In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import one_hot
import string

# Hyperparameters
batch_size = 64
seq_len = 50  # Length of the input sequence
latent_dim = 128  # Latent space size
learning_rate = 1e-3
num_epochs = 20

# Text dataset class (character-level)
class TextDataset(Dataset):
    def __init__(self, text, seq_len):
        self.text = text
        self.seq_len = seq_len
        self.vocab = list(set(text))
        self.vocab_size = len(self.vocab)
        self.char_to_idx = {ch: idx for idx, ch in enumerate(self.vocab)}
        self.idx_to_char = {idx: ch for idx, ch in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.text) - self.seq_len
    
    def __getitem__(self, idx):
        input_text = self.text[idx:idx + self.seq_len]
        target_text = self.text[idx + 1:idx + self.seq_len + 1]
        
        # Convert to one-hot encoding
        input_seq = one_hot(torch.tensor([self.char_to_idx[ch] for ch in input_text]), num_classes=len(self.vocab)).float()
        target_seq = torch.tensor([self.char_to_idx[ch] for ch in target_text])
        
        return input_seq, target_seq

# VAE Model
class VAE(nn.Module):
    def __init__(self, vocab_size, hidden_size, latent_dim):
        super(VAE, self).__init__()
        
        self.hidden_size = hidden_size
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder_rnn = nn.LSTM(input_size=vocab_size, hidden_size=hidden_size, batch_first=True)
        self.fc_mu = nn.Linear(hidden_size, latent_dim)
        self.fc_logvar = nn.Linear(hidden_size, latent_dim)
        
        # Decoder
        self.decoder_rnn = nn.LSTM(input_size=latent_dim, hidden_size=hidden_size, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, vocab_size)
    
    def encode(self, x):
        # x: batch_size x seq_len x vocab_size
        _, (h, _) = self.encoder_rnn(x)
        h = h.squeeze(0)  # Take the last hidden state (shape: batch_size x hidden_size)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        z = z.unsqueeze(1).repeat(1, seq_len, 1)  # Repeat z across seq_len
        out, _ = self.decoder_rnn(z)
        out = self.fc_out(out)
        return out
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z)
        return out, mu, logvar

# Loss function: Reconstruction loss + KL divergence
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (Cross Entropy)
    BCE = nn.CrossEntropyLoss()(recon_x.view(-1, recon_x.size(-1)), x.view(-1))
    
    # KL divergence
    MSE = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + MSE

# Training the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load text data
text = open('the-verdict.txt', 'r').read()  # Replace with your text file
dataset = TextDataset(text, seq_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Model, optimizer
model = VAE(vocab_size=dataset.vocab_size, hidden_size=256, latent_dim=latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, target, mu, logvar)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss/len(dataloader):.4f}")

# Sampling new text
def sample_text(model, start_char, length=100):
    model.eval()
    
    # Start with an initial character
    input_seq = torch.tensor([dataset.char_to_idx[start_char]]).unsqueeze(0).to(device)
    input_seq = one_hot(input_seq, num_classes=dataset.vocab_size).float()
    
    generated_text = start_char
    
    with torch.no_grad():
        for _ in range(length):
            recon_x, mu, logvar = model(input_seq)
            _, topi = recon_x.topk(1)
            next_char_idx = topi.squeeze().cpu().numpy()[0]
            next_char = dataset.idx_to_char[next_char_idx]
            
            generated_text += next_char
            
            # Prepare the next input (next character)
            input_seq = one_hot(torch.tensor([next_char_idx]).unsqueeze(0), num_classes=dataset.vocab_size).float().to(device)
    
    return generated_text

# Sample some text
print(sample_text(model, start_char='T', length=200))


Epoch [1/20], Loss: 3.1706
Epoch [2/20], Loss: 3.1162
Epoch [3/20], Loss: 3.1072
Epoch [4/20], Loss: 3.0981
Epoch [5/20], Loss: 3.0971
Epoch [6/20], Loss: 3.0958
Epoch [7/20], Loss: 3.0944
Epoch [8/20], Loss: 3.0907
Epoch [9/20], Loss: 3.0860
Epoch [10/20], Loss: 3.0827
Epoch [11/20], Loss: 3.0801
Epoch [12/20], Loss: 3.0781
Epoch [13/20], Loss: 3.0754
Epoch [14/20], Loss: 3.0734
Epoch [15/20], Loss: 3.0721
Epoch [16/20], Loss: 3.0708
Epoch [17/20], Loss: 3.0690
Epoch [18/20], Loss: 3.0670
Epoch [19/20], Loss: 3.0626
Epoch [20/20], Loss: 3.0547
T                                                                                                                                                                                                        
