# Simple corruption and denoising demo
This notebook shows how to corrupt a sentence, denoise it and train a small model using this repository.

In [None]:
import torch
from transformers import AutoTokenizer
from diffusion import Diffusion

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if tokenizer.mask_token is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})
mask_id = tokenizer.mask_token_id


In [None]:
sentences = [
    'The quick brown fox jumps over the lazy dog.',
    'A second example for training.',
    'Another simple sentence for the dataset.',
    'Learning diffusion models is fun.',
    'Masked language modeling is interesting.'
]
enc = tokenizer(sentences, padding='max_length', truncation=True, max_length=16, return_tensors='pt')
input_ids = enc['input_ids']


In [None]:
class Corruptor:
    def __init__(self, mask_index):
        self.mask_index = mask_index
    def corrupt(self, tokens, move_chance):
        return Diffusion.q_xt(self, tokens, move_chance)

corruptor = Corruptor(mask_id)
noisy_input = corruptor.corrupt(input_ids, torch.tensor([[0.3]]))


In [None]:
class SmallMLM(torch.nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab, 64)
        self.rnn = torch.nn.LSTM(64, 64, batch_first=True)
        self.out = torch.nn.Linear(64, vocab)
    def forward(self, x):
        e = self.embed(x)
        h,_ = self.rnn(e)
        return self.out(h)

model = SmallMLM(tokenizer.vocab_size)
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
def train_step(x_noisy, x_clean):
    logits = model(x_noisy)
    labels = x_clean.clone()
    labels[x_noisy != mask_id] = -100
    loss = criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
    optim.zero_grad()
    loss.backward()
    optim.step()
    return loss.item()

for epoch in range(3):
    loss = train_step(noisy_input, input_ids)
    print('epoch', epoch, 'loss', loss)


In [None]:
def denoise(tokens):
    with torch.no_grad():
        preds = model(tokens).argmax(-1)
        return torch.where(tokens==mask_id, preds, tokens)

denoised = denoise(noisy_input)
for i in range(len(sentences)):
    print('noisy:', tokenizer.decode(noisy_input[i]))
    print('denoised:', tokenizer.decode(denoised[i]))
    print()
