In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [2]:
class CausalSelfAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len):
    super().__init__()

    self.d_k = d_k
    self.n_heads = n_heads

    self.key = nn.Linear(d_model, d_k*n_heads)
    self.query = nn.Linear(d_model, d_k*n_heads)
    self.value = nn.Linear(d_model, d_k*n_heads)

    self.fc = nn.Linear(d_k*n_heads, d_model)

    cm = torch.tril(torch.ones(max_len, max_len))
    self.register_buffer(
        "causal_mask",
        cm.view(1,1,max_len,max_len)
    )

  def forward(self, q, k, v, pad_mask = None):
    q = self.query(q)
    k = self.key(k)
    v = self.value(v)

    N = q.shape[0]
    T = q.shape[1]

    q = q.view(N,T,self.n_heads, self.d_k).transpose(1,2)
    k = k.view(N,T,self.n_heads, self.d_k).transpose(1,2)
    v = v.view(N,T,self.n_heads, self.d_k).transpose(1,2)

    attn_scores = q@k.transpose(-2,-1)/math.sqrt(self.d_k)

    if pad_mask is not None:
      attn_scores = attn_scores.masked_fill(
          pad_mask[:,None, None,:] == 0, float('-inf')
      )
    attn_scores = attn_scores.masked_fill(
        self.causal_mask[:,:,:T,:T] == 0, float('-inf')
    )
    attn_weights = F.softmax(attn_scores, dim = -1)

    A = attn_weights @ v

    A = A.transpose(1,2)
    A = A.contiguous().view(N,T,self.d_k * self.n_heads)

    return self.fc(A)

In [3]:
class TransformerBlock(nn.Module):
  def __init__(self, d_k,d_model, n_heads,max_len, dropout_prob = 0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model *4),
        nn.GELU(),
        nn.Linear(d_model*4, d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout = nn.Dropout(p=dropout_prob)

  def forward(self, x, pad_mask = None):
    x = self.ln1(x+self.mha(x,x,x,pad_mask))
    x = self.ln2(x+self.ann(x))
    x = self.dropout(x)
    return x

In [4]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 2048, dropout_prob = 0.1):
    super().__init__()

    self.dropout = nn.Dropout(p=dropout_prob)

    position = torch.arange(max_len).unsqueeze(1)
    exp_term = torch.arange(0,d_model, 2)
    div_term = torch.exp(exp_term*(-math.log(10000.0)/d_model))
    pe = torch.zeros(1,max_len, d_model)
    pe[0,:,0::2] = torch.sin(position * div_term)
    pe[0,:,1::2] = torch.cos(position * div_term)
    self.register_buffer('pe',pe)

  def forward(self, x):
    x = x+self.pe[:,:x.size(1),:]
    return self.dropout(x)

In [5]:
class Decoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        TransformerBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob) for _ in range(n_layers)]
    self.transfomer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, x, pad_mask = None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transfomer_blocks:
      x = block(x,pad_mask)

    x = self.ln(x)
    x = self.fc(x)

    return x

In [6]:
!pip install transformers datasets

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.16.1 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [7]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [8]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [9]:
from datasets import load_dataset

In [10]:
raw_datasets = load_dataset("glue", "sst2")

Downloading readme:   0%|          | 0.00/31.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [11]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [12]:
def tokenize_fn(batch):
  return tokenizer(batch['sentence'],truncation = True)

In [13]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched = True)
data_collator = DataCollatorWithPadding(tokenizer = tokenizer)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [14]:
tokenized_datasets = tokenized_datasets.remove_columns(["sentence","idx", "label"])

In [15]:
from torch.utils.data import DataLoader
train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle = True,
    batch_size = 32,
    collate_fn = data_collator
)

In [16]:
for batch in train_loader:
  for k,v in batch.items():
    print("k:", k, "v.shape:",v.shape)
  break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: input_ids v.shape: torch.Size([32, 33])
k: attention_mask v.shape: torch.Size([32, 33])


In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [18]:
model = Decoder(
    vocab_size = tokenizer.vocab_size,
    max_len = tokenizer.max_model_input_sizes[checkpoint],
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout_prob = 0.1
)
model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transfomer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, 

In [19]:
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())

In [20]:
from datetime import datetime

In [21]:
def train(model, criterion, optimizer, train_loader, epochs):
  train_losses = np.zeros(epochs)


  for it in range(epochs):
    model.train()
    t0 = datetime.now()
    train_loss = []

    for batch in train_loader:
      batch = {k:v.to(device) for k,v in batch.items()}

      optimizer.zero_grad()

      targets = batch["input_ids"].clone().detach()
      targets = torch.roll(targets, shifts = -1, dims =1)
      targets[:,-1] = tokenizer.pad_token_id

      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = criterion(outputs.transpose(2,1), targets)

      loss.backward()
      optimizer.step()

      train_loss.append(loss.item())


    train_loss = np.mean(train_loss)
    train_losses[it] = train_loss

    dt = datetime.now() - t0
    print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, Duration:{dt}')
  return train_losses


In [22]:
train_losses = train(
    model, criterion, optimizer, train_loader, epochs = 15
)

Epoch 1/15, Train Loss: 5.9601, Duration:0:01:02.639954
Epoch 2/15, Train Loss: 5.0147, Duration:0:01:09.253377
Epoch 3/15, Train Loss: 4.6855, Duration:0:01:11.820829
Epoch 4/15, Train Loss: 4.4975, Duration:0:00:58.809222
Epoch 5/15, Train Loss: 4.3658, Duration:0:01:00.118868
Epoch 6/15, Train Loss: 4.2594, Duration:0:00:59.020143
Epoch 7/15, Train Loss: 4.1717, Duration:0:01:00.472894
Epoch 8/15, Train Loss: 4.0951, Duration:0:01:05.484818
Epoch 9/15, Train Loss: 4.0292, Duration:0:01:05.491189
Epoch 10/15, Train Loss: 3.9686, Duration:0:01:07.631013
Epoch 11/15, Train Loss: 3.9118, Duration:0:01:08.058552
Epoch 12/15, Train Loss: 3.8620, Duration:0:01:00.091354
Epoch 13/15, Train Loss: 3.8140, Duration:0:01:02.772734
Epoch 14/15, Train Loss: 3.7735, Duration:0:01:02.045789
Epoch 15/15, Train Loss: 3.7302, Duration:0:01:13.769166


In [25]:
valid_loader = DataLoader(
  tokenized_datasets["validation"],
  batch_size = 1,
  collate_fn = data_collator
)

In [26]:
model.eval()
for batch in valid_loader:
  batch = {k:v.to(device) for k,v in batch.items()}

  outputs = model(batch["input_ids"], batch["attention_mask"])
  break

In [27]:
outputs.shape

torch.Size([1, 12, 28996])

In [28]:
torch.argmax(outputs, axis = -1)

tensor([[ 170,  112,  188,  170, 1363,  117, 6276, 6276, 1642,  102,  102,  102]],
       device='cuda:0')

In [29]:
predictions_ids = torch.argmax(outputs, axis = -1)

In [30]:
tokenizer.decode(predictions_ids[0])

"a's a good, funny funny story [SEP] [SEP] [SEP]"

In [31]:
tokenizer.decode(batch["input_ids"][0])

"[CLS] it's a charming and often affecting journey. [SEP]"

In [33]:
tokenizer.decode(torch.concat((batch["input_ids"][0,:5], predictions_ids[:,4])))

"[CLS] it's a good"

In [34]:
prompt = "hello"

In [35]:
tokenized_prompt = tokenizer(prompt, return_tensors = 'pt')
tokenized_prompt

{'input_ids': tensor([[  101, 19082,   102]]), 'attention_mask': tensor([[1, 1, 1]])}

In [36]:
outputs = model(
    tokenized_prompt['input_ids'][:,:-1].to(device),
    tokenized_prompt['attention_mask'][:,:-1].to(device)
)

In [37]:
predictions_ids = torch.argmax(outputs[:,-1,:], axis = -1)

In [38]:
tokenizer.decode(predictions_ids[0])

','

In [40]:
prompt = "it's a"

tokenized_prompt = tokenizer(prompt, return_tensors = 'pt')

input_ids = tokenized_prompt['input_ids'][:,:-1].to(device)
mask = tokenized_prompt['attention_mask'][:,:-1].to(device)

for _ in range(20):
  outputs = model(input_ids, mask)
  prediction_id = torch.argmax(outputs[:,-1,:],axis = -1)

  input_ids = torch.hstack((input_ids, prediction_id.view(1,1)))
  mask = torch.ones_like(input_ids)

  if prediction_id == tokenizer.sep_token_id:
    break

In [41]:
tokenizer.decode(input_ids[0])

"[CLS] it's a good time [SEP]"