In [None]:
!pip install datasets transformers==4.28.0 torchinfo



In [None]:
import numpy as np
from torch.utils.data import dataset
import torchinfo
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import optim
import math
import datasets
from transformers import AutoTokenizer, DataCollatorWithPadding
from datetime import datetime

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len):
        super().__init__()

        self.d_k = d_k
        self.d_model = d_model
        self.n_heads = n_heads

        # generate matrices weights
        self.key = nn.Linear(d_model, d_k * n_heads) # (d_model x d_k)
        self.query = nn.Linear(d_model, d_k * n_heads) # (d_model x d_k)
        self.value = nn.Linear(d_model, d_k * n_heads) # (d_model x d_k)

        # final linear layer
        self.final_layer = nn.Linear(d_k * n_heads, d_model)

        # create causal mask - weights can only access previous tokens
        cm = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer(
            'causal_mask',
            cm.view(1, 1, max_len, max_len)
        )


    def forward(self, x, pad_mask=None):
        # x -> batch_size (N) x T x d_model
        k = self.key(x)   # N x T x h*d_k
        q = self.query(x) # N x T x h*d_k
        v = self.value(x) # N x T x h*d_v

        N = q.shape[0] # store batch size
        T = q.shape[1] # store sequence dimension

        # tranform to (N x T x h, d_k) -> (N x h x T x d_k)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)

        # (N x h x T x d_k) * (N x h x d_k x T) -> (N x h x T x T)
        attn_scores = q @ k.transpose(-1, -2) / np.sqrt(self.d_k)
        if pad_mask is not None:
            # mask is vector size 1 x T
            attn_scores = attn_scores.masked_fill(
                pad_mask[:, None, None, :] == 0, float('-inf')
            )
        attn_scores = attn_scores.masked_fill(
            self.causal_mask[:, :, :T, :T] == 0, float('-inf')
        )
        attn_weights = F.softmax(attn_scores, dim=-1)

        # (N x h x T x T) * (N x h x T x d_v) -> (N x h x T x d_v)
        A = attn_weights @ v
        # reshape to (N x T x h*d_v)
        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T, self.n_heads * self.d_k)

        # (N x T x d_k * h) -> (N x T x d_model)
        return self.final_layer(A)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()

        self.d_k = d_k
        self.d_model = d_model
        self.n_heads = n_heads

        self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
        self.ln1 = nn.LayerNorm(d_model)
        # (N x T x d_model) -> (N x T)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(p=dropout_prob)
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(p=dropout_prob)

    def forward(self, x, pad_mask=None):
        # x = (N x T x d_model) -> (N x T x d_model)
        x = self.ln1(x + self.mha(x, pad_mask=pad_mask))
        x = self.ln2(x + self.ann(x))
        x = self.drop(x)
        return x

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_length=2048):
        super().__init__()

        self.dropout = nn.Dropout(p=dropout)

        # PE(pos, 2i) = sin(pos / 10000^(2i / d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i / d_model))
        position = torch.arange(max_length).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        pe = torch.zeros(1, max_length, d_model)
        div_term = torch.exp(exp_term * (-np.log(10000.0)) / d_model)
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [None]:
class Decoder(nn.Module):
    def __init__(
            self,
            d_k,
            max_length,
            vocab_size,
            d_model,
            n_heads,
            n_layers,
            dropout=0.1
            ):
        super().__init__()

        # after tokenization -> batch_size x max_length (N x T)

        self.embed = nn.Embedding(vocab_size, d_model)
        # after embedding -> batch_size x max_length x d_model (N x T x d_model)

        self.positional_encoding = PositionalEncoding(d_model, max_length=max_length)
        transformer_blocks = [
            TransformerBlock(
            d_k,
            d_model,
            n_heads,
            max_length,
            dropout_prob=dropout
            ) for _ in range(n_layers)
        ]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.layer_norm = nn.LayerNorm(d_model)
        self.final_layer = nn.Linear(d_model, vocab_size)


    def forward(self, x, pad_mask=None):
        x = self.embed(x)
        x = self.positional_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask=pad_mask)
        x = self.layer_norm(x)
        x = self.final_layer(x)
        return x

In [None]:
# generate decoder with dummy data

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Decoder(
    d_k=64,
    max_length=256,
    vocab_size=20000,
    d_model=512,
    n_heads=4,
    n_layers=2,
).to(device)


x_input = torch.randint(0, 20_000, (8, 256)).to(device)
x_pad_mask = torch.ones(8, 256).to(device)
x_pad_mask[:, 128:] = 0

y = model(x_input, x_pad_mask)

In [None]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

raw_ds = datasets.load_dataset('glue', 'sst2')

def tokenize_dataset(batch):
    return tokenizer(batch['sentence'], truncation=True)

tokenized_ds = raw_ds.map(
    tokenize_dataset,
    batched=True,
    remove_columns=raw_ds['train'].column_names
    )

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    tokenized_ds['train'],
    batch_size=32,
    shuffle=True,
    collate_fn=data_collator,
)

model = Decoder(
    d_k=64,
    max_length=tokenizer.max_model_input_sizes[checkpoint],
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    n_heads=4,
    n_layers=2,
    dropout=0.1,
).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters())


def train_decoder(model, num_epochs, loss_fn, optimizer, train_dataloader):
    train_losses = []
    for epoch in range(1, num_epochs+1):
        n_train = 0
        train_loss = []
        model.train()
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()
            # generate shifted input tensor (N x T)
            # go one step left (-1) on dim T
            targets = batch['input_ids'].clone().detach()
            decoder_target = torch.roll(targets, shifts=-1, dims=1)
            decoder_target[:, -1] = tokenizer.pad_token_id

            outputs = model(batch['input_ids'], batch['attention_mask'])
            # output: N x T
            loss = loss_fn(outputs.transpose(1, 2), decoder_target)
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            n_train += batch['input_ids'].size(0)
        train_loss = np.mean(train_loss)
        train_losses.append(train_loss)

        print(f'Epoch {epoch}/{num_epochs} ---> Train Loss: {train_loss:.4f}')
        if epoch == num_epochs:
            torch.save(model, '/content/drive/MyDrive/Data Science/NLP/Transformers from scratch/decoder_model.pt')
    return train_losses



  0%|          | 0/3 [00:00<?, ?it/s]



In [None]:
train_losses = train_decoder(model, 15, loss_fn, optimizer, train_dataloader)

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch 1/15 ---> Train Loss: 4.8155
Epoch 2/15 ---> Train Loss: 3.5209
Epoch 3/15 ---> Train Loss: 2.9287


In [None]:
def generate_text(prompt, text_len):
    inputs = tokenizer(prompt, return_tensors='pt')
    input_ids = inputs['input_ids'][:, :-1].to(device)
    mask = inputs['attention_mask'][:, :-1].to(device)

    for _ in range(text_len):
        outputs = model(input_ids, mask)
        results = torch.argmax(outputs[:, -1, :], axis=-1)
        input_ids = torch.hstack((input_ids, results.view(1, 1)))
        mask = torch.ones_like(input_ids)

        if results == tokenizer.pad_token_id:
            break
    return input_ids

In [None]:
input_ids = generate_text("it's a rather rare", 30)
tokenizer.decode(input_ids[0])