In [67]:
!pip install torch==2.3.0 # To work with torchtext
!pip install torchtext
!pip install portalocker
!pip install torchdata

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip[0m
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Defaulting to user installation because normal sit

In [68]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

import torch
import torch.nn as nn
import math
import einops
from tokenizers import CharBPETokenizer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

In [69]:
# Import huggingface char-bpe tokenizer
en_tokenizer, de_tokenizer = CharBPETokenizer(), CharBPETokenizer()

special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_tokenizer.add_special_tokens(special_symbols)
de_tokenizer.add_special_tokens(special_symbols)

# Train tokenizers
train_iter, test_iter, valid_iter = Multi30k(language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_data, en_data = (list(zip(*train_iter)))
de_data_test, en_data_test = (list(zip(*test_iter)))

en_tokenizer.train_from_iterator(iterator=en_data)
de_tokenizer.train_from_iterator(iterator=de_data)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = tuple(en_tokenizer.encode(x).ids[0] for x in special_symbols)










In [70]:
en_tokenizer.encode(en_data[0]).tokens, en_tokenizer.encode(en_data[0]).ids

(['Two</w>',
  'young</w>',
  ',</w>',
  'White</w>',
  'males</w>',
  'are</w>',
  'outside</w>',
  'near</w>',
  'many</w>',
  'bushes</w>',
  '.</w>'],
 [218, 255, 112, 2727, 1295, 201, 375, 463, 1086, 2994, 154])

In [71]:
# Text transforms for data
text_transform_en = lambda x : torch.tensor([BOS_IDX] + en_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])
text_transform_de = lambda x : torch.tensor([BOS_IDX] + de_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])


In [72]:
"""Position encoding"""
class PositionEncoding2(nn.Module):
    def __init__(self, d_embed):
        super(PositionEncoding2, self).__init__()
        self.d_embed = d_embed
        self.dropout = nn.Dropout(p=0.1)

    def get_position_encoding(self, seq_len):
        encoding = torch.zeros((seq_len, self.d_embed))
        dimensions = torch.arange(0, self.d_embed//2)
        timesteps = torch.arange(0, seq_len)

        encoding[:, 0::2] = torch.sin(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        encoding[:, 1::2] = torch.cos(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        
        return encoding
    
    def forward(self, inp):
        pos_encoding = nn.Parameter(self.get_position_encoding(seq_len=inp.shape[-2]), requires_grad=False).to(DEVICE)
        return self.dropout(inp + pos_encoding) # + dropout ?


In [73]:
"""Multihead attention"""
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads
        self.W_q = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_k = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_v = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))
        mask = mask.to(DEVICE)
        
        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = einops.einsum(self.W_q, q, 'h d_model d_h, b T d_model -> b h  T d_h') # T is timesteps (sequence length)
        k = einops.einsum(self.W_k, k, 'h d_model d_h, b T d_model -> b h  T d_h')
        v = einops.einsum(self.W_v, v, 'h d_model d_h, b T d_model -> b h  T d_h')

        # print("q_shape: ", q.shape)
        # print("k_shape: ", k.shape)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')

        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [74]:
"""Multihead attention"""
class MultiHeadAttention2(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads

        self.W_q = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_hidden)

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))
        mask = mask.to(DEVICE)

        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = self.W_q(q)
        k = self.W_k(k)
        v = self.W_v(v)

        # Split into attention heads
        q = einops.rearrange(q, 'b T (h w) -> b h T w', h=self.n_heads)
        k = einops.rearrange(k, 'b T (h w) -> b h T w', h=self.n_heads)
        v = einops.rearrange(v, 'b T (h w) -> b h T w', h=self.n_heads)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')

        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [75]:
class PosWiseFeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super(PosWiseFeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        # FFN(x) = max(0, xW1 + b1)W2 + b2 ; f(x) = max(0, x) is relu
        return self.fc2(self.relu(self.fc1(x)))

In [76]:
# Encoder and Decoder
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6, model_kind=1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, d_hidden, num_heads) if model_kind==1 else MultiHeadAttention2(d_model, d_hidden, num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None, is_causal=False):
        #print("Q, K, V: ", x.size())
        attn_output = self.self_attn(x, x, x, mask=mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        ff_output = self.feed_forward(x) # feed-forward
        x = self.norm2(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6, model_kind=1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, d_hidden, num_heads) if model_kind==1 else MultiHeadAttention2(d_model, d_hidden, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, d_hidden, num_heads) if model_kind==1 else MultiHeadAttention2(d_model, d_hidden, num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask, is_causal=False):
        attn_output = self.self_attn(x, x, x, mask=tgt_mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        attn_output = self.cross_attn(q=x, k=enc_output, v=enc_output, mask=src_mask)
        x = self.norm2(x + self.dropout(attn_output)) # layer-norm + droput + skip connection
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x

In [77]:
class Transformer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff,src_vocab_size, tgt_vocab_size, dropout=0, n_enc_layers=6, n_dec_layers=6, model_kind=1):
        super(Transformer, self).__init__()

        self.enc_layers = nn.ModuleList(
            [EncoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         num_heads=num_heads,
                         d_ff=d_ff,
                         model_kind=model_kind) for _ in range(n_enc_layers)]
        )
        self.dec_layers = nn.ModuleList(
            [DecoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         d_ff=d_ff,
                         num_heads=num_heads,
                         model_kind=model_kind) for _ in range(n_dec_layers)]
        )

        self.pos_encoding_layer = PositionEncoding2(d_embed=d_model)

        self.src_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=src_vocab_size, padding_idx=PAD_IDX)
        self.tgt_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=tgt_vocab_size, padding_idx=PAD_IDX)

        self.final = nn.Linear(in_features=d_hidden, out_features=tgt_vocab_size)

    def forward(self, inp_seq, tgt_seq, inp_mask=None, tgt_mask=None, is_causal_tgt=False):
        

        inp_seq = self.src_embedding(inp_seq)
        inp_seq = self.pos_encoding_layer(inp_seq)

        for enc in self.enc_layers:
            inp_seq = enc(inp_seq, mask=inp_mask, is_causal=False)

        # inp_seq = self.enc_layers(inp_seq, mask=inp_mask, is_causal=False)

        tgt_seq = self.src_embedding(tgt_seq)
        # tgt_seq = self.dec_layers(x=tgt_seq, enc_out=inp_seq, mask=tgt_mask, is_causal=is_causal_tgt)
        for dec in self.dec_layers:
            tgt_seq = dec(x=tgt_seq, enc_output=inp_seq, src_mask=inp_mask, tgt_mask=tgt_mask, is_causal=is_causal_tgt)

        out_logits = self.final(tgt_seq)
        return out_logits





In [78]:
"""Sample running data through transformer"""
t = Transformer(d_model=8, d_hidden=8, num_heads=4, d_ff=10, src_vocab_size=3, tgt_vocab_size=7)
src, tgt = torch.arange(0, 3).tile(2, 1), torch.arange(0, 3).tile(2, 1)

src.shape, t(src, tgt)#.shape

(torch.Size([2, 3]),
 tensor([[[ 0.2983,  0.2131,  1.3389, -0.1188, -0.1022,  0.1562, -0.5975],
          [ 0.5211, -0.2773,  1.0763, -0.4771, -0.0267,  0.5911, -0.0099],
          [ 0.7991,  0.2512,  1.3321,  0.6017, -0.5192,  0.3061, -0.5167]],
 
         [[ 0.8551,  0.0807,  0.3819, -0.2371, -0.1956,  0.2857, -0.5817],
          [ 1.3447, -0.0757, -0.2926,  0.3488, -0.6398,  0.2266, -0.4109],
          [ 0.5988,  0.6159,  1.3154,  0.3818, -0.5002,  0.1380, -0.8236]]],
        grad_fn=<ViewBackward0>))

In [79]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # print(batch)
    src_batch, tgt_batch = list(zip(*batch))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX).transpose(-1, -2)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX).transpose(-1, -2)

    return src_batch, tgt_batch 

In [80]:
# Simple task to overfit the model on; helps to catch dumb bugs in the architecture and training
text_transform_en('one two three four'), text_transform_en('five six seven eight')

(tensor([   2,  313,  400,  826, 1276,    3]),
 tensor([   2, 1859, 2613, 5024, 3934,    3]))

In [81]:
from torch.utils.data import DataLoader


D_MODEL = 512
D_FF = 2048
N_HEADS = 8

BATCH_SIZE = 32#64
LR = 0.00001 # make this a schedule
BETAS = (0.9, 0.98)


def train_epoch(model, optimizer, data_loader=None):
    model.train()
    losses = 0

    tokenized_train_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))
    
    train_dataloader = DataLoader(tokenized_train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    data_count = 0
    import tqdm
    for src, tgt in tqdm.tqdm(train_dataloader):#data_loader):#train_dataloader):
        data_count+=1

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        out_logits = model(src, tgt[:, :-1], is_causal_tgt=True)

        optimizer.zero_grad()

        loss = loss_fn(out_logits.transpose(-1, -2), tgt[:, 1:])

        loss.backward()
        optimizer.step()
        losses += loss.item()

    print("Trained on %s datapts"%data_count)

    return losses / data_count


def evaluate(model):
    tokenized_test_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), test_iter))
    valid_dataloader = DataLoader(tokenized_test_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    model.eval()
    losses = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    data_count=0
    for src, tgt in valid_dataloader:
        data_count+=1

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        output = model(src, tgt[:,:-1])

    
        loss = loss_fn(output.transpose(-1, -2), tgt[:, 1:])
        losses += loss.item()

    return losses / data_count

In [82]:

model = Transformer(d_model=D_MODEL, 
                    d_hidden=D_MODEL, 
                    num_heads=N_HEADS, 
                    d_ff=D_FF, 
                    src_vocab_size=de_tokenizer.get_vocab_size(), 
                    tgt_vocab_size=en_tokenizer.get_vocab_size(),
                    model_kind=1 # type 2 is more common (refers to using MultiHeadAttenion2 class in this notebook); see both implementations above for detail
                    )

optimizer = torch.optim.Adam(betas=BETAS, eps=10e-9, lr=LR, params=model.parameters())


In [83]:
from timeit import default_timer as timer
NUM_EPOCHS = 1#0

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)#, train_dataloader)
    end_time = timer()
    val_loss = evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

907it [10:08,  1.49it/s]


Trained on 907 datapts
Epoch: 1, Train loss: 5.419, Val loss: 6.631, Epoch time = 608.327s


In [86]:
# Inference 
y_toks = [BOS_IDX]
a_sample = 80
src_toks = de_tokenizer.encode(de_data_test[a_sample]).ids
model.eval()
toks_generated=0
max_len=100
while y_toks[-1] != EOS_IDX and toks_generated < max_len:
    toks_generated+=1
    out_logits = model(torch.tensor(src_toks).unsqueeze(0), torch.tensor(y_toks).unsqueeze(0), is_causal_tgt=True)
    idx_predicted = torch.argmax(out_logits[:,-1])
    y_toks.append(idx_predicted)

In [26]:
print("predicted tokens: ", y_toks)
print("text: ", en_tokenizer.decode(y_toks))
print("ground truth: ", en_data_test[a_sample])

predicted tokens:  [2, tensor(218), tensor(384), tensor(559), tensor(201), tensor(287), tensor(170), tensor(97), tensor(487), tensor(287), tensor(170), tensor(97), tensor(487), tensor(149), tensor(3)]
text:  Two large dogs are playing on a beach playing on a beach .
ground truth:  Two large tan dogs play along a sandy beach.
