In [105]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

import torch
import torch.nn as nn
import math
import einops
from tokenizers import CharBPETokenizer

multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'



In [106]:
# Import huggingface char-bpe tokenizer
en_tokenizer, de_tokenizer = CharBPETokenizer(), CharBPETokenizer()

special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_tokenizer.add_special_tokens(special_symbols)
de_tokenizer.add_special_tokens(special_symbols)

# Train tokenizers
train_iter, test_iter, valid_iter = Multi30k(language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_data, en_data = (list(zip(*train_iter)))
en_tokenizer.train_from_iterator(iterator=en_data)
de_tokenizer.train_from_iterator(iterator=de_data)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = tuple(en_tokenizer.encode(x).ids[0] for x in special_symbols)










In [107]:
en_tokenizer.encode(en_data[0])#.tokens

Encoding(num_tokens=11, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [108]:
text_transform_en = lambda x : torch.tensor([BOS_IDX] + en_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])
text_transform_de = lambda x : torch.tensor([BOS_IDX] + de_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])



In [109]:
# torch.tensor(list(train_iter), dtype=torch.StringType)
# type(train_iter)
# (list(zip(*train_iter)))[0][9]
# en_data, de_data = (list(zip(*train_iter)))
# de_data_tokenized, en_data_tokenized = 
# map(lambda x : (x[0], x[1]), train_iter)#list(map(text_transform_de, de_data)), list(map(text_transform_de, de_data))
# # de_data_tokenized
# tokenized_multik3 = list(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))

In [110]:
# len(tokenized_multik3), len(list(train_iter)), type(train_iter)
# torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(tokenized_multik3)

In [111]:
# list(torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter)))
tokenized_train_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))
tokenized_test_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), test_iter))
tokenized_valid_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), valid_iter))

In [112]:
# len(list(tokenized_train_iter))

In [113]:
"""Position encoding"""
class PositionEncoding2(nn.Module):
    def __init__(self, d_embed):
        super(PositionEncoding2, self).__init__()
        self.d_embed = d_embed
        self.dropout = nn.Dropout(p=0.1)

    def get_position_encoding(self, seq_len):
        encoding = torch.zeros((seq_len, self.d_embed))
        dimensions = torch.arange(0, self.d_embed//2)
        timesteps = torch.arange(0, seq_len)

        encoding[:, 0::2] = torch.sin(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        encoding[:, 1::2] = torch.cos(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        
        return encoding
    
    def forward(self, inp):
        pos_encoding = nn.Parameter(self.get_position_encoding(seq_len=inp.shape[-2]), requires_grad=False)
        return self.dropout(inp + pos_encoding) # + dropout ?


In [114]:
"""Multihead attention"""
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads
        self.W_q = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_k = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_v = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))

        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = einops.einsum(self.W_q, q, 'h d_model d_h, b T d_model -> b h  T d_h') # T is timesteps (sequence length)
        k = einops.einsum(self.W_k, k, 'h d_model d_h, b T d_model -> b h  T d_h')
        v = einops.einsum(self.W_v, v, 'h d_model d_h, b T d_model -> b h  T d_h')

        # print("q_shape: ", q.shape)
        # print("k_shape: ", k.shape)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')
        # print(attn_head_outs.shape)
        # print(attn_scores, "mask: ", mask)
        # print("mask: ", mask)
        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [115]:
"""Multihead attention"""
class MultiHeadAttention2(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads
        # self.W_q = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )
        # self.W_k = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )
        # self.W_v = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )

        self.W_q = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_hidden)

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))
        # print("in q", k.shape)
        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = self.W_q(q)#einops.einsum(self.W_q, q, 'h d_model d_h, b T d_model -> b h T d_h') # T is timesteps (sequence length)
        k = self.W_k(k)#einops.einsum(self.W_k, k, 'h d_model d_h, b T d_model -> b h T d_h')
        v = self.W_v(v)#einops.einsum(self.W_v, v, 'h d_model d_h, b T d_model -> b h T d_h')
        # print("out1 q", k.shape)
        # q = einops.rearrange(q, 'b T d_h -> b T h k', h=self.n_heads)
        q = einops.rearrange(q, 'b T (h w) -> b h T w', h=self.n_heads)
        k = einops.rearrange(k, 'b T (h w) -> b h T w', h=self.n_heads)
        v = einops.rearrange(v, 'b T (h w) -> b h T w', h=self.n_heads)
        # print("q_shape: ", q.shape)
        # print("k_shape: ", k.shape)
        # print("out2 q", k.shape)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')
        # print(attn_head_outs.shape)
        # print(attn_scores, "mask: ", mask)
        # print("mask: ", mask)
        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [116]:
b, h, T_out, d_h = 2, 3, 4, 6
q = torch.ones((b, h, T_out, d_h)) # target seq
k = torch.ones((b, h, T_out+2, d_h)) # inp seq

mask = MultiHeadAttention2.empty_mask(n_heads=h, mask_shape=(q.shape[-2], k.shape[-2]))#(k.shape[-2], q.shape[-2]))
# mask = 
# print(mask)
attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )
attn_scores#.shape

tensor([[[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],


        [[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1

In [117]:
q.shape
q = torch.ones((b, T_out, d_h))
s = MultiHeadAttention2(d_model=d_h, d_hidden=d_h, n_heads=h)
ti = PositionEncoding2(d_embed=d_h)
q = ti(q)
q = s.W_q(q)
q.shape
print(q)
q = einops.rearrange(q, 'b T (h w) -> b h T w', h=h)#, w=d_h//h)
q.shape, q
print(q)

tensor([[[-0.0892,  0.3523,  0.3153, -1.1927,  0.8217, -0.6601],
         [-0.5424,  0.5256, -0.4009, -0.9617,  0.7779, -0.9696],
         [-0.1952,  1.3644, -0.2244, -1.2910,  0.5016, -0.7282],
         [-0.2089,  0.8486,  0.0332, -1.5708,  0.7290, -1.0882]],

        [[ 0.2547,  0.8348, -0.1010, -0.2742, -0.4685,  0.2282],
         [-0.1984,  1.0087, -0.8172, -0.0431, -0.5128, -0.0809],
         [-0.1952,  1.3644, -0.2244, -1.2910,  0.5016, -0.7282],
         [-0.1173,  1.1746, -0.0354, -1.3611,  0.2980, -0.7629]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-0.0892,  0.3523],
          [-0.5424,  0.5256],
          [-0.1952,  1.3644],
          [-0.2089,  0.8486]],

         [[ 0.3153, -1.1927],
          [-0.4009, -0.9617],
          [-0.2244, -1.2910],
          [ 0.0332, -1.5708]],

         [[ 0.8217, -0.6601],
          [ 0.7779, -0.9696],
          [ 0.5016, -0.7282],
          [ 0.7290, -1.0882]]],


        [[[ 0.2547,  0.8348],
          [-0.1984,  1.0087],
          [-0.19

In [118]:
class PosWiseFeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super(PosWiseFeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        # FFN(x) = max(0, xW1 + b1)W2 + b2 ; f(x) = max(0, x) is relu
        return self.fc2(self.relu(self.fc1(x)))

In [119]:
# Encoder and Decoder
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model=d_model, d_hidden=d_hidden, n_heads=num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None, is_causal=False):
        #print("Q, K, V: ", x.size())
        attn_output = self.self_attn(x, x, x, mask=mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        ff_output = self.feed_forward(x) # feed-forward
        x = self.norm2(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x
    

    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, d_hidden, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, d_hidden, num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask, is_causal=False):
        attn_output = self.self_attn(x, x, x, mask=tgt_mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        attn_output = self.cross_attn(q=x, k=enc_output, v=enc_output, mask=src_mask)
        x = self.norm2(x + self.dropout(attn_output)) # layer-norm + droput + skip connection
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x

In [120]:
class Transformer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff,src_vocab_size, tgt_vocab_size, dropout=0, n_enc_layers=6, n_dec_layers=6):
        super(Transformer, self).__init__()

        self.enc_layers = nn.ModuleList(
            [EncoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         num_heads=num_heads,
                         d_ff=d_ff) for _ in range(n_enc_layers)]
        )
        self.dec_layers = nn.ModuleList(
            [DecoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         d_ff=d_ff,
                         num_heads=num_heads) for _ in range(n_dec_layers)]
        )

        self.pos_encoding_layer = PositionEncoding2(d_embed=d_model)

        self.src_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=src_vocab_size, padding_idx=PAD_IDX)
        self.tgt_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=tgt_vocab_size, padding_idx=PAD_IDX)

        self.final = nn.Linear(in_features=d_hidden, out_features=tgt_vocab_size)

    def forward(self, inp_seq, tgt_seq, inp_mask=None, tgt_mask=None, is_causal_tgt=False):
        

        inp_seq = self.src_embedding(inp_seq)
        inp_seq = self.pos_encoding_layer(inp_seq)

        for enc in self.enc_layers:
            inp_seq = enc(inp_seq, mask=inp_mask, is_causal=False)

        # inp_seq = self.enc_layers(inp_seq, mask=inp_mask, is_causal=False)

        tgt_seq = self.src_embedding(tgt_seq)
        # tgt_seq = self.dec_layers(x=tgt_seq, enc_out=inp_seq, mask=tgt_mask, is_causal=is_causal_tgt)
        for dec in self.dec_layers:
            tgt_seq = dec(x=tgt_seq, enc_output=inp_seq, src_mask=inp_mask, tgt_mask=tgt_mask, is_causal=is_causal_tgt)

        out_logits = self.final(tgt_seq)
        return out_logits





In [121]:
t = Transformer(d_model=8, d_hidden=8, num_heads=4, d_ff=10, src_vocab_size=3, tgt_vocab_size=7)
src, tgt = torch.arange(0, 3).tile(2, 1), torch.arange(0, 3).tile(2, 1)

src.shape, t(src, tgt)#.shape

(torch.Size([2, 3]),
 tensor([[[ 0.4632, -0.7032, -0.8027,  0.1558, -0.2824, -0.8684, -0.4869],
          [ 1.1471, -0.0685, -0.1344,  0.0605, -0.5780, -0.6808, -0.3462],
          [ 1.3599, -0.2954, -0.0264,  0.1407, -0.1240, -0.5058, -0.4269]],
 
         [[ 0.2772, -0.9365, -0.7626, -0.4455, -0.2436, -0.6227, -0.7306],
          [ 1.0347, -0.5363, -0.3671, -0.3951, -0.3378, -0.9007, -0.5157],
          [ 1.0146, -0.5092, -0.3398,  0.2083, -0.4329, -0.6571, -0.3672]]],
        grad_fn=<ViewBackward0>))

In [122]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # print(batch)
    src_batch, tgt_batch = list(zip(*batch))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX).transpose(-1, -2)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX).transpose(-1, -2)

    return src_batch, tgt_batch 

p = list(zip(*tokenized_test_iter))
# list(tokenized_test_iter)

In [123]:
# de_tokenizer.decode((pad_sequence(p[0], padding_value=PAD_IDX).transpose(-1, -2))[1013].tolist())

In [124]:
# samp_batch = list(test_iter)#train_iter#list(train_iter)
# collated_samp = collate_fn(tokenized_test_iter)
# uncolled = list(tokenized_test_iter)


In [125]:
# collated_samp[0][len(collated_samp[0])-3]
# uncolled = list(tokenized_test_iter)
# l_iter = list(test_iter)

In [126]:
# len(uncolled), len(collated_samp[0]) #(1015, 42)

In [127]:
# uncolled[0], collated_samp[0][0], torch.max(collated_samp[0][len(collated_samp[0])-3])

In [128]:
# Simple task to overfit the model on; helps to catch dumb bugs in the architecture and training
text_transform_en('one two three four'), text_transform_en('five six seven eight')

(tensor([   2,  313,  400,  826, 1276,    3]),
 tensor([   2, 1859, 2613, 5024, 3934,    3]))

In [129]:
from torch.utils.data import DataLoader


D_MODEL = 512
D_FF = 2048
N_HEADS = 8

BATCH_SIZE = 1#64
DEVICE = 'cpu'
LR = 0.00001#math.sqrt(D_MODEL/4000) # make this a schedule
BETAS = (0.9, 0.98)

# train_dataloader = DataLoader(tokenized_train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# test_dataloader = DataLoader(tokenized_test_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# valid_dataloader = DataLoader(tokenized_valid_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)


def train_epoch(model, optimizer, data_loader=None):
    model.train()
    losses = 0

    tokenized_train_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))
    train_dataloader = DataLoader(tokenized_train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    data_count = 0
    import tqdm
    for src, tgt in tqdm.tqdm(train_dataloader):#data_loader):#train_dataloader):
        data_count+=1

        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        out_logits = model(src, tgt[:, :-1], is_causal_tgt=True)

        # print("inputs: ", de_tokenizer.decode(list(src.squeeze()), skip_special_tokens=False), en_tokenizer.decode(list(tgt[:, :-1].squeeze()), skip_special_tokens=False))


        # print("out_logits shape: ", out_logits.shape)
        # print("tgt shape: ", tgt.shape) b seq vcb

        optimizer.zero_grad()

        # print("out_logits: ", out_logits.transpose(-1, -2), out_logits.transpose(-1, -2).shape)
        # print("tgt: ", tgt )
        

        loss = loss_fn(out_logits.transpose(-1, -2), tgt[:, 1:])

        loss.backward()
        optimizer.step()
        losses += loss.item()

    print("Trained on %s datapts"%data_count)

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    data_count=0
    for src, tgt in val_dataloader:
        data_count+=1
        if data_count==50:
            break
        src = src.to(DEVICE).T
        tgt = tgt.to(DEVICE).T

        tgt_input = tgt[:-1, :]

        print("src", src.T.shape)
        print("tgt", tgt.T.shape)

        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        output = model(src, tgt[:,:-1])

        

        loss = loss_fn(output.transpose(-1, -2), tgt)
        # loss = criterion(output.contiguous().view(-1, TGT_VOCAB_SIZE), tgt[:, 1:].contiguous().view(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [130]:

model = Transformer(d_model=D_MODEL, 
                    d_hidden=D_MODEL, 
                    num_heads=N_HEADS, 
                    d_ff=D_FF, 
                    src_vocab_size=de_tokenizer.get_vocab_size(), 
                    tgt_vocab_size=en_tokenizer.get_vocab_size()
                    )

optimizer = torch.optim.Adam(betas=BETAS, eps=10e-9, lr=LR, params=model.parameters())


In [131]:
# de_tokenizer.get_vocab_size()
from timeit import default_timer as timer
NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)#, train_dataloader)
    end_time = timer()
    val_loss = 0#evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))

29001it [1:26:40,  5.58it/s]


Trained on 29001 datapts


ZeroDivisionError: float division by zero

In [220]:
# torch.save(model, 'oct2+1')
# model = torch.load('oct2+1')

In [221]:
q, k, mask = torch.ones((32, 8, 23, 64)), torch.ones((32, 8, 24, 64)), torch.ones((23, 24))
attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

In [222]:
(einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2)).shape# + mask

torch.Size([32, 8, 23, 24])

In [138]:
y_toks = [BOS_IDX]
random_sample = 80
src_toks = de_tokenizer.encode(de_data[random_sample]).ids
model.eval()
toks_generated=0
max_len=100
while y_toks[-1] != EOS_IDX and toks_generated < max_len:
    toks_generated+=1
    out_logits = model(torch.tensor(src_toks).unsqueeze(0), torch.tensor(y_toks).unsqueeze(0), is_causal_tgt=True)
    # print(out_logits)
    # print(out_logits[:,-1].shape, "sfdsfhadsfhsdkjhfSDFfdf")
    idx_predicted = torch.argmax(out_logits[:,-1])
    y_toks.append(idx_predicted)

In [139]:
(y_toks,
en_tokenizer.decode(y_toks),#[2456]),
en_data[random_sample])

([2,
  tensor(218),
  tensor(384),
  tensor(559),
  tensor(201),
  tensor(287),
  tensor(170),
  tensor(97),
  tensor(487),
  tensor(287),
  tensor(170),
  tensor(97),
  tensor(487),
  tensor(149),
  tensor(3)],
 'Two large dogs are playing on a beach playing on a beach .',
 'Two large tan dogs play along a sandy beach.')

In [135]:
torch.save(model, 'epoch_1_multiheadattention1_translator_de_en')

In [226]:
# torch.zeros((2, 2, 3, 3)) + MultiHeadAttention.causal_mask(n_heads=2, mask_shape=(3,3))