## Exercício: Modelo de Linguagem com auto-atenção e máscaras causais

Seguimos na mesma linha de treinar um modelo de linguagem a partir dos textos do livro "O Guarani", de José de Alencar.

Neste exercício, vamos treinar um modelo de linguagem com auto-atenção e com máscara causal. A máscara causal é necessária para que o modelo não tenha acesso a palavras futuras, que é a abordagem usada por grandes modelos de linguagem, como o GPT.

Use a implementação matricial de auto-atenção da aula passada.

### Modificações necessárias

* Adicione a máscara causal na função `forward` da cabeça de auto-atenção.
* Modifique o nosso dataloader para retornar inputs (uma lista de tokens de tamanho $n$), targets (uma lista de tokens de tamanho $n$ deslocada para a esquerda em 1 token). Exemplo `input = [1, 2, 3, 4]`, `target = [2, 3, 4, 5]` para a sequência `[1, 2, 3, 4, 5]` com `seq_len=4`, por exemplo (Ver slide 50).

### Extra
* MultiHeadAttention: modifique a cabeça de auto-atenção para ter múltiplas cabeças. Isso não é obrigatório, mas pode ser interessante para ver como o modelo se comporta.
* Diagrama da geração: fazer diagrama que mostre os passos da geração de tokens (conforme slide 47).

### Dicas

* Use como base o vídeo do Karpathy: https://www.youtube.com/watch?v=kCc8FmEb1nY. Observe que, no vídeo, ele primeiro implementa um modelo bi-grama, depois um modelo de linguagem com auto-atenção. O modelo de auto-atenção é implementado por volta do minuto 40, mas vale a pena assistir o vídeo todo.
* Use esta implementação como base: https://colab.research.google.com/drive/1vFTg4MSXVJwNSzPjaCcvmqhxTP7gK7HA?usp=sharing. Observe como o modelo é organizado e como a máscara é implementada na classe MultiHeadAttention.
* Use `context_size=9`

## Imports

In [1]:
import os
import sys
import random
import torch.nn as nn
import torch.nn.functional as F
import time
import math
from sklearn.model_selection import train_test_split
from torch import nn

## Variáveis Globais e Inicialização

In [2]:
# Global variables

# Vocabulary
vocab_size = 5000
context_size = 5
pattern = r'\w+|[,;.:!?\']'

# Training
batch_size = 128
epochs = 10
lr = 0.1

# Model
embedding_dim = 256
hidden_dim = 128
dropout_rate = 0.2

# Colab environment
IN_COLAB = 'google.colab' in sys.modules

if (IN_COLAB):
    %pip install colorama

    # Google Drive
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

    project_folder="/content/drive/MyDrive/Classes/IA024/Aula_2_3"
    os.chdir(project_folder)
    !ls -la

## Faz download e carrega o dataset

In [3]:
# Check if download is necessary
if not os.path.exists("67724.txt.utf-8"):
    print("Downloading Gutenberg texts")

    !wget https://www.gutenberg.org/ebooks/67724.txt.utf-8
    !wget https://www.gutenberg.org/ebooks/67725.txt.utf-8

In [4]:
text = open("67724.txt.utf-8","r").read()
text += open("67725.txt.utf-8","r").read()

paragraphs = text.split("\n\n")

len(paragraphs)

4969

In [5]:
# Checking the text
print(paragraphs[0])

The Project Gutenberg eBook of O Guarany: romance brazileiro, Vol. 1 (of 2)
    
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
of the Project Gutenberg License included with this ebook or online
at www.gutenberg.org. If you are not located in the United States,
you will have to check the laws of the country where you are located
before using this eBook.


In [6]:
cleaned_paragraphs = [paragraph.replace("\n", " ") for paragraph in paragraphs if paragraph.strip()]

# Print 5 random paragraphs
num_paragraphs = len(cleaned_paragraphs)
for i in range(0,5):
    idx = random.randrange(num_paragraphs)
    print(f"{cleaned_paragraphs[idx]}\n")

print("Number of paragraphs: " + str(num_paragraphs))

len(cleaned_paragraphs)

--Não tenhais esse receio; qualquer que seja a desgraça que me annunciardes, será bem vinda pelos vossos labios; é sempre um consolo receber-se a má nova de voz amiga!

--Pery ia salvar-te!

Depois, fatigado do esforço supremo, se estende sobre a terra, e adormece n'uma linda bacia que a natureza formou, e onde o recebe como em um leito de noiva, sob as cortinas de trepadeiras e flores agrestes.

While we cannot and do not solicit contributions from states where we have not met the solicitation requirements, we know of no prohibition against accepting unsolicited donations from donors in such states who approach us with offers to donate.

 PAG. 42.--=Biribá=.

Number of paragraphs: 4892


4892

## Análise do dataset

In [7]:
# Conta as palavras no dataset
from collections import Counter
import re

def count_words(texts):
    word_counts = Counter()
    for text in texts:
        word_counts.update(re.findall(r'\w+', text.lower()))
    return word_counts

word_counts = count_words(cleaned_paragraphs)

len(word_counts)

12603

## Criando um vocabulário

In [8]:
most_frequent_words = [word for word, count in word_counts.most_common(vocab_size)]
vocab = {word: i for i, word in enumerate(most_frequent_words, 1)}

In [9]:
def encode_sentence(sentence, vocab):
    return [vocab.get(word, 0) for word in re.findall(pattern, sentence.lower())]

print(cleaned_paragraphs[20])
print(encode_sentence(cleaned_paragraphs[20], vocab))

 Publicando este livro em 1857, se disse ser aquella primeira edição uma prova typographica, que algum dia talvez o autor se dispuzesse a rever.
[0, 139, 4376, 19, 0, 0, 6, 44, 110, 269, 259, 2662, 10, 1064, 0, 0, 2, 186, 130, 280, 3, 2257, 6, 0, 1, 2665, 0]


## Classe do dataset

In [10]:
# Dataset class
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
  def __init__(self, paragraphs, vocab, context):
    self.paragraphs = paragraphs
    self.vocab = vocab
    self.context = context
    self.tokens, self.targets = self.setup()

  def __len__(self):
    return len(self.tokens)

  def __getitem__(self, idx):
    return torch.tensor(self.tokens[idx]), torch.tensor(self.targets[idx])
  
  def setup(self):
    tokens = []
    targets = []
    for paragraph in self.paragraphs:
      encoded = encode_sentence(paragraph, self.vocab)
      
      # If paragraph is smaller than the context, skip it.
      if len(encoded) < self.context + 1:
          continue

      for i in range(len(encoded) - self.context):
        tks = encoded[i:i+self.context]
        tgt = encoded[i+self.context]
        # Only add if there are no unknown tokens in both context and target.
        bad_token = 0
        if not (bad_token in tks or tgt == bad_token):
          tokens.append(tks)
          targets.append(tgt)
    return tokens, targets

In [11]:
# Train/Validation split
train_data, val_data = train_test_split(cleaned_paragraphs, test_size=0.2, random_state=18)

train_dataset = CustomDataset(train_data, vocab, context_size)
val_dataset = CustomDataset(val_data, vocab, context_size)

# Counting all Samples
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print()
print(f"Training dataset samples: {len(train_dataset)}")
print(f"Validation dataset samples: {len(val_dataset)}")

Training samples: 3913
Validation samples: 979

Training dataset samples: 22886
Validation dataset samples: 5853


In [12]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

sample = next(iter(train_loader))
print(sample)

[tensor([[  15,   10,  292,  594,    2],
        [   2,   48,  247,    7,  320],
        [   5, 4046,    3,   37,   11],
        [1336,  544,  813,   26,  156],
        [   2,    6, 1220,  953,  585],
        [  20,   84,   11,  669,    2],
        [  94,   23,  337,    4,  125],
        [   8,  355,    2,   39,  140],
        [   1, 1104,    2,   17, 1306],
        [ 514,    4, 2061,   11,  177],
        [1648,   19,    7,   36, 2266],
        [   9,  216,   80,   15,    9],
        [   1,   76,  595,   80,   38],
        [   6,  101,    3, 2012,   11],
        [   6,   15, 1031,    3,   16],
        [ 214, 1696,    6,    3,  746],
        [1754,    4,  934,  539,  743],
        [ 292,   19,    2,    6, 1140],
        [ 215,   18, 4248,   77, 1567],
        [1829, 4621,    4,   10,  365],
        [ 385,    7, 2632,    5,   10],
        [  17,  867,   77, 1834,   18],
        [   3, 1878,    2,    6,  674],
        [ 162,   12,   49, 1131,   32],
        [ 175,   64,    2,  488,    3],

## Model

In [13]:
# Positional Embedding - as described in "Attention is All You Need"
class PositionalEncoding(nn.Module):
    def __init__(self, max_sequence, embedding_dim):
        super().__init__()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.positional_encoding = torch.zeros(max_sequence, embedding_dim, device=device)
        position = torch.arange(0, max_sequence, device=device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device) * (-math.log(10000.0) / embedding_dim))
        self.positional_encoding[:, 0::2] = torch.sin(position * div_term)
        self.positional_encoding[:, 1::2] = torch.cos(position * div_term)
        self.positional_encoding = self.positional_encoding.unsqueeze(0)

    def forward(self, x):
        _, seq_length, _ = x.size()
        positional_encoding = self.positional_encoding[:, :seq_length, :]
        positional_encoding = positional_encoding.to(x.device)
        # Position encoding is added to the input embeddings.
        return x + positional_encoding   

In [14]:
# Matrix Implementation
class SelfAttention_Matrix(nn.Module):
  def __init__(self, embedding_dim, vocab_size):
    super().__init__()

    self.WQ = nn.Linear(embedding_dim, embedding_dim, bias=False)
    self.WK = nn.Linear(embedding_dim, embedding_dim, bias=False)
    self.WV = nn.Linear(embedding_dim, embedding_dim, bias=False)
    self.WO = nn.Linear(embedding_dim, embedding_dim, bias=False)

  def setProjections(self, WQ, WK, WV, WO):
    self.WQ = WQ
    self.WK = WK
    self.WV = WV
    self.WO = WO

  def forward(self, inputs):
    # Linear projections
    Q = self.WQ(inputs)
    K = self.WK(inputs)
    V = self.WV(inputs)

    scores = torch.matmul(Q, K.transpose(-2, -1))
    probs = F.softmax(scores, dim=-1)
    new_embedding = torch.matmul(probs, V)
    # Projection in WO
    new_embedding = self.WO(new_embedding)
    return new_embedding

In [15]:
class LanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size, h):
        super(LanguageModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size+1, embedding_dim)
        self.posencoding = PositionalEncoding(context_size, embedding_dim)
        self.attention = SelfAttention_Matrix(embedding_dim, vocab_size)        
        self.linear1 = nn.Linear(context_size * embedding_dim, h)
        self.dropout1 = nn.Dropout(p = dropout_rate)
        self.relu = torch.nn.ReLU()
        self.linear2 = nn.Linear(h, vocab_size+1)
        self.dropout2 = nn.Dropout(p = dropout_rate)
        # Softmax to scale outputs
        self.logSoftMax = torch.nn.LogSoftmax(dim=1)

    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        embeds_pos = self.posencoding(embeds)
        x = torch.stack(torch.unbind(embeds_pos, dim=1), dim=1)
        # Camada de autoatenção
        attention  = self.attention(x)
        # Flatten embeddings
        embeds = embeds.view(attention.size(0), -1)
        # Linear layer
        out = self.linear1(embeds)
        out = self.dropout1(out)
        out = self.relu(out)
        # Second layer
        out = self.linear2(out)
        out = self.dropout2(out)
        # Softmax output
        out = self.logSoftMax(out)
        return out

## Funções de Treinamento e Avaliação do Modelo

In [16]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has a total of {total_params:,} parameters.')

In [17]:
def initial_eval(model):
    # Initial Perplexity and Loss
    # Before training
    model.eval()

    loss = 0
    perp = 0

    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss += criterion(outputs, targets).item()

    loss /= len(train_loader)
    perp = torch.exp(torch.tensor(loss))

    print(f'Initial Loss: {loss:.4f}')
    print(f'Initial Perplexity: {perp:.4f}')

In [18]:
def train(model, criterion, optimizer):
      # Training Loop
      model.train()
      for epoch in range(epochs):

            epoch_start = time.time()
            # Metrics
            epoch_loss = 0
            epoch_correct = 0
            epoch_samples = 0

            for inputs, targets in train_loader:
                  inputs = inputs.to(device)  # Move input data to the device
                  targets = targets.to(device)

                  # Forward pass
                  outputs = model(inputs)
                  loss = criterion(outputs, targets)

                  # Backward pass and optimization
                  optimizer.zero_grad()
                  loss.backward()
                  optimizer.step()

                  # Loss
                  epoch_loss += loss.item()

                  # Predicted
                  _, predicted = torch.max(outputs, 1)
                  epoch_correct += (predicted == targets).sum().item()
                  epoch_samples += targets.size(0)

            # Calculate average loss and accuracy for epoch
            avg_loss = epoch_loss / len(train_loader)
            acc = epoch_correct / epoch_samples

            # Perplexity
            perp = torch.exp(torch.tensor(avg_loss))

            epoch_end = time.time()
            epoch_time = epoch_end - epoch_start
            # Print epoch statistics
            print(f'Epoch [{epoch+1}/{epochs}], Time:{epoch_time:.2f}, Loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%, Perplexity: {perp:.4f}')


In [19]:
def eval(model, criterion):
    model.eval()

    loss_sum = 0
    total_sum = 0
    correct_sum = 0
    eval_round = 0

    loss = 0
    perp = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)      
            loss_sum += loss

            # Get the predicted labels
            _, predicted = torch.max(outputs, 1)

            total_sum += targets.size(0)
            correct_sum += (predicted == targets).sum().item()
            eval_round += 1

    # Calculate accuracy
    acc = 100 * correct_sum / total_sum

    # Calculate average perplexity
    average_loss = loss_sum / len(val_loader)
    average_perplexity = torch.exp(average_loss)

    print(f'Test Accuracy: {acc:.2f}%')
    print(f'Average Loss: {average_loss:.2f}')
    print(f'Average Perplexity: {average_perplexity:.2f}')

In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Training

In [21]:
model_attn = LanguageModel(vocab_size, embedding_dim, context_size, hidden_dim)
print("Model with Self Attention:")
print()
count_parameters(model_attn)

# Cross Entropy
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.SGD(model_attn.parameters(), lr)

model_attn.to(device)

print()
print("Training Start")
print()
train(model_attn, criterion, optimizer)

Model with Self Attention:

The model has a total of 2,351,497 parameters.

Training Start

Epoch [1/10], Time:2.14, Loss: 7.7416, Accuracy: 0.05%, Perplexity: 2302.1047
Epoch [2/10], Time:1.24, Loss: 6.9527, Accuracy: 0.09%, Perplexity: 1045.9573
Epoch [3/10], Time:1.25, Loss: 6.5878, Accuracy: 0.11%, Perplexity: 726.1639
Epoch [4/10], Time:1.24, Loss: 6.2952, Accuracy: 0.14%, Perplexity: 541.9409
Epoch [5/10], Time:1.25, Loss: 6.0578, Accuracy: 0.16%, Perplexity: 427.4146
Epoch [6/10], Time:1.17, Loss: 5.8786, Accuracy: 0.17%, Perplexity: 357.3179
Epoch [7/10], Time:1.16, Loss: 5.6768, Accuracy: 0.19%, Perplexity: 292.0055
Epoch [8/10], Time:1.20, Loss: 5.5365, Accuracy: 0.21%, Perplexity: 253.7761
Epoch [9/10], Time:1.24, Loss: 5.3893, Accuracy: 0.22%, Perplexity: 219.0574
Epoch [10/10], Time:1.23, Loss: 5.2480, Accuracy: 0.23%, Perplexity: 190.1868


## Avaliação

In [22]:
print()
print("Evaluation Start")
print()
eval(model_attn, criterion)


Evaluation Start

Test Accuracy: 12.68%
Average Loss: 6.13
Average Perplexity: 459.38


## Exemplo de uso

In [26]:
# Código adaptado da implementação do Cesar Bastos
from colorama import Fore, Style

text = cleaned_paragraphs
model_attn.to(device)
def generate_text(model, vocab, text, max_length, context_size):
    words = []
    # Ensure there are enough words for at least one sequence
    while len(words) < context_size:
        random_number = random.randint(1, 4891)
        words = encode_sentence(text[random_number], vocab)
        if not words:
            words = []
            continue  # Skip if the sentence cannot be encoded
        words = words[:context_size]
        #print(words)
        if any(token == 0 for token in words):
            words = []
            continue  # Skip if any token is zero (assuming 0 is a special token)
        context = words

    print(f"Frase: {cleaned_paragraphs[random_number]}")
    print(words)

    for _ in range(max_length):
        words_tensor = torch.tensor(context[-context_size:], dtype=torch.long).unsqueeze(0).to(device)
        logits = model(words_tensor)
        probs = F.softmax(logits, dim=1)
        next_token = torch.multinomial(probs, num_samples=1)
        context.append(next_token.item())
        print(context)
    frase = []
    for i in context: ##Agradecimentos a Ramon Abilio
        word = next((word for word, code in vocab.items() if code == i), "<UNKNOWN>")
        frase.append(word)

    print(f"{Fore.BLUE}{frase[:context_size]}{Style.RESET_ALL} {Fore.RED}{frase[-max_length:]}{Style.RESET_ALL} ")


max_length= 10
generate_text(model_attn, vocab, text, max_length, context_size)

Frase: --Sob pretexto de que os selvagens podem cortar-nos a entrada da casa por alguns dias, levamos provisão de viveres. Caminhamos sem parar, sem olhar atrás; e prometto-vos que nos salvaremos.
[523, 2410, 4, 2, 12]
[523, 2410, 4, 2, 12, 278]
[523, 2410, 4, 2, 12, 278, 277]
[523, 2410, 4, 2, 12, 278, 277, 579]
[523, 2410, 4, 2, 12, 278, 277, 579, 12]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49, 4986]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49, 4986, 24]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49, 4986, 24, 627]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49, 4986, 24, 627, 11]
[523, 2410, 4, 2, 12, 278, 277, 579, 12, 49, 4986, 24, 627, 11, 2438]
[34m['sob', 'pretexto', 'de', 'que', 'os'][0m [31m['passou', 'quanto', 'morto', 'os', 'olhos', 'ignorando', 'lhe', 'objecto', 'da', 'reptis'][0m 
