## Exercício: Modelo de Linguagem com auto-atenção e máscaras causais

Seguimos na mesma linha de treinar um modelo de linguagem a partir dos textos do livro "O Guarani", de José de Alencar.

Neste exercício, vamos treinar um modelo de linguagem com auto-atenção e com máscara causal. A máscara causal é necessária para que o modelo não tenha acesso a palavras futuras, que é a abordagem usada por grandes modelos de linguagem, como o GPT.

Use a implementação matricial de auto-atenção da aula passada.

### Modificações necessárias

* Adicione a máscara causal na função `forward` da cabeça de auto-atenção.
* Modifique o nosso dataloader para retornar inputs (uma lista de tokens de tamanho $n$), targets (uma lista de tokens de tamanho $n$ deslocada para a esquerda em 1 token). Exemplo `input = [1, 2, 3, 4]`, `target = [2, 3, 4, 5]` para a sequência `[1, 2, 3, 4, 5]` com `seq_len=4`, por exemplo (Ver slide 50).

### Extra
* MultiHeadAttention: modifique a cabeça de auto-atenção para ter múltiplas cabeças. Isso não é obrigatório, mas pode ser interessante para ver como o modelo se comporta.
* Diagrama da geração: fazer diagrama que mostre os passos da geração de tokens (conforme slide 47).

### Dicas

* Use como base o vídeo do Karpathy: https://www.youtube.com/watch?v=kCc8FmEb1nY. Observe que, no vídeo, ele primeiro implementa um modelo bi-grama, depois um modelo de linguagem com auto-atenção. O modelo de auto-atenção é implementado por volta do minuto 40, mas vale a pena assistir o vídeo todo.
* Use esta implementação como base: https://colab.research.google.com/drive/1vFTg4MSXVJwNSzPjaCcvmqhxTP7gK7HA?usp=sharing. Observe como o modelo é organizado e como a máscara é implementada na classe MultiHeadAttention.
* Use `context_size=9`

## Imports

In [1]:
import os
import sys
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import re
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter

## Variáveis Globais e Inicialização

In [2]:
# Global variables

# Vocabulary
vocab_size = 10000
seq_len = 9
pattern = r'\w+|[,;.:!?\']'

# Training
batch_size = 64
epochs = 100
lr = 0.0001

# Model
embedding_dim = 256
dropout_rate = 0.2

# Colab environment
IN_COLAB = 'google.colab' in sys.modules

if (IN_COLAB):
    %pip install colorama

    # Google Drive
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

    project_folder="/content/drive/MyDrive/Classes/IA024/Aula_2_3"
    os.chdir(project_folder)
    !ls -la

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


## Faz download e carrega o dataset

In [3]:
# Check if download is necessary
if not os.path.exists("67724.txt.utf-8"):
    print("Downloading Gutenberg texts")

    !wget https://www.gutenberg.org/ebooks/67724.txt.utf-8
    !wget https://www.gutenberg.org/ebooks/67725.txt.utf-8

In [4]:
text = open("67724.txt.utf-8","r").read()
text += open("67725.txt.utf-8","r").read()

paragraphs = text.split("\n\n")

len(paragraphs)

4969

In [5]:
# Checking the text
print(paragraphs[0])

The Project Gutenberg eBook of O Guarany: romance brazileiro, Vol. 1 (of 2)
    
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restrictions
whatsoever. You may copy it, give it away or re-use it under the terms
of the Project Gutenberg License included with this ebook or online
at www.gutenberg.org. If you are not located in the United States,
you will have to check the laws of the country where you are located
before using this eBook.


In [6]:
cleaned_paragraphs = [paragraph.replace("\n", " ") for paragraph in paragraphs if paragraph.strip()]

# Print 5 random paragraphs
num_paragraphs = len(cleaned_paragraphs)
for i in range(0,5):
    idx = random.randrange(num_paragraphs)
    print(f"{cleaned_paragraphs[idx]}\n")

print("Number of paragraphs: " + str(num_paragraphs))

len(cleaned_paragraphs)

--Porque é preciso.

Voltou-se para ver se alguem estava alli que reparasse no que ia fazer, e deo com o italiano que a dous passos delle o olhava com um dos seus sorrisos sarcasticos.

 PAG. 292.--=Guanumby=.

Um incidente veio atear a chamma que lastrava; Pery, apenas começou a romper o dia, via a alguma distancia do jardim o cadaver do Ruy Soeiro; e temendo que sua senhora acordando não presenciasse este triste espectaculo, tomou o corpo, e atravessando a esplanada, veio atira-lo no meio do pateo.

--Tu me offendes, Pery! exclamou o fidalgo; a minha casa está aberta para todos, e sobretudo para ti que és amigo, e salvaste minha filha.

Number of paragraphs: 4892


4892

## Análise do dataset

In [7]:
# Conta as palavras no dataset
def count_words(texts):
    word_counts = Counter()
    for text in texts:
        word_counts.update(re.findall(r'\w+', text.lower()))
    return word_counts

word_counts = count_words(cleaned_paragraphs)

len(word_counts)

12603

## Criando um vocabulário

In [8]:
most_frequent_words = [word for word, count in word_counts.most_common(vocab_size)]
vocab = {word: i for i, word in enumerate(most_frequent_words, 1)}

In [9]:
def encode_sentence(sentence, vocab):
    return [vocab.get(word, 0) for word in re.findall(pattern, sentence.lower())]

print(cleaned_paragraphs[20])
print(encode_sentence(cleaned_paragraphs[20], vocab))

 Publicando este livro em 1857, se disse ser aquella primeira edição uma prova typographica, que algum dia talvez o autor se dispuzesse a rever.
[6594, 139, 4376, 19, 6595, 0, 6, 44, 110, 269, 259, 2662, 10, 1064, 6596, 0, 2, 186, 130, 280, 3, 2257, 6, 6597, 1, 2665, 0]


## Classe do dataset

### Dataset Modificado para Máscara Causal

Modifique o nosso dataloader para retornar inputs (uma lista de tokens de tamanho $n$), targets (uma lista de tokens de tamanho $n$ deslocada para a esquerda em 1 token). Exemplo `input = [1, 2, 3, 4]`, `target = [2, 3, 4, 5]` para a sequência `[1, 2, 3, 4, 5]` com `seq_len=4`, por exemplo (Ver slide 50).

In [10]:
# Dataset class
import torch
from torch.utils.data import Dataset, DataLoader

class CausalMaskDataset(Dataset):
  def __init__(self, paragraphs, vocab, seq_len):
    self.paragraphs = paragraphs
    self.vocab = vocab
    self.seq_len = seq_len
    self.tokens, self.targets = self.setup()

  def __len__(self):
    return len(self.tokens)

  def __getitem__(self, idx):
    tokens = torch.tensor(self.tokens[idx])
    targets = torch.tensor(self.targets[idx])
    return tokens, targets
  
  def setup(self):
    tokens = []
    targets = []
    for paragraph in self.paragraphs:
      encoded = encode_sentence(paragraph, self.vocab)
      
      # If paragraph is smaller than the sequence length, skip it.
      if len(encoded) < self.seq_len + 1:
          continue

      for i in range(len(encoded) - self.seq_len):
        tks = encoded[i:i+self.seq_len]
        # Return targets with seq_len instead of a single one.
        tgt = encoded[i+1:i+1+self.seq_len]
        # Only add if there are no unknown tokens in both context and target.
        bad_token = 0
        if not (bad_token in tks or tgt == bad_token):
          tokens.append(tks)
          targets.append(tgt)
    return tokens, targets

#### Verificação básica do Dataset modificado

In [11]:
tst_paragraphs = cleaned_paragraphs[:50]
tst_seq_len = 9
tst_dataset = CausalMaskDataset(tst_paragraphs, vocab, tst_seq_len)

print("Input and Target Tensors:")
tst_dataset[0]


Input and Target Tensors:


(tensor([133,  47, 712, 144, 537, 324, 275,  35, 499]),
 tensor([ 47, 712, 144, 537, 324, 275,  35, 499,  47]))

### Verificação do Dataset

In [12]:
# Train/Validation split
train_data, val_data = train_test_split(cleaned_paragraphs, test_size=0.2, random_state=18)

train_dataset = CausalMaskDataset(train_data, vocab, seq_len)
val_dataset = CausalMaskDataset(val_data, vocab, seq_len)

# Counting all Samples
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print()
print(f"Training dataset samples: {len(train_dataset)}")
print(f"Validation dataset samples: {len(val_dataset)}")

Training samples: 3913
Validation samples: 979

Training dataset samples: 15993
Validation dataset samples: 4197


In [13]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

sample = next(iter(train_loader))
print(sample)

[tensor([[  20,   16,  392, 1076,    5,   63,  120,  104, 2154],
        [6513,  298,  117,   40,    5, 2643,   23,    7, 6516],
        [   1, 4073,    3,  782,    4, 1412,   22,  208,    4],
        [1502,  118, 6118,   75,  489,    5,  203,   39, 1545],
        [  10, 1139,    6,  856, 4116,   22,    2,    6,  196],
        [   5,   21,  152,  151,   88,  193,    1, 1837,    4],
        [ 933,  500,    2,    3,  115,    3, 1243,    6,   27],
        [1909,   32,   12,  935,    2,  358,  299,   56, 1822],
        [1339,   53,  197,   22, 3062,    4,   10, 1643, 4262],
        [ 459,    1,  146,  665,    5, 7812,    4,   29,    2],
        [   6,  467,  387,    1,  638,  177,   36,  935,  281],
        [1821,  119, 1145, 4018,   53,  794,    4,  793,  228],
        [ 499,  465, 1061,   35,  333,   47,   35,   67,   59],
        [ 474,  806,   86,    7, 2420,    4, 1872,   75,  160],
        [  71,  121,  648,   15,    2,   17,    9, 2464,   40],
        [2651,  109,   35,  463,  375, 

### Implementação do Modelo com Máscara Causal

In [14]:
# Código baseado no tutorial do Andrej Karpathy https://github.com/karpathy/ng-video-lecture
class Head(nn.Module):

    def __init__(self, head_size, seq_len):
        super(Head, self).__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(seq_len, seq_len)))

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)

        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

        wei = F.softmax(wei, dim=-1) # (B, T, T)

        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)

        return out
    
class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, embedding_dim):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, x):
        return self.net(x)
    
class KarpathyModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, seq_len):
      super(KarpathyModel, self).__init__()
    
      # Embedding
      self.embedding = nn.Embedding(vocab_size+1, embedding_dim)
      # Positional Embedding (simpler?)
      self.positional_embedding = nn.Embedding(seq_len, embedding_dim)
      # Single Head Attention
      self.attention = Head(embedding_dim, seq_len)
      # Projection
      self.WO = nn.Linear(embedding_dim, embedding_dim)
      # Linear Layers
      self.ffwd = FeedForward(embedding_dim)

      # Normalization Layers
      self.ln_attention = nn.LayerNorm(embedding_dim)
      self.ln_ffwd = nn.LayerNorm(embedding_dim)
      self.ln_f = nn.LayerNorm(embedding_dim) # final layer norm

      self.lm_head = nn.Linear(embedding_dim, vocab_size+1)
      self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input):
      B,T = input.shape
      # idx and targets are both (B,T) tensor of integers
      embedding_input = self.embedding(input) # (B,T,C)
      positions = self.positional_embedding(torch.arange(T, device=device)) # (T,C)
      x = embedding_input + positions # (B,T,C)
      # Auto Atenção
      attention = self.ln_attention(x)
      attention = self.attention(attention)
      attention = self.WO(attention)
      attention = self.dropout(attention)
      x = x + attention
      # MLP
      ffwd = self.ln_ffwd(x)
      ffwd = self.ffwd(ffwd)

      x = x + ffwd
      # Camada de saida
      x = self.ln_f(x)
      logits = self.lm_head(x)

      return logits

    def generate(self, idx, max_new_tokens=10, seq_len=9):
      # idx is (B, T) array of indices in the current context
      for _ in range(max_new_tokens):
          # crop idx to the last block_size tokens
          idx_cond = idx[:, -seq_len:]
          # get the predictions
          logits = self(idx_cond)
          # focus only on the last time step
          logits = logits[:, -1, :] # becomes (B, C)
          # apply softmax to get probabilities
          probs = F.softmax(logits, dim=-1) # (B, C)
          #Extraido do Ramon Simoes --> Excluir o token <unk> (codificado como 0) atribuindo probabilidade zero
          probs[:,0] = 0.0
          probs = probs / probs.sum(dim=-1, keepdim=True)
          # sample from the distribution
          idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
          # append sampled index to the running sequence
          idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
      return idx

## Treinamento e Avaliação

### Funções de Treinamento e Avaliação

#### Contagem de Parâmetros do Modelo

In [15]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has a total of {total_params:,} parameters.')

#### Avaliação Inicial Pré-Treinamento

In [16]:
def initial_eval(model, criterion, device):
    # Initial Perplexity and Loss
    # Before training
    model.eval()

    loss = 0
    perp = 0

    with torch.no_grad():
        for inputs, targets in train_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model(inputs)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)

            loss = criterion(logits, targets)
            
    loss /= len(train_loader)

    perp = torch.exp(loss)

    print(f'Initial Loss: {loss:.4f}')
    print(f'Initial Perplexity: {perp:.4f}')

#### Função de Treinamento

In [17]:
def train(model, criterion, optimizer, device):
      # Training Loop
      model.train()
      for epoch in range(epochs):

            epoch_start = time.time()
            # Metrics
            epoch_loss = 0
            epoch_correct = 0
            epoch_samples = 0

            for inputs, targets in train_loader:
                  inputs = inputs.to(device)  # Move input data to the device
                  targets = targets.to(device)

                  # Forward pass
                  logits = model(inputs)
                  B, T, C = logits.shape
                  logits = logits.view(B*T, C)
                  targets = targets.view(B*T)
                  loss = criterion(logits, targets)

                  # Backward pass and optimization
                  optimizer.zero_grad()
                  loss.backward()
                  optimizer.step()

                  # Loss
                  epoch_loss += loss.item()

                  # Predicted
                  _, predicted = torch.max(logits, 1)
                  epoch_correct += (predicted == targets).sum().item()
                  epoch_samples += targets.size(0)

            # Calculate average loss and accuracy for epoch
            avg_loss = epoch_loss / len(train_loader)
            acc = epoch_correct / epoch_samples

            # Perplexity
            perp = torch.exp(torch.tensor(avg_loss))

            epoch_end = time.time()
            epoch_time = epoch_end - epoch_start
            # Print epoch statistics
            print(f'Epoch [{epoch+1}/{epochs}], Time:{epoch_time:.2f}, Loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%, Perplexity: {perp:.4f}')


#### Função de Avaliação do Modelo Treinado

In [18]:
def eval(model, criterion, device):
    model.eval()

    loss_sum = 0
    total_sum = 0
    correct_sum = 0
    eval_round = 0

    loss = 0
    perp = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model(inputs)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = criterion(logits, targets)     

            loss_sum += loss

            # Get the predicted labels
            _, predicted = torch.max(logits, 1)

            total_sum += targets.size(0)
            correct_sum += (predicted == targets).sum().item()
            eval_round += 1

    # Calculate accuracy
    acc = 100 * correct_sum / total_sum

    # Calculate average perplexity
    average_loss = loss_sum / len(val_loader)
    average_perplexity = torch.exp(average_loss)

    print(f'Test Accuracy: {acc:.2f}%')
    print(f'Average Loss: {average_loss:.2f}')
    print(f'Average Perplexity: {average_perplexity:.2f}')

### Avaliação Inicial

In [19]:
model_attn = KarpathyModel(vocab_size, embedding_dim, seq_len)
print(model_attn)

KarpathyModel(
  (embedding): Embedding(10001, 256)
  (positional_embedding): Embedding(9, 256)
  (attention): Head(
    (key): Linear(in_features=256, out_features=256, bias=False)
    (query): Linear(in_features=256, out_features=256, bias=False)
    (value): Linear(in_features=256, out_features=256, bias=False)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (WO): Linear(in_features=256, out_features=256, bias=True)
  (ffwd): FeedForward(
    (net): Sequential(
      (0): Linear(in_features=256, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=256, bias=True)
      (3): Dropout(p=0.2, inplace=False)
    )
  )
  (ln_attention): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (ln_ffwd): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=256, out_features=10001, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


In [20]:
print("Model with Self Attention:")
print()
count_parameters(model_attn)

# Cross Entropy
criterion = nn.CrossEntropyLoss()

model_attn.to(device)

print()
print("Initial Evaluation")
print()
initial_eval(model_attn, criterion, device)

Model with Self Attention:

The model has a total of 5,922,321 parameters.

Initial Evaluation

Initial Loss: 0.0375
Initial Perplexity: 1.0382


### Treinamento do Modelo

In [21]:
print("Model with Self Attention and Causal Mask:")
print()

# Cross Entropy
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.AdamW(model_attn.parameters(), lr)

model_attn.to(device)

print()
print("Training Start")
print()
train(model_attn, criterion, optimizer, device)

Model with Self Attention and Causal Mask:


Training Start

Epoch [1/100], Time:2.99, Loss: 7.3040, Accuracy: 0.06%, Perplexity: 1486.1747
Epoch [2/100], Time:2.87, Loss: 6.4204, Accuracy: 0.08%, Perplexity: 614.2295
Epoch [3/100], Time:2.88, Loss: 5.9556, Accuracy: 0.11%, Perplexity: 385.9193
Epoch [4/100], Time:3.03, Loss: 5.4894, Accuracy: 0.13%, Perplexity: 242.1127
Epoch [5/100], Time:3.10, Loss: 5.1247, Accuracy: 0.15%, Perplexity: 168.1209
Epoch [6/100], Time:3.53, Loss: 4.8321, Accuracy: 0.17%, Perplexity: 125.4745
Epoch [7/100], Time:3.67, Loss: 4.5867, Accuracy: 0.18%, Perplexity: 98.1659
Epoch [8/100], Time:3.48, Loss: 4.3737, Accuracy: 0.19%, Perplexity: 79.3368
Epoch [9/100], Time:3.49, Loss: 4.1870, Accuracy: 0.21%, Perplexity: 65.8251
Epoch [10/100], Time:3.47, Loss: 4.0224, Accuracy: 0.22%, Perplexity: 55.8356
Epoch [11/100], Time:3.45, Loss: 3.8733, Accuracy: 0.24%, Perplexity: 48.1031
Epoch [12/100], Time:3.50, Loss: 3.7383, Accuracy: 0.25%, Perplexity: 42.0285
Epoch

KeyboardInterrupt: 

### Avaliação do Modelo

In [None]:
print()
print("Evaluation Start")
print()
eval(model_attn, criterion, device)

## Exemplo de uso

In [None]:
# Simple decoder for tokens into a sentence
def decode(tokens):
  words = []
  for key, value in vocab.items():
    if value in tokens:
      words.append(key)
  
  sentence = ' '.join(words)
  sentence = sentence.capitalize() + '.'
  return sentence

seq = torch.zeros((1, 1), dtype=torch.long, device=device)
tokens = model_attn.generate(seq, 20, seq_len)[0].tolist()
print(decode(tokens))