In [1]:
import math
import torch
from torch import nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_K, d_model, n_heads, max_len):
        super().__init__()
        # assume d_V == d_K
        self.d_K = d_K
        self.n_heads = n_heads
        self.key = nn.Linear(in_features=d_model,
                             out_features=d_K * n_heads)
        self.query = nn.Linear(in_features=d_model,
                      out_features=d_K * n_heads)
        self.value = nn.Linear(in_features=d_model,
                      out_features=d_K * n_heads)

        # final fully connected linear layer
        self.fc = nn.Linear(in_features=d_K * n_heads,
                            out_features=d_model)
        # causal mask
        causal_mask = torch.tril(
            torch.ones(max_len, max_len)
        )

        self.register_buffer("causal_mask",
                             causal_mask.view(1, 1, max_len, max_len))


    def forward(self, query, key, value, pad_mask=None):
        query = self.query(query) # N x T x (hd_Q)
        key = self.key(key) # N x T x (hd_K)
        value = self.value(value) # N x T x (hd_V)
        # get the dimensions
        N = query.shape[0]
        T = query.shape[1]
        # swap dimensions order for proper matrix multiplication
        # N x T x H x d_K -> N x H x T x d_K
        query = query.view(N, T, self.n_heads, self.d_K).transpose(1, 2)
        key = key.view(N, T, self.n_heads, self.d_K).transpose(1, 2)
        value = value.view(N, T, self.n_heads, self.d_K).transpose(1, 2)
        # compute attention weights
        # (N x H x T x d_K) x (N x H x d_K x T ) -> (N, H, T, T)
        attention_logits = torch.matmul(query,  key.transpose(-2, -1)) / math.sqrt(self.d_K)
        if pad_mask is not None:
            attention_logits = attention_logits.masked_fill(
                mask=pad_mask[:, None, None, :] == 0,
                value=float('-inf')
            )
        attention_logits = attention_logits.masked_fill(
            mask=self.causal_mask[:, :, :T, :T] == 0,
            value=float('-inf')
        )
        attention_weights = F.softmax(attention_logits, dim=-1)
        # compute attention weighted values
        attention = torch.matmul(attention_weights, value)
        # before inputing at final FC layer transpose back
        attention = attention.transpose(1, 2) # to (N x T x H x d_K)
        attention = attention.contiguous().view(N, T, self.d_K * self.n_heads)

        return self.fc(attention)

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, d_K: int, d_model: int, n_heads: int,
                 max_len: int, dropout_rate: float = 0.1):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(normalized_shape=d_model)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=d_model)
        self.mh_attention = CausalSelfAttention(d_K=d_K,
                                                d_model=d_model,
                                                n_heads=n_heads,
                                                max_len=max_len)
        self.network = nn.Sequential(
            nn.Linear(in_features=d_model, out_features=d_model * 4),
            nn.GELU(),
            nn.Linear(in_features=d_model * 4, out_features=d_model),
            nn.Dropout(p=dropout_rate)
        )
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x, pad_mask=None):
        x = self.layer_norm_1(x + self.mh_attention(x, x, x, pad_mask))
        x = self.layer_norm_1(x + self.network(x))
        x = self.dropout(x)
        return x



In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 2048, dropout_rate=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pos_enc = torch.zeros(1, max_len, d_model)
        pos_enc[0, :, 0::2] = torch.sin(position * div_term)
        pos_enc[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        # x.shape: N x T x D
        x = x + self.pos_enc[:, :x.size(1), :]
        return self.dropout(x)

In [5]:
# the Decoder's main task to predict the next token
# the 'n_classes' is equal to a vocab_size
class Decoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 max_len,
                 d_K,
                 d_model,
                 n_heads,
                 n_layers,
                 dropout_rate):
        super().__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=d_model)
        self.pos_encodding = PositionalEncoding(
            d_model=d_model, max_len=max_len, dropout_rate=dropout_rate
        )
        transformer_blocks = [
            TransformerBlock(
                d_K=d_K,
                d_model=d_model,
                n_heads=n_heads,
                max_len=max_len,
                dropout_rate=dropout_rate
            ) for _ in range(n_layers)
        ]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.norm = nn.LayerNorm(normalized_shape=d_model)
        self.fc = nn.Linear(in_features=d_model, out_features=vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embedding(x)
        x = self.pos_encodding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        # normalization and linear transformation
        x = self.norm(x)
        # many-to-many
        x = self.fc(x)
        return x

In [6]:
model = Decoder(
    20000, 1024, 16, 64, 4, 2, 0.1
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

cuda:0


In [7]:
x = torch.randint(0, 20000, size=(8, 512))
y = model(x.to(device))

print("Shape: ", y.shape)

Shape:  torch.Size([8, 512, 20000])


In [8]:
mask = np.ones((8, 512))
mask[:, 256: ] = 0
mask = torch.tensor(mask).to(device)

y = model(x.to(device), pad_mask=mask)

print("Shape: ", y.shape)

Shape:  torch.Size([8, 512, 20000])


In [9]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [10]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [11]:
pip install -q datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [12]:
from datasets import load_dataset

In [13]:
raw_data = load_dataset("glue", "sst2")

Downloading readme:   0%|          | 0.00/31.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [14]:
raw_data

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [15]:
def tokenizer_func(batch):
    return tokenizer(batch['sentence'], truncation=True)

In [16]:
tokenized_dataset = raw_data.map(tokenizer_func, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [17]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [18]:
tokenized_dataset = tokenized_dataset.remove_columns([
    "sentence", "idx", "label"
])

tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [19]:
train_dl = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator
)

# checkout how it works
for batch in train_dl:
    for key, value in batch.items():
        print("key: ", key, "value.shape ", value.shape)
    break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


key:  input_ids value.shape  torch.Size([32, 40])
key:  attention_mask value.shape  torch.Size([32, 40])


In [20]:
tokenizer.pad_token_id

0

In [21]:
model = Decoder(
    vocab_size=tokenizer.vocab_size,
    max_len=tokenizer.max_model_input_sizes[checkpoint],
    d_K=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout_rate=0.1
)

model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encodding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mh_attention): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (network): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (l

In [22]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())

In [23]:
from datetime import datetime

In [24]:
def train(model, criterion, optimizer, train_dl, epochs):
    train_losses = np.zeros(epochs)

    for iteration in range(epochs):
        model.train()
        tic = datetime.now()
        # accumulate loss per batch
        train_loss = []

        for batch in train_dl:
            # to gpu
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            # set that target is our next token, similar as input
            # but with rolling by 1 idx and shift it to the left
            # avoid the special [CLS] token at the 0 position
            targets = batch['input_ids'].clone().detach()
            targets = torch.roll(targets,
                                 shifts=-1,
                                 dims=1)
            targets[:, -1] = tokenizer.pad_token_id

            output = model(
                batch['input_ids'], batch['attention_mask']
            )

            #print("outputs shape: ", output.shape)
            #print("targets shape: ", targets.shape)

            loss = criterion(output.transpose(2, 1),
                             targets)

            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        # get an average loss per epoch
        train_loss = np.mean(train_loss)
        train_losses[iteration] = train_loss

        tac = datetime.now()

        iter_time = tac - tic

        print(f"Epoch {iteration + 1}/{epochs}, Train_loss: {train_loss:.4f}, Duration: {iter_time}")

    return train_losses



In [25]:
tr_losses = train(
    model, criterion, optimizer, train_dl, 15
)

Epoch 1/15, Train_loss: 5.9713, Duration: 0:01:08.963321
Epoch 2/15, Train_loss: 5.0204, Duration: 0:01:05.533060
Epoch 3/15, Train_loss: 4.6833, Duration: 0:01:05.919571
Epoch 4/15, Train_loss: 4.4936, Duration: 0:01:07.138839
Epoch 5/15, Train_loss: 4.3579, Duration: 0:01:07.626623
Epoch 6/15, Train_loss: 4.2543, Duration: 0:01:05.714643
Epoch 7/15, Train_loss: 4.1640, Duration: 0:01:04.520695
Epoch 8/15, Train_loss: 4.0891, Duration: 0:01:06.424164
Epoch 9/15, Train_loss: 4.0226, Duration: 0:01:05.732206
Epoch 10/15, Train_loss: 3.9638, Duration: 0:01:03.564603
Epoch 11/15, Train_loss: 3.9085, Duration: 0:01:01.257541
Epoch 12/15, Train_loss: 3.8591, Duration: 0:00:58.971661
Epoch 13/15, Train_loss: 3.8113, Duration: 0:01:01.725275
Epoch 14/15, Train_loss: 3.7713, Duration: 0:01:05.703578
Epoch 15/15, Train_loss: 3.7321, Duration: 0:01:07.400636


In [28]:
valid_dl = DataLoader(
    tokenized_dataset['validation'],
    batch_size=32,
    collate_fn=data_collator
)

In [31]:
model.eval()
for batch in valid_dl:
    batch = {k: v.to(device) for k, v in batch.items()}
    output = model(batch['input_ids'], batch['attention_mask'])
    break

print("output shape: ", output.shape)
predictions = torch.argmax(output, axis=-1)
print("preds shape: ", predictions.shape)
print("tokenizer decode preds: ", tokenizer.decode(predictions[0]))
print("tokenizer decode batch: ", tokenizer.decode(batch['input_ids'][0]))
print("-"*30)
print(batch['input_ids'][0, :5])
print(predictions[:, 4])
print("-"*30)
print("Concat batch inputs and preds: \n",
      tokenizer.decode(
          torch.concat((batch['input_ids'][0, :5],
                        predictions[:, 4]))
      ))

output shape:  torch.Size([32, 51, 28996])
preds shape:  torch.Size([32, 51])
tokenizer decode preds:  a's a good, un handled film [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP]ddddddddddd... [SEP] [SEP] [SEP] [SEP]............. fare
tokenizer decode batch:  [CLS] it's a charming and often affecting journey. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
------------------------------
tensor([ 101, 1122,  112,  188,  170], device='cuda:0')
tensor([ 1363,  2806,   102,  1132,  1106,  1105,  1193,  3381,   189,  1104,
         6185,  8869,  1273,   170,  1129,   118,  6288,  7777,  1273,   102,
         8362, 24914,   102, 24183,  1115,   170,  1103, 13143,  1363,  1120,
        24181,  1106], device='cuda:0')
------------------------------
Concat batch inputs and preds: 
 [CLS] it's a

In [41]:
prompt = "it's"

tokenized_prompt = tokenizer(prompt, return_tensors='pt')
print(tokenized_prompt)

{'input_ids': tensor([[ 101, 1122,  112,  188,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}


In [42]:
output = model(
    tokenized_prompt['input_ids'][:, :-1].to(device),
    tokenized_prompt['attention_mask'][:, :-1].to(device)
)

print("Output shape ", output.shape)

preds = torch.argmax(output[: , -1, :], dim=-1)
print("Decoded preds: ", tokenizer.decode(preds[0]))


Output shape  torch.Size([1, 4, 28996])
Decoded preds:  a


In [45]:
prompt = "it's a"
tokenized_prompt = tokenizer(prompt, return_tensors='pt')

input_ids = tokenized_prompt['input_ids'][:, :-1].to(device)
mask = tokenized_prompt['attention_mask'][:, :-1].to(device)

for _ in range(20):
    output = model(input_ids, mask)
    preds = torch.argmax(output[:, -1, :], dim=-1)
    input_ids = torch.hstack(
        (input_ids, preds.view(1, 1))
    )
    mask = torch.ones_like(input_ids)

    if preds == tokenizer.sep_token_id:
        break

In [46]:
tokenizer.decode(input_ids[0])

"[CLS] it's a good time [SEP]"