In [2]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

import torch
import torch.nn as nn
import math
import einops
from tokenizers import CharBPETokenizer




# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'



In [3]:
# Import huggingface char-bpe tokenizer
en_tokenizer, de_tokenizer = CharBPETokenizer(), CharBPETokenizer()

special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
en_tokenizer.add_special_tokens(special_symbols)
de_tokenizer.add_special_tokens(special_symbols)

# Train tokenizers
train_iter, test_iter, valid_iter = Multi30k(language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_data, en_data = (list(zip(*train_iter)))
en_tokenizer.train_from_iterator(iterator=en_data)
de_tokenizer.train_from_iterator(iterator=de_data)

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = tuple(en_tokenizer.encode(x).ids[0] for x in special_symbols)


################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################











In [4]:
en_tokenizer.encode(en_data[0])#.tokens

Encoding(num_tokens=11, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [5]:
text_transform_en = lambda x : torch.tensor([BOS_IDX] + en_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])
text_transform_de = lambda x : torch.tensor([BOS_IDX] + de_tokenizer.encode(x.rstrip("\n")).ids + [EOS_IDX])



In [6]:
# torch.tensor(list(train_iter), dtype=torch.StringType)
# type(train_iter)
# (list(zip(*train_iter)))[0][9]
# en_data, de_data = (list(zip(*train_iter)))
# de_data_tokenized, en_data_tokenized = 
# map(lambda x : (x[0], x[1]), train_iter)#list(map(text_transform_de, de_data)), list(map(text_transform_de, de_data))
# # de_data_tokenized
# tokenized_multik3 = list(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))

In [7]:
# len(tokenized_multik3), len(list(train_iter)), type(train_iter)
# torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(tokenized_multik3)

In [8]:
# list(torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter)))
tokenized_train_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))
tokenized_test_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), test_iter))
tokenized_valid_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), valid_iter))

In [9]:
# len(list(tokenized_train_iter))

In [10]:
"""Position encoding"""
class PositionEncoding2(nn.Module):
    def __init__(self, d_embed):
        super(PositionEncoding2, self).__init__()
        self.d_embed = d_embed
        self.dropout = nn.Dropout(p=0.1)

    def get_position_encoding(self, seq_len):
        encoding = torch.zeros((seq_len, self.d_embed))
        dimensions = torch.arange(0, self.d_embed//2)
        timesteps = torch.arange(0, seq_len)

        encoding[:, 0::2] = torch.sin(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        encoding[:, 1::2] = torch.cos(torch.einsum('i,j -> ji', 1/(10000**(2*dimensions/self.d_embed)), timesteps))
        
        return encoding
    
    def forward(self, inp):
        pos_encoding = nn.Parameter(self.get_position_encoding(seq_len=inp.shape[-2]), requires_grad=False)
        return self.dropout(inp + pos_encoding) # + dropout ?


In [11]:
"""Multihead attention"""
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads
        self.W_q = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_k = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )
        self.W_v = nn.Parameter(
            nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden//n_heads)))
        )

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))

        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = einops.einsum(self.W_q, q, 'h d_model d_h, b T d_model -> b h  T d_h') # T is timesteps (sequence length)
        k = einops.einsum(self.W_k, k, 'h d_model d_h, b T d_model -> b h  T d_h')
        v = einops.einsum(self.W_v, v, 'h d_model d_h, b T d_model -> b h  T d_h')

        # print("q_shape: ", q.shape)
        # print("k_shape: ", k.shape)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')
        # print(attn_head_outs.shape)
        # print(attn_scores, "mask: ", mask)
        # print("mask: ", mask)
        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [46]:
"""Multihead attention"""
class MultiHeadAttention2(nn.Module):
    def __init__(self, d_model, d_hidden, n_heads):
        super().__init__()#MultiheadAttention)
        self.d_hidden=d_hidden
        self.n_heads = n_heads
        # self.W_q = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )
        # self.W_k = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )
        # self.W_v = nn.Parameter(
        #     nn.init.xavier_normal_(torch.zeros((n_heads, d_model, d_hidden)))
        # )

        self.W_q = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_hidden)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_hidden)

        self.W_o = nn.Linear(in_features=d_hidden, out_features=d_model, bias=False)

    def forward(self, q, k, v, get_attn_scores=False, mask=None, is_causal=False):
        mask_shape = (q.shape[-2], k.shape[-2]) 
        # print("mask shape: ", mask_shape)
        mask = mask if mask else (self.causal_mask(self.n_heads, mask_shape) if is_causal else self.empty_mask(self.n_heads, mask_shape))
        # print("in q", k.shape)
        # Compute Q, K, V (matrix multiplication, batched along heads)
        q = self.W_q(q)#einops.einsum(self.W_q, q, 'h d_model d_h, b T d_model -> b h T d_h') # T is timesteps (sequence length)
        k = self.W_k(k)#einops.einsum(self.W_k, k, 'h d_model d_h, b T d_model -> b h T d_h')
        v = self.W_v(v)#einops.einsum(self.W_v, v, 'h d_model d_h, b T d_model -> b h T d_h')
        # print("out1 q", k.shape)
        # q = einops.rearrange(q, 'b T d_h -> b T h k', h=self.n_heads)
        q = einops.rearrange(q, 'b T (h w) -> b h T w', h=self.n_heads)
        k = einops.rearrange(k, 'b T (h w) -> b h T w', h=self.n_heads)
        v = einops.rearrange(v, 'b T (h w) -> b h T w', h=self.n_heads)
        # print("q_shape: ", q.shape)
        # print("k_shape: ", k.shape)
        # print("out2 q", k.shape)

        attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(self.d_hidden//self.n_heads) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

        attn_head_outs = einops.einsum(attn_scores, v, '... T_out T_in, ... T_in d_h -> ... T_out d_h')
        # print(attn_head_outs.shape)
        # print(attn_scores, "mask: ", mask)
        # print("mask: ", mask)
        concatted_outs = einops.rearrange(attn_head_outs, 'b h T_out d_h -> b T_out (h d_h)')
        concatted_outs = self.W_o(concatted_outs)

        return (concatted_outs, attn_scores) if get_attn_scores else concatted_outs
    
    @classmethod
    def causal_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.triu(
                torch.fill(torch.zeros(mask_shape),  -torch.inf),
                diagonal=1
                )
            , pattern='... -> k ...', k=n_heads
            )
    
    @classmethod
    def empty_mask(cls, n_heads, mask_shape):
        return einops.repeat(
            torch.zeros(mask_shape) , pattern='... -> k ...', k=n_heads
            )

In [47]:
b, h, T_out, d_h = 2, 3, 4, 6
q = torch.ones((b, h, T_out, d_h)) # target seq
k = torch.ones((b, h, T_out+2, d_h)) # inp seq

mask = MultiHeadAttention2.empty_mask(n_heads=h, mask_shape=(q.shape[-2], k.shape[-2]))#(k.shape[-2], q.shape[-2]))
# mask = 
# print(mask)
attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )
attn_scores#.shape

tensor([[[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],


        [[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
          [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],

         [[0.1667, 0.1667, 0.1

In [48]:
q.shape
q = torch.ones((b, T_out, d_h))
s = MultiHeadAttention2(d_model=d_h, d_hidden=d_h, n_heads=h)
ti = PositionEncoding2(d_embed=d_h)
q = ti(q)
q = s.W_q(q)
q.shape
print(q)
q = einops.rearrange(q, 'b T (h w) -> b h T w', h=h)#, w=d_h//h)
q.shape, q
print(q)

tensor([[[-1.0915, -0.5712,  1.2721, -0.3021,  0.7753,  0.3741],
         [-0.1946, -0.4819,  0.8439,  0.3889, -0.3510, -0.6231],
         [-0.8813, -0.0814,  1.3627,  0.1860,  0.1811, -0.1485],
         [-0.7180, -0.7588,  1.0270, -0.4516,  0.7069,  0.0185]],

        [[-0.6276, -0.3813,  0.9738, -0.0662,  0.3060,  0.1939],
         [-0.8500, -0.8558,  0.8711, -0.5391,  0.8880,  0.7814],
         [-1.5107,  0.5947,  1.3644,  0.5068, -0.0897, -0.0935],
         [-0.2300,  0.1176,  0.9877,  0.3936,  0.5128,  0.4988]]],
       grad_fn=<ViewBackward0>)
tensor([[[[-1.0915, -0.5712],
          [-0.1946, -0.4819],
          [-0.8813, -0.0814],
          [-0.7180, -0.7588]],

         [[ 1.2721, -0.3021],
          [ 0.8439,  0.3889],
          [ 1.3627,  0.1860],
          [ 1.0270, -0.4516]],

         [[ 0.7753,  0.3741],
          [-0.3510, -0.6231],
          [ 0.1811, -0.1485],
          [ 0.7069,  0.0185]]],


        [[[-0.6276, -0.3813],
          [-0.8500, -0.8558],
          [-1.51

In [49]:
class PosWiseFeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super(PosWiseFeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        # FFN(x) = max(0, xW1 + b1)W2 + b2 ; f(x) = max(0, x) is relu
        return self.fc2(self.relu(self.fc1(x)))

In [207]:
# Encoder and Decoder
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model=d_model, d_hidden=d_hidden, n_heads=num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None, is_causal=False):
        #print("Q, K, V: ", x.size())
        attn_output = self.self_attn(x, x, x, mask=mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        ff_output = self.feed_forward(x) # feed-forward
        x = self.norm2(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x
    

    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff, dropout=0.1, N_layers=6):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, d_hidden, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, d_hidden, num_heads)
        self.feed_forward = PosWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask, is_causal=False):
        attn_output = self.self_attn(x, x, x, mask=tgt_mask, is_causal=is_causal) # self attention
        x = self.norm1(x + self.dropout(attn_output)) # layer-norm + dropout + skip connection
        attn_output = self.cross_attn(q=x, k=enc_output, v=enc_output, mask=src_mask)
        x = self.norm2(x + self.dropout(attn_output)) # layer-norm + droput + skip connection
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output)) # layer-norm + dropout + skip connection
        return x

In [208]:
class Transformer(nn.Module):
    def __init__(self, d_model, d_hidden, num_heads, d_ff,src_vocab_size, tgt_vocab_size, dropout=0, n_enc_layers=6, n_dec_layers=6):
        super(Transformer, self).__init__()
        # self.enc_layers = nn.Sequential(
        #     *(EncoderLayer(d_model=d_model,
        #                  d_hidden=d_hidden,
        #                  num_heads=num_heads,
        #                  d_ff=d_ff) for _ in range(n_enc_layers))
        # )
        # self.dec_layers = nn.Sequential(
        #     *(DecoderLayer(d_model=d_model,
        #                  d_hidden=d_hidden,
        #                  d_ff=d_ff,
        #                  num_heads=num_heads) for _ in range(n_enc_layers))
        # )

        self.enc_layers = nn.ModuleList(
            [EncoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         num_heads=num_heads,
                         d_ff=d_ff) for _ in range(n_enc_layers)]
        )
        self.dec_layers = nn.ModuleList(
            [DecoderLayer(d_model=d_model,
                         d_hidden=d_hidden,
                         d_ff=d_ff,
                         num_heads=num_heads) for _ in range(n_dec_layers)]
        )

        self.pos_encoding_layer = PositionEncoding2(d_embed=d_model)

        self.src_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=src_vocab_size, padding_idx=PAD_IDX)
        self.tgt_embedding = nn.Embedding(embedding_dim=d_model, num_embeddings=tgt_vocab_size, padding_idx=PAD_IDX)

        self.final = nn.Linear(in_features=d_hidden, out_features=tgt_vocab_size)

    def forward(self, inp_seq, tgt_seq, inp_mask=None, tgt_mask=None, is_causal_tgt=False):
        

        inp_seq = self.src_embedding(inp_seq)
        inp_seq = self.pos_encoding_layer(inp_seq)

        for enc in self.enc_layers:
            inp_seq = enc(inp_seq, mask=inp_mask, is_causal=False)

        # inp_seq = self.enc_layers(inp_seq, mask=inp_mask, is_causal=False)

        tgt_seq = self.src_embedding(tgt_seq)
        # tgt_seq = self.dec_layers(x=tgt_seq, enc_out=inp_seq, mask=tgt_mask, is_causal=is_causal_tgt)
        for dec in self.dec_layers:
            tgt_seq = dec(x=tgt_seq, enc_output=inp_seq, src_mask=inp_mask, tgt_mask=tgt_mask, is_causal=is_causal_tgt)

        out_logits = self.final(tgt_seq)
        return out_logits





In [209]:
t = Transformer(d_model=8, d_hidden=8, num_heads=4, d_ff=10, src_vocab_size=3, tgt_vocab_size=7)
src, tgt = torch.arange(0, 3).tile(2, 1), torch.arange(0, 3).tile(2, 1)

src.shape, t(src, tgt)#.shape

(torch.Size([2, 3]),
 tensor([[[-0.5123,  0.6123,  0.4138, -0.6868, -0.0120, -0.6168,  0.0660],
          [-0.2988,  0.8658,  0.3645, -0.4842, -0.2050, -0.7479, -0.1059],
          [-0.4016,  0.6682,  0.4790, -0.5985,  0.2125, -0.4946, -0.1214]],
 
         [[-0.7274,  0.3957,  0.1740, -0.6914,  0.0929, -0.6376,  0.4161],
          [-0.7089,  0.5095,  0.2555, -0.6303,  0.3580, -0.3460,  0.2380],
          [-0.7822,  0.0296,  0.3872, -0.7767,  0.3205, -0.1270,  0.5384]]],
        grad_fn=<ViewBackward0>))

In [210]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    # print(batch)
    src_batch, tgt_batch = list(zip(*batch))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX).transpose(-1, -2)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX).transpose(-1, -2)

    return src_batch, tgt_batch 

p = list(zip(*tokenized_test_iter))
# list(tokenized_test_iter)

In [211]:
# de_tokenizer.decode((pad_sequence(p[0], padding_value=PAD_IDX).transpose(-1, -2))[1013].tolist())

In [212]:
# samp_batch = list(test_iter)#train_iter#list(train_iter)
# collated_samp = collate_fn(tokenized_test_iter)
# uncolled = list(tokenized_test_iter)


In [213]:
# collated_samp[0][len(collated_samp[0])-3]
# uncolled = list(tokenized_test_iter)
# l_iter = list(test_iter)

In [214]:
# len(uncolled), len(collated_samp[0]) #(1015, 42)

In [215]:
# uncolled[0], collated_samp[0][0], torch.max(collated_samp[0][len(collated_samp[0])-3])

In [216]:
# for t in train_dataloader:
#     print(t)
text_transform_en('one two three four'), text_transform_en('five six seven eight')

(tensor([   2,  313,  400,  826, 1276,    3]),
 tensor([   2, 1859, 2613, 5024, 3934,    3]))

In [217]:
from torch.utils.data import DataLoader


D_MODEL = 512
D_FF = 2048
N_HEADS = 8

BATCH_SIZE = 1#64
DEVICE = 'cpu'
LR = 0.00001#math.sqrt(D_MODEL/4000) # make this a schedule
BETAS = (0.9, 0.98)

# train_dataloader = DataLoader(tokenized_train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# test_dataloader = DataLoader(tokenized_test_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
# valid_dataloader = DataLoader(tokenized_valid_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)


def train_epoch(model, optimizer, data_loader=None):
    model.train()
    losses = 0
    # train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # tokenized_train_iter = ('one two three four', 'five six seven eight')#torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))
    # tokenized_train_iter = [(torch.Tensor([   2,  313,  400,  826, 1276,    3]).long(), torch.Tensor([   2, 1357, 1694,  210,  139, 1016, 5501, 8206, 1386,    3]).long())]
    # tokenized_train_iter = [(torch.Tensor([      2,  313,  400,  826, 1276,    3]).long(), torch.Tensor([   2, 1859, 2613, 5024, 3934,    3]).long())]
    tokenized_train_iter = torch.utils.data.datapipes.iter.sharding.ShardingFilterIterDataPipe(map(lambda x : (text_transform_de(x[0]), text_transform_en(x[1])), train_iter))

    train_dataloader = DataLoader(tokenized_train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)#[PAD_IDX, BOS_IDX])

    data_count = 0
    import tqdm
    for src, tgt in tqdm.tqdm(train_dataloader):#data_loader):#train_dataloader):
        data_count+=1
        if data_count==100:
            break
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        out_logits = model(src, tgt[:, :-1], is_causal_tgt=True)

        # print("inputs: ", de_tokenizer.decode(src.squeeze().tolist()), en_tokenizer.decode(tgt[:, :-1]).squeeze().tolist())
        print("inputs: ", de_tokenizer.decode(list(src.squeeze()), skip_special_tokens=False), en_tokenizer.decode(list(tgt[:, :-1].squeeze()), skip_special_tokens=False))


        # print("out_logits shape: ", out_logits.shape)
        # print("tgt shape: ", tgt.shape) b seq vcb

        optimizer.zero_grad()

        # print("out_logits: ", out_logits.transpose(-1, -2), out_logits.transpose(-1, -2).shape)
        # print("tgt: ", tgt )
        

        loss = loss_fn(out_logits.transpose(-1, -2), tgt[:, 1:])

        loss.backward()
        optimizer.step()
        losses += loss.item()

    print("Trained on %s datapts"%data_count)

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0
    loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    data_count=0
    for src, tgt in val_dataloader:
        data_count+=1
        if data_count==50:
            break
        src = src.to(DEVICE).T
        tgt = tgt.to(DEVICE).T

        tgt_input = tgt[:-1, :]

        print("src", src.T.shape)
        print("tgt", tgt.T.shape)

        # src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        # logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        output = model(src, tgt[:,:-1])

        

        loss = loss_fn(output.transpose(-1, -2), tgt)
        # loss = criterion(output.contiguous().view(-1, TGT_VOCAB_SIZE), tgt[:, 1:].contiguous().view(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [218]:

model = Transformer(d_model=D_MODEL, 
                    d_hidden=D_MODEL, 
                    num_heads=N_HEADS, 
                    d_ff=D_FF, 
                    src_vocab_size=de_tokenizer.get_vocab_size(), 
                    tgt_vocab_size=en_tokenizer.get_vocab_size()
                    )

optimizer = torch.optim.Adam(betas=BETAS, eps=10e-9, lr=LR, params=model.parameters())


In [219]:
# de_tokenizer.get_vocab_size()
from timeit import default_timer as timer
NUM_EPOCHS = 1

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(model, optimizer)#, train_dataloader)
    end_time = timer()
    val_loss = 0#evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))



inputs:  <bos>Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche . <eos> <bos>Two young , White males are outside near many bushes .


1it [00:00,  1.25it/s]

inputs:  <bos>Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem . <eos> <bos>Several men in hard hats are operating a giant pulley system .


2it [00:01,  1.91it/s]

inputs:  <bos>Ein kleines Mädchen klettert in ein Spielhaus aus Holz . <eos> <bos>A little girl climbing into a wooden playhouse .


3it [00:01,  1.96it/s]

inputs:  <bos>Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster . <eos> <bos>A man in a blue shirt is standing on a ladder cleaning a window .


4it [00:01,  2.37it/s]

inputs:  <bos>Zwei Männer stehen am Herd und bereiten Essen zu . <eos> <bos>Two men are at the stove preparing food .


6it [00:02,  2.68it/s]

inputs:  <bos>Ein Mann in grün hält eine Gitarre , während der andere Mann sein Hemd ansieht . <eos> <bos>A man in green holds a guitar while the other man observes his shirt .


7it [00:02,  3.17it/s]

inputs:  <bos>Ein Mann lächelt einen ausgestopften Löwen an . <eos> <bos>A man is smiling at a stuffed lion
inputs:  <bos>Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt . <eos> <bos>A trendy girl talking on her cellphone while gliding slowly down the street .


8it [00:03,  3.35it/s]

inputs:  <bos>Eine Frau mit einer großen Geldbörse geht an einem Tor vorbei . <eos> <bos>A woman with a large purse is walking by a gate .


9it [00:03,  3.29it/s]

inputs:  <bos>Jungen tanzen mitten in der Nacht auf Pfosten . <eos> <bos>Boys dancing on poles in the middle of the night .


11it [00:03,  3.47it/s]

inputs:  <bos>Eine Ballettklasse mit fünf Mädchen , die nacheinander springen . <eos> <bos>A ballet class of five girls jumping in sequence .


12it [00:04,  3.76it/s]

inputs:  <bos>Vier Typen , von denen drei Hüte tragen und einer nicht , springen oben in einem Treppenhaus . <eos> <bos>Four guys three wearing hats one not are jumping at the top of a staircase .


13it [00:04,  3.63it/s]

inputs:  <bos>Ein schwarzer Hund und ein gefleckter Hund kämpfen . <eos> <bos>A black dog and a spotted dog are fighting
inputs:  <bos>Ein Mann in einer neongrünen und orangefarbenen Uniform fährt auf einem grünen Traktor . <eos> <bos>A man in a neon green and orange uniform is driving on a green tractor .


15it [00:04,  3.91it/s]

inputs:  <bos>Mehrere Frauen warten in einer Stadt im Freien . <eos> <bos>Several women wait outside in a city .


16it [00:05,  4.21it/s]

inputs:  <bos>Eine Frau mit schwarzem Oberteil und Brille streut Puderzucker auf einem Gugelhupf . <eos> <bos>A lady in a black top with glasses is sprinkling powdered sugar on a bundt cake .
inputs:  <bos>Ein kleines Mädchen sitzt vor einem großen gemalten Regenbogen . <eos> <bos>A little girl is sitting in front of a large painted rainbow .


18it [00:05,  4.96it/s]

inputs:  <bos>Ein Mann liegt auf der Bank , an die auch ein weißer Hund angebunden ist . <eos> <bos>A man lays on the bench to which a white dog is also tied .
inputs:  <bos>Fünf Personen sitzen mit Instrumenten im Kreis . <eos> <bos>Five people are sitting in a circle with instruments .


19it [00:05,  4.76it/s]

inputs:  <bos>Eine Gruppe älterer Frauen spielt zusammen Klarinette von Notenblättern . <eos> <bos>A bunch of elderly women play their clarinets together as they read off sheet music .


20it [00:06,  3.08it/s]

inputs:  <bos>Ein großes Bauwerk ist kaputt gegangen und liegt auf einer Fahrbahn . <eos> <bos>A large structure has broken and is laying in a roadway .


22it [00:06,  3.39it/s]

inputs:  <bos>Eine große Menschenmenge steht außen vor dem Eingang einer Metrostation . <eos> <bos>A large crowd of people stand outside in front of the entrance to a Metro station .


23it [00:07,  3.72it/s]

inputs:  <bos>Ein Mann , der ein Tattoo auf seinem Rücken erhält . <eos> <bos>A man getting a tattoo on his back .
inputs:  <bos>Zwei Kinder sitzen auf einer kleinen Wippe im Sand . <eos> <bos>Two children sit on a small seesaw in the sand .


25it [00:07,  4.13it/s]

inputs:  <bos>Ein Mann , der eine reflektierende Weste und einen Schutzhelm trägt , hält eine Flagge in die Straße . <eos> <bos>A man wearing a reflective vest and a hard hat holds a flag in the road
inputs:  <bos>Eine Person in einem blauen Mantel steht auf einem belebten Gehweg und betrachtet ein Gemälde einer Straßenszene . <eos> <bos>A person dressed in a blue coat is standing in on a busy sidewalk , studying painting of a street scene .


27it [00:07,  4.18it/s]

inputs:  <bos>Ein Mann in grünen Hosen läuft die Straße entlang . <eos> <bos>A man in green pants walking down the road .


28it [00:08,  4.55it/s]

inputs:  <bos>Das kleine Kind klettert an roten Seilen auf einem Spielplatz . <eos> <bos>The small child climbs on a red ropes on a playground .
inputs:  <bos>Du weißt , dass ich aussehe wie Justin Bieber . <eos> <bos>You know i am looking like Justin Bieber .


30it [00:08,  5.05it/s]

inputs:  <bos>Ein junger Mann in einer schwarz - gelben Jacke blickt etwas an und lächelt . <eos> <bos>A young man in a black and yellow jacket is gazing at something and smiling .
inputs:  <bos>Ein Mann , der mit einer Tasse Kaffee an einem Urinal steht . <eos> <bos>A man standing at a urinal with a coffee cup .


32it [00:08,  5.38it/s]

inputs:  <bos>Fünf gehende Personen mit einem mehrfarbigen Himmel im Hintergrund . <eos> <bos>Five people walking with a multicolored sky in the background .
inputs:  <bos>Ein alter Mann , der allein ein Bier trinkt . <eos> <bos>A old man having a beer alone .


34it [00:09,  5.03it/s]

inputs:  <bos>Ein geschulter Polizeihund sitzt neben dem Hundeführer vor dem Polizeitransporter . <eos> <bos>A trained police dog sits next to his handler in front of the police van .
inputs:  <bos>Eine Person fährt auf einer verschneiten Straße Fahrrad . <eos> <bos>A person riding a bike on a snowy road .


36it [00:09,  4.91it/s]

inputs:  <bos>Fünf Männer , die alle weiße Hemden , Krawatten und schwarze Freizeithosen tragen , unterhalten sich hinter einem Lieferwagen . <eos> <bos>Five men , uniformly dressed in white shirts , tie and black slacks converse at the back of an open van .


37it [00:09,  4.90it/s]

inputs:  <bos>Ein Mann mit einem nach hinten gerichteten Hut arbeitet an Maschinen . <eos> <bos>A man with a backwards hat works on machinery .


38it [00:10,  4.91it/s]

inputs:  <bos>Eine schwarze Frau und ein weißer Mann arbeiten in einer Fabrikumgebung und packen Gläser mit Kerzen in Kartons . <eos> <bos>A black woman and a white man working in a factory setting packing jars with candles into boxes .
inputs:  <bos>Ein asiatischer Mann kehrt den Gehweg . <eos> <bos>Asian man sweeping the walkway .


40it [00:10,  4.78it/s]

inputs:  <bos>Ein Mann lehnt sich in ein Auto , um mit dem Fahrer zu reden , während ein Mann auf einem Fahrrad zusieht . <eos> <bos>A man leans into a car to talk to the driver , as a man on a bicycle looks on .
inputs:  <bos>Zwei Kleinkinder im Freien auf dem Gras . <eos> <bos>Two young toddlers outside on the grass .


42it [00:10,  5.14it/s]

inputs:  <bos>Leute sehen einer Person in einem seltsamen Fahrzeug auf einem Platz zu . <eos> <bos>People are watching a person in a weird vehicle in a plaza .


43it [00:11,  4.69it/s]

inputs:  <bos>Ein Mann geht an einem silbernen Fahrzeug vorbei . <eos> <bos>A man walks by a silver vehicle .


44it [00:11,  4.90it/s]

inputs:  <bos>Eine schöne Braut geht auf einem Gehweg mit ihrem neuen Ehemann . <eos> <bos>A beautiful bride walking on a sidewalk with her new husband .
inputs:  <bos>Ein kleiner Junge spielt bei McDonald ' s GameCube . <eos> <bos>A little boy playing GameCube at a McDonald ' s .


46it [00:11,  4.96it/s]

inputs:  <bos>Ein weißer Hund schüttelt sich am Rande eines Strands mit einem orangefarbenen Ball . <eos> <bos>A white dog shakes on the edge of a beach with an orange ball .
inputs:  <bos>Eine Gruppe von Personen , die im Park grillen . <eos> <bos>A group of people having a barbecue at a park .


48it [00:12,  4.86it/s]

inputs:  <bos>Ein Mann mit Sonnenbrille legt seinen Arm um eine Frau in einer schwarz - weißen Bluse . <eos> <bos>A man in sunglasses puts his arm around a woman in a black and white blouse .


49it [00:12,  5.02it/s]

inputs:  <bos>Ein Mann mit einem Luftballonhut und Leute , die im Freien an Picknicktischen essen . <eos> <bos>A man with a balloon hat and people eating outdoors at picnic tables .
inputs:  <bos>Ein Junge , der während eines Taekwondo - Wettbewerbs einen Sprungtritt über drei Kinder macht und dabei auf Holz tritt . <eos> <bos>A boy jump kicking over three kids kicking wood during a tae kwon do competition .


51it [00:12,  5.35it/s]

inputs:  <bos>Ein Junge in einer roten Jacke , der Wasser auf einen Mann in einem weißen Hemd gießt . <eos> <bos>A boy in a red jacket pouring water on a man in a white shirt
inputs:  <bos>Ein Mann mit einer roten Jacke , der sich vor der Sonne schützt und versucht , ein Stück Papier zu lesen . <eos> <bos>A man with a red jacket is shielding himself from the sun trying to read a piece of paper .


53it [00:13,  4.81it/s]

inputs:  <bos>Männer , die eine Straße mit Kindern entlang laufen . <eos> <bos>Men walking down a street with children .
inputs:  <bos>Ein kleiner Junge , der auf der Straße steht , während ein Mann in einem Overall an einer Steinwand arbeitet . <eos> <bos>A little boy is standing on the street while a man in overalls is working on a stone wall .


55it [00:13,  4.80it/s]

inputs:  <bos>Ein schwarzer Hund springt über einen Baumstamm . <eos> <bos>A black dog leaps over a log .


56it [00:13,  4.94it/s]

inputs:  <bos>Ein Mann in einem Anzug rennt an zwei anderen Herren vorbei , die auch einen Anzug tragen . <eos> <bos>A man in a suit is running past two other gentleman , also dressed in a suit .
inputs:  <bos>Ein Mann in einem roten Hemd , der mit dem Fahrrad um Wasser herum fährt . <eos> <bos>Man in a red shirt riding his bicycle around water .


58it [00:14,  4.96it/s]

inputs:  <bos>Ein Mann , der barfuß ist , olivgrüne kurze Hosen trägt , auf einem kleinen Propangasgrill Hotdogs grillt und gleichzeitig eine blaue Kunststofftasse hält . <eos> <bos>A barefooted man wearing olive green shorts grilling hotdogs on a small propane grill while holding a blue plastic cup .
inputs:  <bos>Ein Hund rennt im Schnee . <eos> <bos>A dog is running in the snow


60it [00:14,  5.21it/s]

inputs:  <bos>Eine Menschenmenge steht und wartet , bis die Ampel grün wird . <eos> <bos>A crowd is standing and waiting for the green light .
inputs:  <bos>Mann auf Skiern , das zu verkaufende Kunstwerke im Schnee betrachtet . <eos> <bos>Man on skis looking at artwork for sale in the snow


62it [00:14,  5.46it/s]

inputs:  <bos>Sieben Kletterer klettern eine Felswand hoch , während ein anderer Mann dasteht und das Seil hält . <eos> <bos>Seven climbers are ascending a rock face whilst another man stands holding the rope .
inputs:  <bos>Der gelenkige Körper des jungen Turners schwebt über dem Schwebebalken . <eos> <bos>The young gymnast ' s supple body soars above the balance beam .


64it [00:15,  5.41it/s]

inputs:  <bos>Ein Junge schiebt ein Spielzeug - Geländefahrzeug um einen Gummi - Pool . <eos> <bos>A young boy is pushing a toy ATV around a rubber pool
inputs:  <bos>Eine Frau in einer roten Windjacke , die über ein auf einem Dach installiertes Fernrohr auf die darunterliegende Stadt blickt . <eos> <bos>Woman in red windbreaker looking though a rooftop binoculars at the city below .


66it [00:15,  4.87it/s]

inputs:  <bos>Ein Mann steht vor einem kleinen roten Objekt , das wie ein Flugzeug aussieht . <eos> <bos>A man is standing in front of a small red object that looks like a plane .
inputs:  <bos>Ein Hund spielt mit einem Schlauch . <eos> <bos>A dog is playing with a hose .


68it [00:16,  5.17it/s]

inputs:  <bos>Ein Mann und ein kleines Mädchen posieren glücklich vor ihrem Einkaufswagen im Supermarkt . <eos> <bos>A man and a little girl happily posing in front of their cart in a supermarket .
inputs:  <bos>Ein weißer Hund ist kurz davor , ein gelbes Hundespielzeug zu fangen . <eos> <bos>A white dog is about to catch a yellow dog toy .


70it [00:16,  5.43it/s]

inputs:  <bos>Ein Kerl in einem grünen Hemd , dessen Hand einen Teil seines Gesichts bedeckt , in einer Nische im Restaurant . <eos> <bos>Guy in green shirt with hand covering part of his face in restaurant booth .
inputs:  <bos>Ein schwarz - weißer Hund spring zu einem gelben Spielzeug hoch . <eos> <bos>A black and white dog jumps up towards a yellow toy .


72it [00:16,  5.51it/s]

inputs:  <bos>Zwei Wanderer machen bei einem Stückchen Schnee Pause . <eos> <bos>Two hikers resting by a patch of snow .
inputs:  <bos>Ein Mann führt seine neue hölzerne Kreation vor . <eos> <bos>A man showing off his new wooden creation .


74it [00:17,  5.33it/s]

inputs:  <bos>Ein älterer Vater und sein erwachsener Sohn bereiten sich auf einen Camping - Ausflug in der Wildnis vor . <eos> <bos>A elderly father and his grown son are preparing for a camping trip in the wild .
inputs:  <bos>Ein Reisender mit Bart in einem roten Hemd , der in einem Auto sitzt und eine Karte liest . <eos> <bos>A bearded traveler in a red shirt sitting in a car and reading a map .


76it [00:17,  5.05it/s]

inputs:  <bos>Ein Junge winkt einer Ente im Wasser zu , umgeben von einer Grünanlage . <eos> <bos>A young boy waves his hand at the duck in the water surrounded by a green park .


77it [00:17,  4.92it/s]

inputs:  <bos>Ein Paar sitzt mit Baby und Sportwagen im Gras . <eos> <bos>A couple sit on the grass with a baby and stroller .
inputs:  <bos>Ein paar Männer stehen vor einem Gebäude neben einem parkenden Auto . <eos> <bos>Some men standing in front of a building next to a parked car .


79it [00:18,  4.63it/s]

inputs:  <bos>Der schwarze Hund rennt durch das Wasser . <eos> <bos>The black dog runs through the water .
inputs:  <bos>Ein Mann bohrt durch das gefrorene Eis eines Teichs . <eos> <bos>A man is drilling through the frozen ice of a pond .


81it [00:18,  4.62it/s]

inputs:  <bos>Zwei große lohfarbene Hunde spielen an einem sandigen Strand . <eos> <bos>Two large tan dogs play along a sandy beach .
inputs:  <bos>Eine Person in blau und rot , die mit zwei Pickeln eisklettert . <eos> <bos>A person in blue and red ice climbing with two picks .


82it [00:18,  4.91it/s]

inputs:  <bos>Drei Personen , die auf einem Pfad in einer Wiese gehen . <eos> <bos>Three people walking on a path in a meadow .


84it [00:19,  4.45it/s]

inputs:  <bos>Ein Mann in Schwarz schaufelt Schnee auf die Straße und ignoriert dabei die öffentliche Sicherheit . <eos> <bos>A man in black attire shovels snow into the street , disregarding all public safety .
inputs:  <bos>Ein Paar steht hinter seiner Hochzeitstorte . <eos> <bos>A couple stands behind their wedding cake .


86it [00:19,  4.35it/s]

inputs:  <bos>Ein nasser schwarzer Hund trägt ein grünes Spielzeug durch das Gras . <eos> <bos>A wet black dog is carrying a green toy through the grass .
inputs:  <bos>Dorfbewohner verkaufen ihre Ernte auf dem Markt . <eos> <bos>Villagers selling their crops at the market .


88it [00:20,  3.25it/s]

inputs:  <bos>In einem vollen Konzert nähert sich ein Mann dem Hauptsänger , der ein gelbes Hemd trägt . <eos> <bos>In a crowded concert a man in white is approaching the main singer who is wearing a yellow shirt .


89it [00:20,  3.69it/s]

inputs:  <bos>Ein Junge springt auf seinem Skateboard , und eine Menschenmenge sieht zu . <eos> <bos>A boy jumps on his skateboard while a crowd watches
inputs:  <bos>Ein Mann und ein Baby befinden sich in einem gelben Kajak auf dem Wasser . <eos> <bos>A man and a baby are in a yellow kayak on water .


91it [00:21,  4.36it/s]

inputs:  <bos>Zwei Personen sitzen auf einer Bank , und eine Frau steht neben ihnen . <eos> <bos>Two people are sitting on a bench , and one women is standing by them .
inputs:  <bos>Eine Baustelle auf einer Straße mit drei arbeitenden Männern . <eos> <bos>A construction site on a street with three men working .


93it [00:21,  4.58it/s]

inputs:  <bos>Zwei Männer sitzen auf einer Bank und reden , im Hintergrund eine Reklamefläche mit Werbung für Brillen . <eos> <bos>Two men sitting on a bench talking , with a billboard advertisement for glasses in the background .
inputs:  <bos>Eine Gruppe Jugendlicher geht die Straße entlang und schwenkt Fahnen , die das Farbspektrum zeigen . <eos> <bos>A group of youths march down the street waving flags showing the color spectrum .


95it [00:21,  4.99it/s]

inputs:  <bos>Drei alte Männer sehen einem anderen Mann zu , wie er Fisch zubereitet . <eos> <bos>Three old men are watching another man prepare fish .
inputs:  <bos>Eine Frau einem einem weißen Pullunder mit einem grünen , wallenden Rock ein Lied singend auf der Bühne . <eos> <bos>A woman in a white tank top with a green flowing skirt , on stage singing a song .


97it [00:22,  5.11it/s]

inputs:  <bos>Zwei Männer und zwei Frauen , die auf Treppenstufen im Freien sitzen . <eos> <bos>Two men and two women sitting on steps outdoors .
inputs:  <bos>Ein brauner und ein schwarzer Labrador im Freien , wobei der schwarze Labrador ein Spielzeug in seinem Maul hat . <eos> <bos>A brown and black lab are outside and the black lab is catching a toy in its mouth .


99it [00:22,  4.36it/s]


inputs:  <bos>Junger männlicher Hockey - Goalie in roter Jacke duckt sich mit Stock beim Tor . <eos> <bos>Hockey goalie boy in red jacket crouches by goal , with stick .
Trained on 100 datapts
Epoch: 1, Train loss: 0.027, Val loss: 0.000, Epoch time = 25.822s


In [220]:
# torch.save(model, 'oct2+1')
# model = torch.load('oct2+1')

In [221]:
q, k, mask = torch.ones((32, 8, 23, 64)), torch.ones((32, 8, 24, 64)), torch.ones((23, 24))
attn_scores = torch.softmax(
            einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2) + mask,#torch.FloatTensor(self.d_hidden//8)),
            dim=-1
        )

In [222]:
(einops.einsum(q, k, 'b h T_out d_h, b h T_in d_h -> b h T_out T_in')/math.sqrt(2)).shape# + mask

torch.Size([32, 8, 23, 24])

In [230]:
y_toks = [BOS_IDX] #+ list(text_transform_en('[BOS]'))
src_toks = de_tokenizer.encode(de_data[80]).ids
# src_toks = text_transform_en('<bos>')#torch.Tensor([   2,  313,  400,  826, 1276,    3]).long()
model = torch.load('it_kindof_works')
model.eval()
toks_generated=0
max_len=10#100
while y_toks[-1] != EOS_IDX and toks_generated < max_len:
    toks_generated+=1
    out_logits = model(torch.tensor(src_toks).unsqueeze(0), torch.tensor(y_toks).unsqueeze(0), is_causal_tgt=True)
    print(out_logits)
    # print(out_logits[:,-1].shape, "sfdsfhadsfhsdkjhfSDFfdf")
    idx_predicted = torch.argmax(out_logits[:,-1])
    y_toks.append(idx_predicted)

tensor([[[-1.1573, -2.0253, -1.4835,  ..., -0.8829, -0.1583, -1.5284]]],
       grad_fn=<ViewBackward0>)
tensor([[[-1.1573, -2.0253, -1.4835,  ..., -0.8829, -0.1583, -1.5284],
         [-1.1587, -1.4320, -1.3938,  ..., -0.6338, -0.1273, -1.6783]]],
       grad_fn=<ViewBackward0>)
tensor([[[-1.1573, -2.0253, -1.4835,  ..., -0.8829, -0.1583, -1.5284],
         [-1.1587, -1.4320, -1.3938,  ..., -0.6338, -0.1273, -1.6783],
         [-1.6576, -1.1608, -1.0758,  ..., -1.1079, -0.1805, -2.0747]]],
       grad_fn=<ViewBackward0>)
tensor([[[-1.1573, -2.0253, -1.4835,  ..., -0.8829, -0.1583, -1.5284],
         [-1.1587, -1.4320, -1.3938,  ..., -0.6338, -0.1273, -1.6783],
         [-1.6576, -1.1608, -1.0758,  ..., -1.1079, -0.1805, -2.0747],
         [-1.6403, -0.9788, -1.5941,  ..., -1.2258, -0.2059, -1.8604]]],
       grad_fn=<ViewBackward0>)
tensor([[[-1.1573, -2.0253, -1.4835,  ..., -0.8829, -0.1583, -1.5284],
         [-1.1587, -1.4320, -1.3938,  ..., -0.6338, -0.1273, -1.6783],
         [-1

In [231]:
(y_toks,
en_tokenizer.decode(y_toks),#[2456]),
en_data[80])

([2,
  tensor(120),
  tensor(167),
  tensor(162),
  tensor(101),
  tensor(96),
  tensor(3)],
 'A man in a .',
 'Two large tan dogs play along a sandy beach.')

In [229]:
# torch.save(model, 'it_doeskn_works')

In [226]:
# torch.zeros((2, 2, 3, 3)) + MultiHeadAttention.causal_mask(n_heads=2, mask_shape=(3,3))