In [None]:
#import packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

In [None]:
# Dummy corpus (you can replace with a large dataset)
sentences = [
    "i love machine learning",
    "deep learning is powerful",
    "nlp is fun",
    "i enjoy coding",
    "language models are awesome",
    "i like natural language processing",
]

In [None]:
# 1. Vocabulary & Tokenization
class Vocab:
    def __init__(self, texts):
        tokens = set(word for sent in texts for word in sent.split())
        self.word2idx = {w: i+2 for i, w in enumerate(tokens)}
        self.word2idx['<pad>'] = 0
        self.word2idx['<unk>'] = 1
        self.idx2word = {i: w for w, i in self.word2idx.items()}
        self.vocab_size = len(self.word2idx)
        print(f"Vocabulary size: {self.vocab_size}")

    def encode(self, sent):
        return [self.word2idx.get(word, 1) for word in sent.split()]

    def decode(self, ids):
        return ' '.join([self.idx2word[i] for i in ids if i != 0])

vocab = Vocab(sentences)
print(vocab)

Vocabulary size: 20
<__main__.Vocab object at 0x78c58741ec90>


In [None]:
# 2. Dataset
class TextDataset(Dataset):
    def __init__(self, sentences, vocab, max_len=8):
        self.data = [vocab.encode(s) for s in sentences]
        self.data = [s + [0]*(max_len - len(s)) for s in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.long)

dataset = TextDataset(sentences, vocab)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
print(dataset)

<__main__.TextDataset object at 0x78c587305f50>


In [None]:
# 3. VAE Model
class TextVAE(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, hidden_dim=128, latent_dim=32):
        super(TextVAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder_rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.hidden2mean = nn.Linear(hidden_dim, latent_dim)
        self.hidden2logv = nn.Linear(hidden_dim, latent_dim)
        self.latent2hidden = nn.Linear(latent_dim, hidden_dim)
        self.decoder_rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.outputs2vocab = nn.Linear(hidden_dim, vocab_size)

    def encode(self, x):
        emb = self.embedding(x)
        _, h = self.encoder_rnn(emb)
        h = h.squeeze(0)
        return self.hidden2mean(h), self.hidden2logv(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode_stepwise(self, z, max_len=15, temperature=0.8, start_token="i"):
        h = self.latent2hidden(z).unsqueeze(0)
        start_idx = vocab.word2idx.get(start_token, random.randint(2, vocab.vocab_size - 1))
        inputs = torch.tensor([[start_idx]], dtype=torch.long).to(z.device)

        outputs = []
        prev_token = None

        for _ in range(max_len):
            emb = self.embedding(inputs)
            out, h = self.decoder_rnn(emb, h)
            logits = self.outputs2vocab(out[:, -1, :])
            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            next_token_id = next_token.item()

            # Avoid repetition and pad
            if next_token_id != prev_token and next_token_id != 0:
                outputs.append(next_token_id)
                prev_token = next_token_id

            inputs = next_token

        return outputs

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode_from_z(z, x.size(1))
        return x_recon, mu, logvar

    def decode_from_z(self, z, seq_len):
        h = self.latent2hidden(z).unsqueeze(0)
        inputs = torch.full((z.size(0), seq_len), vocab.word2idx['<pad>'], dtype=torch.long).to(z.device)
        inputs[:, 0] = random.randint(2, vocab.vocab_size - 1)
        emb = self.embedding(inputs)
        out, _ = self.decoder_rnn(emb, h)
        return self.outputs2vocab(out)

In [None]:
# 4. Loss Function
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), x.view(-1), ignore_index=0)
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_div

In [None]:
# 5. Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TextVAE(vocab_size=vocab.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 20
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(batch)
        loss = vae_loss(recon, batch, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

Epoch 1, Loss: 10.5910
Epoch 2, Loss: 9.3555
Epoch 3, Loss: 9.0976
Epoch 4, Loss: 8.9725
Epoch 5, Loss: 8.7550
Epoch 6, Loss: 8.5339
Epoch 7, Loss: 8.4723
Epoch 8, Loss: 8.5010
Epoch 9, Loss: 8.4817
Epoch 10, Loss: 8.3999
Epoch 11, Loss: 8.3621
Epoch 12, Loss: 8.3269
Epoch 13, Loss: 8.2855
Epoch 14, Loss: 7.9773
Epoch 15, Loss: 8.1695
Epoch 16, Loss: 8.0114
Epoch 17, Loss: 7.9640
Epoch 18, Loss: 8.0304
Epoch 19, Loss: 8.0251
Epoch 20, Loss: 7.7483


In [None]:
# 6. Text Generation
model.eval()
with torch.no_grad():
    z = torch.randn(1, 32).to(device)
    print(z)
    output_ids = model.decode_stepwise(z, max_len=12, temperature=0.8, start_token="language")
    print("\n📝 Generated Sentence:\n", vocab.decode(output_ids))


tensor([[ 1.9235,  0.4659,  1.2726, -0.4064,  0.2897, -0.9690,  0.0390, -1.3075,
         -0.0789,  1.8578, -0.7934, -1.2537,  0.9834,  0.1393, -0.9805, -0.6045,
          1.4312,  0.0861,  1.0327, -0.1965, -1.7563, -0.4395,  0.2122, -1.0704,
         -0.3338, -0.8653, -0.4654, -0.9062,  0.4435,  1.6028,  1.1854,  1.7066]])

📝 Generated Sentence:
 processing are awesome like fun <unk> like fun powerful models i are


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Sample sentence data
sentences = [
    "i enjoy working on ai",
    "machine learning is fun",
    "deep learning helps a lot",
    "natural language is powerful",
    "models can learn patterns"
]

# Vocabulary preparation
class Vocab:
    def __init__(self, sentences):
        tokens = {"<pad>", "<sos>", "<eos>", "<unk>"}
        for sentence in sentences:
            tokens.update(sentence.split())
        self.word2idx = {w: i for i, w in enumerate(sorted(tokens))}
        self.idx2word = {i: w for w, i in self.word2idx.items()}
        self.pad = self.word2idx["<pad>"]
        self.sos = self.word2idx["<sos>"]
        self.eos = self.word2idx["<eos>"]
        self.vocab_size = len(self.word2idx)

    def encode(self, sentence, max_len=10):
        tokens = [self.word2idx.get(w, self.word2idx["<unk>"]) for w in sentence.split()]
        tokens = [self.sos] + tokens + [self.eos]
        tokens += [self.pad] * (max_len - len(tokens))
        return tokens[:max_len]

    def decode(self, indices):
        words = []
        for i in indices:
            w = self.idx2word.get(i, "<unk>")
            if w in {"<pad>", "<sos>", "<eos>"}: continue
            words.append(w)
        return " ".join(words)

vocab = Vocab(sentences)

# Dataset
class SentenceDataset(Dataset):
    def __init__(self, sentences, vocab):
        self.data = [vocab.encode(s) for s in sentences]
        self.vocab = vocab

    def __getitem__(self, idx):
        x = torch.tensor(self.data[idx])
        return x, x  # input, target

    def __len__(self):
        return len(self.data)

dataset = SentenceDataset(sentences, vocab)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# VAE Modules
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, hidden_dim=128, latent_dim=32):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.hidden2mean = nn.Linear(hidden_dim, latent_dim)
        self.hidden2logv = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        emb = self.embed(x)
        _, h = self.rnn(emb)
        mean = self.hidden2mean(h.squeeze(0))
        logv = self.hidden2logv(h.squeeze(0))
        std = torch.exp(0.5 * logv)
        z = mean + std * torch.randn_like(std)
        return z, mean, logv

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=64, latent_dim=32, hidden_dim=128, max_len=10):
        super().__init__()
        self.latent2hidden = nn.Linear(latent_dim, hidden_dim)
        self.embed = nn.Embedding(vocab_size, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.max_len = max_len

    def forward(self, z, targets=None, teacher_forcing_ratio=0.5):
        batch_size = z.size(0)
        hidden = self.latent2hidden(z).unsqueeze(0)
        input_token = torch.full((batch_size, 1), vocab.sos).to(z.device)
        outputs = []

        for t in range(self.max_len):
            emb = self.embed(input_token)
            output, hidden = self.rnn(emb, hidden)
            logits = self.fc(output)
            outputs.append(logits)
            top1 = logits.argmax(2)
            input_token = top1 if (targets is None or torch.rand(1).item() > teacher_forcing_ratio) else targets[:, t].unsqueeze(1)

        return torch.cat(outputs, dim=1)

# VAE Model
class SentenceVAE(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.encoder = Encoder(vocab_size)
        self.decoder = Decoder(vocab_size)

    def forward(self, x):
        z, mean, logv = self.encoder(x)
        outputs = self.decoder(z, targets=x)
        return outputs, mean, logv

# Loss Function
def vae_loss(recon, target, mean, logv):
    recon = recon.view(-1, vocab.vocab_size)
    target = target.view(-1)
    CE = F.cross_entropy(recon, target, ignore_index=vocab.pad)
    KL = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp()) / mean.size(0)
    return CE + KL

# Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SentenceVAE(vocab.vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

for epoch in range(50):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        recon, mean, logv = model(x)
        loss = vae_loss(recon, y, mean, logv)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# Sentence Generation
def reconstruct(sentence):
    model.eval()
    with torch.no_grad():
        x = torch.tensor([vocab.encode(sentence)]).to(device)
        z, _, _ = model.encoder(x)
        outputs = model.decoder(z)
        tokens = outputs.argmax(2).squeeze(0).tolist()
        return vocab.decode(tokens)

# Try some test sentences
print("\n--- Sentence Reconstruction ---")
for s in sentences:
    print(f"Input     : {s}")
    print(f"Rewritten : {reconstruct(s)}\n")


Epoch 10, Loss: 4.1399
Epoch 20, Loss: 2.8232
Epoch 30, Loss: 2.0676
Epoch 40, Loss: 2.4734
Epoch 50, Loss: 3.9451

--- Sentence Reconstruction ---
Input     : i enjoy working on ai
Rewritten : deep learning helps a lot

Input     : machine learning is fun
Rewritten : deep learning helps a lot

Input     : deep learning helps a lot
Rewritten : models can learn patterns

Input     : natural language is powerful
Rewritten : machine learning helps a lot

Input     : models can learn patterns
Rewritten : models can learn patterns

