# Generación de texto con una RNN

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import tqdm
import re

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Tokenización

In [None]:
class Tokenizer:

    def __init__(self, corpus_vocabulary):
        self.vocabulary = corpus_vocabulary + ['<BOS>', '<EOS>', '<UNK>', '<PAD>']
        self.vocab_size = len(self.vocabulary)

        self.token_to_id = {token: id for id, token in enumerate(self.vocabulary)}
        self.bos_id = self.token_to_id['<BOS>']
        self.eos_id = self.token_to_id['<EOS>']
        self.unk_id = self.token_to_id['<UNK>']
        self.pad_id = self.token_to_id['<PAD>']

    def encode(self, text):
        seq_tokens = re.findall(r'\d|[^\w\s]|\w+|\s', text)
        seq_ids = [self.token_to_id.get(token, self.unk_id) for token in seq_tokens]
        return seq_ids

    def decode(self, seq_ids):
        seq_tokens = ''.join(self.vocabulary[i] for i in seq_ids)
        return seq_tokens

In [None]:
# Ejemplo de uso:

vocabulary = ['a', 'b', 'c', 'd', ' ']
tokenizer = Tokenizer(vocabulary)

print(f'<BOS> ID: {tokenizer.bos_id}.')
print(f'<EOS> ID: {tokenizer.eos_id}.')
print(f'<UNK> ID: {tokenizer.unk_id}.')
print(f'<PAD> ID: {tokenizer.pad_id}.')

token_seq = 'a b c d e'
id_seq = tokenizer.encode(token_seq)
print(f'encode("{token_seq}") = {id_seq}.')
print(f'decode({id_seq}) = "{tokenizer.decode(id_seq)}".')

### Dataset

In [None]:
class TextDataset(Dataset):

    def __init__(self, filename, seq_length):

        with open(filename, 'r', encoding='utf-8') as file:
            corpus = file.read()

        corpus_vocabulary = sorted(set(re.findall(r'\d|[^\w\s]|\w+|\s', corpus)))
        self.tokenizer = Tokenizer(corpus_vocabulary)

        sentences = [sentence.strip() for sentence in corpus.split('\n') if sentence.strip()]
        self.data = [[self.tokenizer.bos_id] + self.tokenizer.encode(sentence) + [self.tokenizer.eos_id] for sentence in sentences]

        self.seq_length = seq_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, n):

        seq_ids = self.data[n]

        if len(seq_ids) > self.seq_length:
            seq_ids = seq_ids[:self.seq_length]
        else:
            seq_ids += [self.tokenizer.pad_id] * (self.seq_length - len(seq_ids))

        return torch.tensor(seq_ids)

In [None]:
seq_length, batch_size = 256, 32

dataset = TextDataset('data.txt', seq_length)
dataloader = DataLoader(dataset, batch_size, shuffle=True, drop_last=True)

print(f'Tamaño del dataset: {len(dataset)} secuencias.')
print(f'Tamaño del dataloader: {len(dataloader)} batches.')
print(f'Tamaño del vocabulario: {dataset.tokenizer.vocab_size} tokens.')

In [None]:
# Ejemplo de uso:

seq_ids = dataset[0]
seq_tokens = dataset.tokenizer.decode(seq_ids[:40])
print(seq_tokens)

## Embeddings

In [None]:
# Ejemplo de uso:

vocab_size, embedding_dim = dataset.tokenizer.vocab_size, 384
emb = nn.Embedding(vocab_size, embedding_dim)

x = torch.randint(vocab_size, (batch_size, seq_length))
y = emb(x)

assert y.shape == (batch_size, seq_length, embedding_dim)
assert emb.weight.shape == (vocab_size, embedding_dim)

## Red recurrente

In [None]:
class GenerativeRNN(nn.Module):

    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, embedding_dim, batch_first=True)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        seq_embedding = self.embedding(x.long())  # [batch_size, seq_length, embedding_dim].
        rnn_output, _ = self.rnn(seq_embedding)   # [batch_size, seq_length, embedding_dim], [1, batch_size, embedding_dim].
        logits = self.lm_head(rnn_output)         # (batch_size, seq_length, vocab_size].
        return logits

In [None]:
# Ejemplo de uso:

vocab_size, embedding_dim = dataset.tokenizer.vocab_size, 384
rnn = GenerativeRNN(vocab_size, embedding_dim)

x = torch.randint(vocab_size, size=[batch_size, seq_length])
y = rnn(x)

assert y.shape == (batch_size, seq_length, vocab_size)

Red que se usará:

In [None]:
rnn = GenerativeRNN(dataset.tokenizer.vocab_size, embedding_dim=384)

n_params = sum(param.numel() for param in rnn.parameters()) / 1e6
print(f'Cantidad de parámetros: {n_params:.3} millones.')

## Entrenamiento

### Equivalencia entre verosimilitud normalizada y CELoss 

In [None]:
batch_size, seq_length, vocab_size = 32, 256, dataset.tokenizer.vocab_size

logits = torch.randn(batch_size, seq_length, vocab_size)
targets = torch.randint(0, vocab_size, size=(batch_size, seq_length))

# Cálculo de la función de pérdida paso a paso:
probs = logits.softmax(dim=-1)
target_probs = probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # [batch_size, seq_length].

log_likelihood = target_probs.log().sum(dim=-1)  # [batch_size].
normalized_log_likelihood = log_likelihood / seq_length
manual_loss = - normalized_log_likelihood.mean()

# Cálculo de la función de pérdida con CrossEntropyLoss:
loss_fn = nn.CrossEntropyLoss()
logits_reshaped = logits.view(batch_size * seq_length, vocab_size)
targets_reshaped = targets.view(batch_size * seq_length)
direct_loss = loss_fn(logits_reshaped, targets_reshaped)

assert torch.isclose(manual_loss, direct_loss)

### Loop de entrenamiento

In [None]:
def train_model(model, optimizer, dataloader, epochs, ckpt_filename):

    model.to(DEVICE)
    model.train()

    pad_id = dataloader.dataset.tokenizer.pad_id
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)

    training = {'losses': [], 'model': None}

    try:
        progressbar = tqdm.trange(epochs)
        for epoch in progressbar:

            for seq_batch in dataloader:

                seq_batch = seq_batch.to(DEVICE)  # [batch_size, seq_length].
                x_batch, y_batch = seq_batch[:, :-1], seq_batch[:, 1:]

                logits = model(x_batch)

                batch_size, seq_length, vocab_size = logits.shape
                logits = logits.reshape(batch_size * seq_length, vocab_size)
                y_batch = y_batch.reshape(batch_size * seq_length)

                loss = loss_fn(logits, y_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                training['losses'].append(loss.item())
                progressbar.set_postfix(loss=loss.item())

    except KeyboardInterrupt:
        print('Entrenamiento interrumpido.')

    training['model'] = model.state_dict()
    torch.save(training, ckpt_filename)

In [None]:
rnn_optimizer = optim.AdamW(rnn.parameters())
#train_model(rnn, rnn_optimizer, dataloader, epochs=32, ckpt_filename='rnn_training.pt')

In [None]:
rnn_training = torch.load('rnn_training.pt', DEVICE, weights_only=True)
rnn.load_state_dict(rnn_training['model'])

plt.plot(rnn_training['losses'])
plt.xlabel('Iteración')
plt.ylabel('Entropía cruzada media')
plt.show()

## Generación

### Temperatura

In [None]:
# Ejemplo de uso:

vocab_size = 10
vocabulary = [f'$a_{{{k}}}$' for k in range(1, vocab_size + 1)]
logits = torch.randn(vocab_size)
temperatures = [0.1, 0.5, 1, 2, 5, 10]

fig, axes = plt.subplots(2, 3, figsize=(10, 4))
axes = axes.flatten()

for i, t in enumerate(temperatures):
    scaled_logits = logits / t
    prob = scaled_logits.softmax(-1)
    axes[i].bar(vocabulary, prob)
    axes[i].set_title(f'$T={t}$')
    axes[i].set_ylabel(f'Probabilidad')

plt.tight_layout()
plt.show()

### Loop de generación

In [None]:
def generate_tokens(model, context, tokenizer, temperature=1, top_k=50, max_tokens=512, repetition_penalty=1):
    model.to(DEVICE)
    model.eval()

    seq_id = [tokenizer.bos_id] + tokenizer.encode(context)
    seq_id = torch.tensor(seq_id, device=DEVICE)

    with torch.no_grad():
        for _ in range(max_tokens):
            logits = model(seq_id.unsqueeze(0))[0, -1, :]

            if temperature == 0:
                next_token = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                logits = logits / temperature

                token_counts = torch.bincount(seq_id, minlength=logits.size(0))
                for token_id, count in enumerate(token_counts):
                    if count > 0:
                        logits[token_id] /= (repetition_penalty ** count)

                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                probs = top_k_logits.softmax(dim=-1)
                next_token = top_k_indices[torch.multinomial(probs, num_samples=1)]

            seq_id = torch.cat((seq_id, next_token), dim=0)

            if next_token in (tokenizer.eos_id, tokenizer.pad_id):
                break

    return tokenizer.decode(seq_id.tolist())

In [None]:
context = 'Habia una vez'

for n in range(10):
    new_tokens = generate_tokens(rnn, context, dataset.tokenizer, max_tokens=64)
    print(new_tokens)