In [None]:
import torch
import torch.nn as nn

In [None]:
from torchtext.legacy.data import Field, BucketIterator
import spacy


In [None]:
from torchtext.legacy.datasets import Multi30k

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# !python3 -m spacy download de_core_news_sm

Collecting de_core_news_sm==2.2.5
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-2.2.5/de_core_news_sm-2.2.5.tar.gz (14.9 MB)
[K     |████████████████████████████████| 14.9 MB 2.4 MB/s 
Building wheels for collected packages: de-core-news-sm
  Building wheel for de-core-news-sm (setup.py) ... [?25l[?25hdone
  Created wheel for de-core-news-sm: filename=de_core_news_sm-2.2.5-py3-none-any.whl size=14907055 sha256=47b40b1f1b5469dc48a0dac51759d0a844f94919dc1a01d69006ecc7cc3e613d
  Stored in directory: /tmp/pip-ephem-wheel-cache-r0f6tjgh/wheels/00/66/69/cb6c921610087d2cab339062345098e30a5ceb665360e7b32a
Successfully built de-core-news-sm
Installing collected packages: de-core-news-sm
Successfully installed de-core-news-sm-2.2.5
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('de_core_news_sm')


In [None]:
spacy_en = spacy.load("en_core_web_sm")
spacy_ger = spacy.load("de_core_news_sm")

In [None]:
def en_token(text):
    return [token.text for token in spacy_en.tokenizer(text)]
def ger_token(text):
    return [token.text for token in spacy_ger.tokenizer(text)]
eng_field = Field(tokenize=en_token,lower=True, init_token="<sos>",
                    eos_token="<eos>",
                    batch_first=True
                 )
ger_field = Field(tokenize=ger_token,lower=True, init_token="<sos>",
                  eos_token="<eos>",
                  batch_first=True
                 )
train,val,test = Multi30k.splits(exts=(".de",".en"),fields=(ger_field,eng_field))
eng_field.build_vocab(train,max_size=10000,min_freq=2)
ger_field.build_vocab(train,max_size=10000,min_freq=2)

In [None]:
BATCH_SIZE = 16

In [None]:
train_iter, val_iter, test_iter = BucketIterator.splits(
        (train,val,test),
        batch_size=BATCH_SIZE,
        sort_within_batch=True,
        sort_key = lambda x: len(x.src),
        device=device
)

In [None]:
class MultiHead_Attention(nn.Module):
    def __init__(self, hidden_dim, n_head, device):
        super(MultiHead_Attention,self).__init__()
        self.hidden_dim = hidden_dim
        self.n_head = n_head
        self.head_dim = hidden_dim // n_head
        self.q_fc = nn.Linear(hidden_dim,hidden_dim)
        self.k_fc = nn.Linear(hidden_dim,hidden_dim)
        self.v_fc = nn.Linear(hidden_dim,hidden_dim)
        self.fc = nn.Linear(hidden_dim,hidden_dim)
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    def forward(self,query,key,value,mask):
        print(f"input multi shape {query.shape}")
        batch = query.shape[0]
        query_len,key_len,value_len = query.shape[1],key.shape[1],value.shape[1]
        Q = self.q_fc(query)
        K = self.k_fc(key)
        V = self.v_fc(value)
        print(f"query fully connected {Q.shape}")
        Q  = Q.view(batch,self.n_head,query_len,self.head_dim)
        K = K.view(batch,self.n_head,key_len,self.head_dim)
        V = V.view(batch,self.n_head,value_len,self.head_dim)
        print(f"Q view(batch,self.n_head,query_len,self.head_dim) {Q.shape}")
        energy = torch.einsum("bnqd,bnkd->bnqk",Q,K)/ self.scale
        print(f"energy bnqk shape {energy.shape} ")
        if mask is not None:
            energy = energy.masked_fill(mask==0,-1e10)

        attention = torch.softmax(energy,dim=-1)
        value_attn = torch.matmul(attention,V)
        print(f"matmul with value {value_attn.shape}")
                
        x = value_attn.permute(0, 2, 1, 3).contiguous()
        
        x = value_attn.view(batch,query_len,self.hidden_dim)
        x = self.fc(x)
        print(f"out multi shape {x.shape}")
        return x, attention

In [None]:
# source mask torch.Size([16, 1, 1, 14])

In [None]:
class Positon_wise_feed(nn.Module):
    def __init__(self,hidden_dim,pf_dim,dropout):
        super(Positon_wise_feed,self).__init__()
        self.hid_to_pf = nn.Linear(hidden_dim,pf_dim)
        self.pf_to_hid = nn.Linear(pf_dim,hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        x = self.dropout(self.hid_to_pf(x))
        x = self.pf_to_hid(x)
        return x
        

In [None]:
class Encoder_layer(nn.Module):
    def __init__(self,hidden,n_head,pf_dim,dropout,device):
        super(Encoder_layer,self).__init__()
        self.multihead = MultiHead_Attention(hidden,n_head,device)
        self.posi_feed = Positon_wise_feed(hidden,pf_dim,dropout)
        self.Norm_multi = nn.LayerNorm(hidden)
        self.Norm_posifeed = nn.LayerNorm(hidden)
        self.dropout = nn.Dropout(dropout)
    def forward(self,query,key,value,mask):
        x,_ = self.multihead(query,key,value,mask)
        x_mul = self.Norm_multi(query+self.dropout(x))
        x = self.posi_feed(x_mul)
        x = self.Norm_posifeed(x_mul+x)
        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self,input_vocab,hidden,
                 n_head,n_layers,pf_dim,dropout,max_length,device):
        super(Encoder,self).__init__()
        self.embedding = nn.Embedding(input_vocab,hidden)
        self.position_encoding = nn.Embedding(max_length,hidden)
        self.encoder_layer =  nn.ModuleList([Encoder_layer
                 (hidden,n_head,pf_dim,dropout,device) for i in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hidden])).to(device)
    def forward(self,x,maks):
        seq_len = x.shape[1]
        batch_size = x.shape[0]
        position = torch.arange(0,seq_len).unsqueeze(0).repeat(batch_size,1).to(device)
        src = self.dropout((self.embedding(x)*(self.scale))+self.position_encoding(position))
        for layer in self.encoder_layer:
            src = layer(src,src,src,maks)
        return src

In [None]:
class Decoder_layer(nn.Module):
    def __init__(self,hidden,pf_dim,n_head,dropout,device):
        super(Decoder_layer,self).__init__()
        self.mask_multihead = MultiHead_Attention(hidden,n_head,device)
        self.Norm_mask = nn.LayerNorm(hidden)
        self.multihead= MultiHead_Attention(hidden,n_head,device)
        self.Norm_multi = nn.LayerNorm(hidden)
        self.feed = Positon_wise_feed(hidden,pf_dim,dropout)
        self.Norm_feed = nn.LayerNorm(hidden)
        self.dropout = nn.Dropout(dropout)
    def forward(self,query,key,value,key_encoder,value_encoder,src_mask,trg_mask):
        x,_ = self.mask_multihead(query,key,value,trg_mask)
        x_norm = self.Norm_mask(query+self.dropout(x))
        x,_ = self.multihead(query,key_encoder,value_encoder,src_mask)
        x_multi_norm = self.Norm_multi(self.dropout(x)+x_norm)
        x = self.feed(x_multi_norm)
        x_fedd_norm = self.Norm_feed(self.dropout(x)+x_multi_norm)
        return x_fedd_norm

In [None]:
class Decoder(nn.Module):
    def __init__(self,input_vocab,hidden,n_layer,pf_hidden,n_head,max_length,dropout,device):
        super(Decoder,self).__init__()
        self.embedding_input = nn.Embedding(input_vocab,hidden)
        self.embedding_position = nn.Embedding(max_length,hidden)
        self.layer = nn.ModuleList([Decoder_layer(hidden,pf_hidden,n_head,dropout,device) 
                                   for _ in range(n_layer)])
        self.dropout = nn.Dropout(dropout)
        self.device = device
        self.fc = nn.Linear(hidden,input_vocab)
        # self.softmax = nn.Softmax(dim=2)
        self.scale = torch.sqrt(torch.FloatTensor([hidden])).to(device)
    def forward(self,trg_value,key_encoder,value_encoder,src_mask,trg_mask):
        seq_len = trg_value.shape[1]
        batch = trg_value.shape[0]
        position = torch.arange(0,seq_len).unsqueeze(0).repeat(batch,1).to(self.device)
        x = self.dropout(self.embedding_input(trg_value)*self.scale) +self.embedding_position(position)
        for layer in self.layer:
            x = layer(x,x,x,key_encoder,value_encoder,src_mask,trg_mask)
        output1 = self.fc(x)
        # output2 = self.softmax(output1)
        return output1
        
        

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self,encoder,decoder,src_pad,trg_pad,device):
        super(Seq2Seq,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad = src_pad
        self.trg_pad = trg_pad
        self.device = device
    def make_src_mask(self,x):
        src = (x!=self.src_pad).unsqueeze(1).unsqueeze(2)
        return src
    def make_trg_mask(self,x):
        trg = (x!=self.trg_pad).unsqueeze(1).unsqueeze(2)
        # trg1 = trg.permute(3,1,2,0)
        trg_len = x.shape[1]
        sub_mask = torch.tril(torch.ones((trg_len,trg_len),device=self.device)).bool()

        out =   trg & sub_mask
        return out
    def forward(self,src,target):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(target)
       
        encoder_model = self.encoder(src,src_mask)
        decoder = self.decoder(target,encoder_model,encoder_model,src_mask,trg_mask)
        
        return decoder

In [None]:
hidden = 512
pf_dim = 2048
ger_vocab = len(ger_field.vocab)
eng_vocab = len(eng_field.vocab)
dropout = 0.2
n_layer_encoder = 4
n_layer_decoder = 2
n_head = 8
src_pad = ger_field.vocab.stoi[ger_field.pad_token]
trg_pad = eng_field.vocab.stoi[eng_field.pad_token]
max_length = 100

In [None]:
encoder = Encoder(ger_vocab,hidden,n_head,n_layer_encoder,pf_dim,dropout,max_length,device)
decoder = Decoder(eng_vocab,hidden,n_layer_decoder,pf_dim,n_head,max_length,dropout,device)


In [None]:
transformer = Seq2Seq(encoder,decoder,src_pad, trg_pad, device).to(device)

In [None]:
loss = nn.CrossEntropyLoss(ignore_index=trg_pad)
lr = 0.0001
optim = torch.optim.Adam(transformer.parameters(),lr)
num_epoch = 1

In [None]:
from tqdm import tqdm

In [None]:
clip = 1

In [None]:
for i in range(num_epoch):
  loop = tqdm(enumerate(train_iter),total=len(train_iter),leave=False)
  for inx,data in loop:
    src = (data.src).to(device)
    trg = (data.trg).to(device)
    print(f"trg shape is {trg.shape}")
    optim.zero_grad()
    output = transformer(src,trg[:,:-1])

    output_dim = output.shape[-1]
    output1 = output.contiguous().view(-1, output_dim)
    target = trg[:,1:].contiguous().view(-1)
    los = loss(output1,target)
    los.backward()
    torch.nn.utils.clip_grad_norm_(transformer.parameters(), clip)
    optim.step()
    loop.set_description(f"Epoch [{i}/{num_epoch}]")
    loop.set_postfix(loss=los.item())

  0%|          | 0/1813 [00:00<?, ?it/s]

trg shape is torch.Size([16, 26])
input multi shape torch.Size([16, 22, 512])
query fully connected torch.Size([16, 22, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 22, 64])
energy bnqk shape torch.Size([16, 8, 22, 22]) 
matmul with value torch.Size([16, 8, 22, 64])
out multi shape torch.Size([16, 22, 512])
input multi shape torch.Size([16, 22, 512])
query fully connected torch.Size([16, 22, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 22, 64])
energy bnqk shape torch.Size([16, 8, 22, 22]) 
matmul with value torch.Size([16, 8, 22, 64])
out multi shape torch.Size([16, 22, 512])
input multi shape torch.Size([16, 22, 512])
query fully connected torch.Size([16, 22, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 22, 64])
energy bnqk shape torch.Size([16, 8, 22, 22]) 
matmul with value torch.Size([16, 8, 22, 64])
out multi shape torch.Size([16, 22, 512])
input multi shape torch.Size([16, 22, 512])
query f

Epoch [0/1]:   0%|          | 1/1813 [00:02<1:04:53,  2.15s/it, loss=8.8]

trg shape is torch.Size([16, 18])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query f

Epoch [0/1]:   0%|          | 2/1813 [00:03<50:47,  1.68s/it, loss=8.57]  

trg shape is torch.Size([16, 13])
input multi shape torch.Size([16, 11, 512])
query fully connected torch.Size([16, 11, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 11, 64])
energy bnqk shape torch.Size([16, 8, 11, 11]) 
matmul with value torch.Size([16, 8, 11, 64])
out multi shape torch.Size([16, 11, 512])
input multi shape torch.Size([16, 11, 512])
query fully connected torch.Size([16, 11, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 11, 64])
energy bnqk shape torch.Size([16, 8, 11, 11]) 
matmul with value torch.Size([16, 8, 11, 64])
out multi shape torch.Size([16, 11, 512])
input multi shape torch.Size([16, 11, 512])
query fully connected torch.Size([16, 11, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 11, 64])
energy bnqk shape torch.Size([16, 8, 11, 11]) 
matmul with value torch.Size([16, 8, 11, 64])
out multi shape torch.Size([16, 11, 512])
input multi shape torch.Size([16, 11, 512])
query f

Epoch [0/1]:   0%|          | 3/1813 [00:04<43:10,  1.43s/it, loss=8.31]

trg shape is torch.Size([16, 25])
input multi shape torch.Size([16, 21, 512])
query fully connected torch.Size([16, 21, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 21, 64])
energy bnqk shape torch.Size([16, 8, 21, 21]) 
matmul with value torch.Size([16, 8, 21, 64])
out multi shape torch.Size([16, 21, 512])
input multi shape torch.Size([16, 21, 512])
query fully connected torch.Size([16, 21, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 21, 64])
energy bnqk shape torch.Size([16, 8, 21, 21]) 
matmul with value torch.Size([16, 8, 21, 64])
out multi shape torch.Size([16, 21, 512])
input multi shape torch.Size([16, 21, 512])
query fully connected torch.Size([16, 21, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 21, 64])
energy bnqk shape torch.Size([16, 8, 21, 21]) 
matmul with value torch.Size([16, 8, 21, 64])
out multi shape torch.Size([16, 21, 512])
input multi shape torch.Size([16, 21, 512])
query f

Epoch [0/1]:   0%|          | 4/1813 [00:06<48:01,  1.59s/it, loss=8.31]

trg shape is torch.Size([16, 17])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query fully connected torch.Size([16, 13, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 13, 64])
energy bnqk shape torch.Size([16, 8, 13, 13]) 
matmul with value torch.Size([16, 8, 13, 64])
out multi shape torch.Size([16, 13, 512])
input multi shape torch.Size([16, 13, 512])
query f

Epoch [0/1]:   0%|          | 5/1813 [00:07<45:06,  1.50s/it, loss=7.94]

trg shape is torch.Size([16, 25])
input multi shape torch.Size([16, 19, 512])
query fully connected torch.Size([16, 19, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 19, 64])
energy bnqk shape torch.Size([16, 8, 19, 19]) 
matmul with value torch.Size([16, 8, 19, 64])
out multi shape torch.Size([16, 19, 512])
input multi shape torch.Size([16, 19, 512])
query fully connected torch.Size([16, 19, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 19, 64])
energy bnqk shape torch.Size([16, 8, 19, 19]) 
matmul with value torch.Size([16, 8, 19, 64])
out multi shape torch.Size([16, 19, 512])
input multi shape torch.Size([16, 19, 512])
query fully connected torch.Size([16, 19, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 19, 64])
energy bnqk shape torch.Size([16, 8, 19, 19]) 
matmul with value torch.Size([16, 8, 19, 64])
out multi shape torch.Size([16, 19, 512])
input multi shape torch.Size([16, 19, 512])
query f

Epoch [0/1]:   0%|          | 6/1813 [00:09<47:08,  1.57s/it, loss=8.1]

trg shape is torch.Size([16, 20])
input multi shape torch.Size([16, 14, 512])
query fully connected torch.Size([16, 14, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 14, 64])
energy bnqk shape torch.Size([16, 8, 14, 14]) 
matmul with value torch.Size([16, 8, 14, 64])
out multi shape torch.Size([16, 14, 512])
input multi shape torch.Size([16, 14, 512])
query fully connected torch.Size([16, 14, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 14, 64])
energy bnqk shape torch.Size([16, 8, 14, 14]) 
matmul with value torch.Size([16, 8, 14, 64])
out multi shape torch.Size([16, 14, 512])
input multi shape torch.Size([16, 14, 512])
query fully connected torch.Size([16, 14, 512])
Q view(batch,self.n_head,query_len,self.head_dim) torch.Size([16, 8, 14, 64])
energy bnqk shape torch.Size([16, 8, 14, 14]) 
matmul with value torch.Size([16, 8, 14, 64])
out multi shape torch.Size([16, 14, 512])
input multi shape torch.Size([16, 14, 512])
query f



KeyboardInterrupt: ignored