<a href="https://colab.research.google.com/github/ayyucedemirbas/Denoising_Autoencoder_for_Text_Generation/blob/main/char_level_tokenization_Transformer_Based_Denoising_Autoencoder_for_Text_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import requests
import string
import re
from collections import Counter

In [None]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

In [None]:
def preprocess_text(text, min_char_freq=3):
    text = text.replace('\r\n', '\n').replace('\r', '\n')
    chars = list(text)
    char_counts = Counter(chars)
    vocab = ['<pad>', '<unk>', '<mask>'] + \
            [char for char, count in char_counts.items() if count >= min_char_freq]

    char2idx = {char:i for i, char in enumerate(vocab)}
    idx2char = {i:char for i, char in enumerate(vocab)}

    data = []
    for char in chars:
        if char in char2idx:
            data.append(char2idx[char])
        else:
            data.append(char2idx['<unk>'])

    return data, char2idx, idx2char, len(vocab)

data, char2idx, idx2char, vocab_size = preprocess_text(text)

In [None]:
seq_length = 64
batch_size = 32
embed_dim = 256
num_heads = 4
ff_dim = 512
num_layers = 3
noise_prob = 0.3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def create_batches(data, batch_size, seq_length):
    num_batches = len(data) // (batch_size * seq_length)
    data = data[:num_batches * batch_size * seq_length]
    data = torch.tensor(data).view(batch_size, -1)
    return data

def add_noise(batch):
    device = batch.device
    noisy_batch = batch.clone()

    mask = torch.rand_like(noisy_batch.float(), device=device) < noise_prob
    random_chars = torch.randint(3, vocab_size, noisy_batch.shape, device=device)

    newline_mask = (noisy_batch == char2idx.get('\n', -1))
    punctuation_mask = torch.isin(noisy_batch, torch.tensor(
        [char2idx[c] for c in string.punctuation if c in char2idx],
        device=device
    ))
    mask[newline_mask | punctuation_mask] &= torch.rand_like(mask.float())[newline_mask | punctuation_mask] < 0.1

    noisy_batch[mask] = random_chars[mask]
    return noisy_batch

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        return self.norm2(x + ff_output)

class DenoisingTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for block in self.transformer_blocks:
            x = block(x)
        return self.fc(x)

In [None]:
model = DenoisingTransformer(vocab_size, embed_dim, num_heads, ff_dim, num_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=char2idx['<pad>'])

data_tensor = create_batches(data, batch_size, seq_length)
num_batches = data_tensor.size(1) // seq_length

In [None]:
for epoch in range(300):
    model.train()
    total_loss = 0
    for i in range(num_batches):
        inputs = data_tensor[:, i*seq_length:(i+1)*seq_length].to(device)
        noisy_inputs = add_noise(inputs)

        optimizer.zero_grad()
        outputs = model(noisy_inputs)
        loss = criterion(outputs.view(-1, vocab_size), inputs.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()

    if (epoch+1) % 10 == 0:
        test_text = "First Citizen:\nWe are accounted poor citizens..."
        test_chars = list(test_text)
        test_data = [char2idx.get(c, char2idx['<unk>']) for c in test_chars]
        test_input = torch.tensor(test_data[:seq_length]).unsqueeze(0).to(device)
        noisy_test = add_noise(test_input)

        model.eval()
        with torch.no_grad():
            prediction = model(noisy_test).argmax(-1)
            restored = ''.join([idx2char[idx.item()] for idx in prediction[0]])

        print(f"\nEpoch {epoch+1}")
        print("Noisy input:", ''.join([idx2char.get(idx.item(), '?') for idx in noisy_test[0]]))
        print("Restored:", restored)

    print(f"Epoch {epoch+1}, Loss: {total_loss/num_batches:.4f}")


Epoch 1, Loss: 1.3368
Epoch 2, Loss: 1.2851
Epoch 3, Loss: 1.2847
Epoch 4, Loss: 1.2857
Epoch 5, Loss: 1.2805
Epoch 6, Loss: 1.2851
Epoch 7, Loss: 1.2789
Epoch 8, Loss: 1.2810
Epoch 9, Loss: 1.2817

Epoch 10
Noisy input: FiEatvCitsren:
WeMaaeMacJou ted pooG .idizenX...
Restored: FiEatv itsren:
We aae ac ou ted poo  .idi en ...
Epoch 10, Loss: 1.2811
Epoch 11, Loss: 1.2816
Epoch 12, Loss: 1.2825
Epoch 13, Loss: 1.2810
Epoch 14, Loss: 1.2833
Epoch 15, Loss: 1.2792
Epoch 16, Loss: 1.2772
Epoch 17, Loss: 1.2843
Epoch 18, Loss: 1.2791
Epoch 19, Loss: 1.2810

Epoch 20
Noisy input: Bjlst Cik3zEn:
Ge pri acc!uFted ;oor cktiztn&...
Restored:   lst  ik  En:
 e pri acc u ted ;oor ckti tn ...
Epoch 20, Loss: 1.2809
Epoch 21, Loss: 1.2812
Epoch 22, Loss: 1.2780
Epoch 23, Loss: 1.2783
Epoch 24, Loss: 1.2807
Epoch 25, Loss: 1.2821
Epoch 26, Loss: 1.2820
Epoch 27, Loss: 1.2810
Epoch 28, Loss: 1.2810
Epoch 29, Loss: 1.2824

Epoch 30
Noisy input: j&'se pQdizer:
WS?akeNaccounted poor cXRizRns...
Restored

In [None]:
def save_model(model, path, char2idx, idx2char):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'char2idx': char2idx,
        'idx2char': idx2char,
        'hyperparameters': {
            'vocab_size': vocab_size,
            'embed_dim': embed_dim,
            'num_heads': num_heads,
            'ff_dim': ff_dim,
            'num_layers': num_layers
        }
    }
    torch.save(checkpoint, path)
    print(f"Model saved to {path}")

def load_model(path, device='cpu'):
    checkpoint = torch.load(path, map_location=device)
    hp = checkpoint['hyperparameters']

    model = DenoisingTransformer(
        hp['vocab_size'],
        hp['embed_dim'],
        hp['num_heads'],
        hp['ff_dim'],
        hp['num_layers']
    ).to(device)

    model.load_state_dict(checkpoint['model_state_dict'])
    return model, checkpoint['char2idx'], checkpoint['idx2char']

In [None]:
save_model(model, 'char_denoising_transformer.pth', char2idx, idx2char)

Model saved to char_denoising_transformer.pth


In [None]:
loaded_model, char2idx, idx2char = load_model('char_denoising_transformer.pth', device=device)

In [None]:
test_phrase = "\nSecond Citizen:\nLet us kill him!\n"
test_data = [char2idx.get(c, char2idx['<unk>']) for c in test_phrase]
test_input = torch.tensor(test_data).unsqueeze(0).to(device)
noisy_test = add_noise(test_input)

with torch.no_grad():
    prediction = model(noisy_test).argmax(-1)
    restored = ''.join([idx2char[idx.item()] for idx in prediction[0]])

print("Original:", test_phrase)
print("Noisy:   ", ''.join([idx2char.get(idx.item(), '?') for idx in noisy_test[0]]))
print("Restored:", restored)

Original: 
Second Citizen:
Let us kill him!

Noisy:    
veaondqCitizen:
LeD uE kilA him!

Restored:  ve onMGV t  enD  e        A h    


In [None]:
test_phrase = "\nSecond Citizen:\nLet us kill him!\n"
test_data = [char2idx.get(c, char2idx['<unk>']) for c in test_phrase]
test_input = torch.tensor(test_data).unsqueeze(0).to(device)
noisy_test = add_noise(test_input)

with torch.no_grad():
    prediction = loaded_model(noisy_test).argmax(-1)
    restored = ''.join([idx2char[idx.item()] for idx in prediction[0]])

print("Original:", test_phrase)
print("Noisy:   ", ''.join([idx2char.get(idx.item(), '?') for idx in noisy_test[0]]))
print("Restored:", restored)