<a href="https://colab.research.google.com/github/eisbetterthanpi/pytorch/blob/main/translation_transformer_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title setup

# https://pytorch.org/tutorials/beginner/translation_transformer.html
# https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/c64c91cf87c13c0e83586b8e66e4d74e/translation_transformer.ipynb

# https://github.com/pytorch/data
%pip install portalocker
%pip install torchdata

# Create source and target language tokenizer. Make sure to install the dependencies.
!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm


In [2]:
# @title data

from torchtext.datasets import multi30k, Multi30k
# modify the URLs for the dataset since the links to the original dataset are broken https://github.com/pytorch/text/issues/1756#issuecomment-1163664163
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

from torchtext.data.utils import get_tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')


UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3 # unknown, pad, bigining, end of sentence
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

from torchtext.vocab import build_vocab_from_iterator
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))

de_tokens = [de_tokenizer(data_sample[0]) for data_sample in train_iter]
en_tokens = [en_tokenizer(data_sample[1]) for data_sample in train_iter]

de_vocab = build_vocab_from_iterator(de_tokens, min_freq=1, specials=special_symbols, special_first=True)
en_vocab = build_vocab_from_iterator(en_tokens, min_freq=1, specials=special_symbols, special_first=True)
de_vocab.set_default_index(UNK_IDX)
en_vocab.set_default_index(UNK_IDX)

import torch

def de_transform(o):
    o=de_tokenizer(o)
    o=de_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))

def en_transform(o):
    o=en_tokenizer(o)
    o=en_vocab(o)
    return torch.cat((torch.tensor([BOS_IDX]), torch.tensor(o), torch.tensor([EOS_IDX])))


from torch.nn.utils.rnn import pad_sequence
# function to collate data samples into batch tensors
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        tgt_batch.append(en_transform(tgt_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    # src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    # tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch


torch.manual_seed(0)

train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
batch_size = 128 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_iter, batch_size=batch_size, collate_fn=collate_fn)

# vocab_transform = {SRC_LANGUAGE:de_vocab, TGT_LANGUAGE:en_vocab}
# text_transform = {SRC_LANGUAGE:de_transform, TGT_LANGUAGE:en_transform}




In [4]:
# @title model nn.
# import torch
# import torch.nn as nn
# import math
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# class PositionalEncoding(nn.Module):
#     def __init__(self, emb_size, dropout, maxlen = 5000):
#         super(PositionalEncoding, self).__init__()
#         self.dropout = nn.Dropout(dropout)
#         den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
#         pos = torch.arange(0, maxlen).reshape(maxlen, 1) # .reshape(-1, 1)
#         pos_emb = torch.zeros((maxlen, emb_size))
#         pos_emb[:, 0::2] = torch.sin(pos * den) # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
#         pos_emb[:, 1::2] = torch.cos(pos * den) # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
#         pos_emb = pos_emb.unsqueeze(-2)
#         self.register_buffer('pos_emb', pos_emb) # register as buffer so optimizer wont update it

#     def forward(self, token_emb):
#         return self.dropout(token_emb + self.pos_emb[:token_emb.size(0), :])


# class TokenEmbedding(nn.Module):
#     def __init__(self, vocab_size, emb_size):
#         super(TokenEmbedding, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, emb_size)
#         self.emb_size = emb_size

#     def forward(self, tokens):
#         return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# class Transformer(nn.Module):
#     def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward = 512, dropout = 0.1):
#         super(Transformer, self).__init__()
#         self.emb_size = emb_size
#         # self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
#         # self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)
#         self.pos_enc = PositionalEncoding(emb_size, dropout=dropout)
#         self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout)
#         self.generator = nn.Linear(emb_size, tgt_vocab_size)

#         self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
#         self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
#         for p in self.parameters():
#         # for p in transformer.parameters():
#             if p.dim() > 1:
#                 nn.init.xavier_uniform_(p)

#     def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
#         src_emb = self.pos_enc(self.src_tok_emb(src))
#         tgt_emb = self.pos_enc(self.tgt_tok_emb(tgt))

#         # src = self.src_tok_emb(src.long()) * math.sqrt(self.emb_size) # https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod
#         # src_emb = self.pos_enc(src)
#         # tgt = self.src_tok_emb(tgt.long()) * math.sqrt(self.emb_size)
#         # tgt_emb = self.pos_enc(tgt)

#         outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
#         return self.generator(outs)

#     def encode(self, src, src_mask):
#         return self.transformer.encoder(self.pos_enc(self.src_tok_emb(src)), src_mask)

#         # src = self.src_tok_emb(src.long()) * math.sqrt(self.emb_size)
#         # src_emb = self.pos_enc(src)
#         # return self.transformer.encoder(src_emb, src_mask)

#     def decode(self, tgt, memory, tgt_mask):
#         return self.transformer.decoder(self.pos_enc(self.tgt_tok_emb(tgt)), memory, tgt_mask)

#         # tgt = self.tgt_tok_emb(tgt.long()) * math.sqrt(self.emb_size)
#         # tgt_emb = self.pos_enc(tgt)
#         # return self.transformer.decoder(tgt_emb, memory, tgt_mask)


# def greedy_decode(model, src, src_mask, max_len, start_symbol):
#     src = src.to(device)
#     src_mask = src_mask.to(device)
#     memory = model.encode(src, src_mask)
#     ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
#     for i in range(max_len-1):
#         memory = memory.to(device)
#         tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)
#         out = model.decode(ys, memory, tgt_mask)
#         out = out.transpose(0, 1)
#         prob = model.generator(out[:, -1])
#         _, next_word = torch.max(prob, dim=1)
#         next_word = next_word.item()
#         ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
#         if next_word == EOS_IDX:
#             break
#     return ys

# # actual function to translate input sentence into target language
# def translate(model, src_sentence):
#     model.eval()
#     src = de_transform(src_sentence).view(-1, 1)
#     num_tokens = src.shape[0]
#     src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
#     tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
#     return " ".join(en_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")


# src_vocab_size = len(de_vocab)
# tgt_vocab_size = len(en_vocab)

# emb_size = 512
# nhead = 8
# dim_feedforward = 512
# num_encoder_layers = 3
# num_decoder_layers = 3

# transformer = Transformer(num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward).to(device)


In [17]:
# @title TranQuocTrinh transformer
import torch
import torch.nn as nn
import numpy as np
import math
from torch.autograd import Variable
# https://github.com/TranQuocTrinh/transformer/blob/main/models.py

# # Embedding the input sequence
# class Embedding(nn.Module):
#     def __init__(self, vocab_size, embedding_dim):
#         super(Embedding, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, embedding_dim)

#     def forward(self, x):
#         return self.embedding(x)

class PositionalEncoder(nn.Module):
    def __init__(self, embedding_dim, max_seq_length=512, dropout=0.1):
        super(PositionalEncoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_length, embedding_dim)
        # for pos in range(max_seq_length):
        #     for i in range(0, embedding_dim, 2):
        #         pe[pos, i] = math.sin(pos/(10000**(2*i/embedding_dim)))
        #         pe[pos, i+1] = math.cos(pos/(10000**((2*i+1)/embedding_dim)))
        position = torch.arange(0, max_seq_length).unsqueeze(1) # https://nlp.seas.harvard.edu/annotated-transformer/
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x = x*math.sqrt(self.embedding_dim) # ?
        # return self.dropout(x + Variable(self.pe[:, :x.size(1)], requires_grad=False).to(x.device))
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

# Self-attention layer
class SelfAttention(nn.Module): #Scaled Dot-Product Attention
    def __init__(self, dropout=0.1):
        super(SelfAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        key_dim = key.size(-1)
        attn = torch.matmul(query / np.sqrt(key_dim), key.transpose(2, 3))
        # attn = query @ key.transpose(-1, -2) / np.sqrt(key_dim)
        if mask is not None:
            mask = mask.unsqueeze(1) #?
            # print("selfatt",mask.shape,attn.shape) # [11, 1, 11], [11, 8, 1, 1]
            attn = attn.masked_fill(mask == 0, -1e9)
        attn = self.dropout(torch.softmax(attn, dim=-1))
        output = torch.matmul(attn, value)
        # output = attn @ value
        return output

# Multi-head attention layer
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.self_attention = SelfAttention(dropout)
        self.num_heads = num_heads
        self.dim_per_head = embedding_dim // num_heads
        self.query_projection = nn.Linear(embedding_dim, embedding_dim) # bias=False
        self.key_projection = nn.Linear(embedding_dim, embedding_dim) # bias=False
        self.value_projection = nn.Linear(embedding_dim, embedding_dim) # bias=False
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        query = self.query_projection(query)
        key = self.key_projection(key)
        value = self.value_projection(value)
        query = query.view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1, 2)
        scores = self.self_attention(query, key, value, mask)
        output = scores.transpose(1, 2).contiguous().view(batch_size, -1, self.embedding_dim)
        output = self.out(output)
        return output

# class Norm(nn.Module):
#     def __init__(self, embedding_dim):
#         super(Norm, self).__init__()
#         self.norm = nn.LayerNorm(embedding_dim)

#     def forward(self, x):
#         return self.norm(x)


# Transformer encoder layer
class EncoderLayer(nn.Module):
    def __init__(self, embedding_dim, num_heads, ff_dim=2048, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(embedding_dim, num_heads, dropout)
        self.feed_forward = nn.Sequential(nn.Linear(embedding_dim, ff_dim), nn.ReLU(),
            nn.Linear(ff_dim, embedding_dim))
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        # self.norm1 = Norm(embedding_dim)
        # self.norm2 = Norm(embedding_dim)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x, mask=None):
        x2 = self.norm1(x) # add first then norm?
        # Add and Muti-head attention
        x = x + self.dropout1(self.self_attention(x2, x2, x2, mask))
        x2 = self.norm2(x)
        x = x + self.dropout2(self.feed_forward(x2))
        return x

# Transformer decoder layer
class DecoderLayer(nn.Module):
    def __init__(self, embedding_dim, num_heads, ff_dim=2048, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(embedding_dim, num_heads, dropout)
        self.encoder_attention = MultiHeadAttention(embedding_dim, num_heads, dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, ff_dim), nn.ReLU(),
            nn.Linear(ff_dim, embedding_dim))
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        # self.norm1 = Norm(embedding_dim)
        # self.norm2 = Norm(embedding_dim)
        # self.norm3 = Norm(embedding_dim)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.norm3 = nn.LayerNorm(embedding_dim)

    def forward(self, x, memory, source_mask, target_mask):
        x2 = self.norm1(x) # should add then norm?
        x = x + self.dropout1(self.self_attention(x2, x2, x2, target_mask))
        x2 = self.norm2(x)
        x = x + self.dropout2(self.encoder_attention(x2, memory, memory, source_mask))
        x2 = self.norm3(x)
        x = x + self.dropout3(self.feed_forward(x2))
        return x

# Encoder transformer
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_seq_len, num_heads, num_layers, dropout=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # self.num_layers = num_layers
        # self.num_heads = num_heads
        # self.embedding_dim = embedding_dim
        self.layers = nn.ModuleList([EncoderLayer(embedding_dim, num_heads, 2048, dropout) for _ in range(num_layers)])
        # self.norm = Norm(embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
        self.position_embedding = PositionalEncoder(embedding_dim, max_seq_len, dropout)

    def forward(self, source, source_mask):
        x = self.embedding(source)
        x = self.position_embedding(x)
        for layer in self.layers:
            x = layer(x, source_mask)
        x = self.norm(x)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_seq_len,num_heads, num_layers, dropout=0.1):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        # self.num_layers = num_layers
        # self.num_heads = num_heads
        # self.embedding_dim = embedding_dim
        self.layers = nn.ModuleList([DecoderLayer(embedding_dim, num_heads, 2048, dropout) for _ in range(num_layers)])
        # self.norm = Norm(embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
        self.position_embedding = PositionalEncoder(embedding_dim, max_seq_len, dropout)

    def forward(self, target, memory, source_mask, target_mask):
        x = self.embedding(target)
        x = self.position_embedding(x)
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)
        x = self.norm(x)
        return x


# Transformers
class Transformer(nn.Module):
    def __init__(self, source_vocab_size, target_vocab_size, source_max_seq_len, target_max_seq_len, embedding_dim, num_heads, num_layers, dropout=0.1):
        super(Transformer, self).__init__()
        # self.source_vocab_size = source_vocab_size
        # self.target_vocab_size = target_vocab_size
        # self.source_max_seq_len = source_max_seq_len
        # self.target_max_seq_len = target_max_seq_len
        # self.embedding_dim = embedding_dim
        # self.num_heads = num_heads
        # self.num_layers = num_layers
        # self.dropout = dropout
        self.encoder = Encoder(source_vocab_size, embedding_dim, source_max_seq_len, num_heads, num_layers, dropout)
        self.decoder = Decoder(target_vocab_size, embedding_dim, target_max_seq_len, num_heads, num_layers, dropout)
        self.final_linear = nn.Linear(embedding_dim, target_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, source, target, source_mask, target_mask):
        memory = self.encoder(source, source_mask)
        output = self.decoder(target, memory, source_mask, target_mask)
        output = self.dropout(output)
        output = self.final_linear(output)
        return output

    def make_source_mask(self, source_ids, source_pad_id):
        return (source_ids != source_pad_id).unsqueeze(-2)

    def make_target_mask(self, target_ids):
        batch_size, len_target = target_ids.size()
        subsequent_mask = (1 - torch.triu(torch.ones((1, len_target, len_target), device=target_ids.device), diagonal=1)).bool()
        return subsequent_mask

# src_mask = make_source_mask(source_ids, source_pad_id)
# tgt_mask = make_target_mask(target_ids)

src_vocab_size = len(de_vocab)
tgt_vocab_size = len(en_vocab)

emb_size = 512
nhead = 8
dim_feedforward = 512
num_encoder_layers = 3
num_decoder_layers = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# transformer = Transformer(num_encoder_layers, num_decoder_layers, emb_size, nhead, src_vocab_size, tgt_vocab_size, dim_feedforward).to(device)
source_max_seq_len, target_max_seq_len = 512,512
transformer = Transformer(src_vocab_size, tgt_vocab_size, source_max_seq_len, target_max_seq_len, emb_size, nhead, num_encoder_layers, dropout=0.1).to(device)



In [49]:
def collate_fn(batch): # convert a batch of raw strings into batch tensors
    mt=torch.empty(101)
    src_batch, tgt_batch = [], []
    # print(batch) # batch_size*(src,tgt)
    for src_sample, tgt_sample in batch:
        src_batch.append(de_transform(src_sample.rstrip("\n")))
        tgt_batch.append(en_transform(tgt_sample.rstrip("\n")))
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    # print(src_batch.shape) # [27, batch_size]
    # src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    # tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch


torch.manual_seed(0)

# train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
batch_size = 47 # 128
train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)
# train_loader = torch.utils.data.DataLoader(train_iter, batch_size=batch_size)


In [54]:

for src, tgt in train_loader:
    # print(src, tgt)
    # print(len(src), len(tgt))
    print(src.shape,tgt.shape) # [27, batch_size], [24, batch_size]
    src = src.to(device)
    tgt = tgt.to(device)
    # src=torch.randint(10,(batch_size,13),device=device)
    # tgt=torch.randint(10,(batch_size,13),device=device)
    src=torch.randint(10,(13,batch_size),device=device)
    tgt=torch.randint(10,(11,batch_size),device=device)
    tgt_input = tgt[:-1, :]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
    # logits = transformer(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
    logits = transformer(src, tgt_input, src_mask, tgt_mask)
    tgt_out = tgt[1:, :]
    loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
    break

# [32, 1, 32]) torch.Size([32, 8, 64, 64]
# selfatt torch.Size([batch_size, 1, batch_size]) torch.Size([batch_size, 8, 13, 13]) # src batch_size,13
# selfatt torch.Size([13, 1, 13]) torch.Size([13, 8, batch_size, batch_size]) # src 13, batch_size




torch.Size([25, 47]) torch.Size([24, 47])
selfatt torch.Size([47, 1, 47]) torch.Size([47, 8, 13, 13])


RuntimeError: ignored

In [18]:
# @title train eval

# subsequent word mask that will prevent the model from looking into the future words when making predictions.
# also need masks to hide source and target padding token
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    # mask = (torch.tril(torch.ones((sz, sz), device=device)) == 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

# def train_epoch(dataloader, model, loss_fn, optimizer, scheduler=None, verbose=True):
def train_epoch(dataloader, model, loss_fn, optimizer):
    model.train()
    losses = 0
    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        optimizer.zero_grad()
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        optimizer.step()
        losses += loss.item()
    return losses / len(list(dataloader))

# def evaluate(dataloader, model, loss_fn, verbose=True):
def evaluate(dataloader, model, loss_fn):
    model.eval()
    losses = 0
    for src, tgt in dataloader:
        src = src.to(device)
        tgt = tgt.to(device)
        tgt_input = tgt[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(list(dataloader))


def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)
    # print("greedy_decode",src.shape, src_mask.shape) # [11, 1][11, 11]
    memory = model.encoder(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys

# actual function to translate input sentence into target language
def translate(model, src_sentence):
    model.eval()
    src = de_transform(src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(en_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")


print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))


selfatt torch.Size([11, 1, 11]) torch.Size([11, 8, 1, 1])


RuntimeError: ignored

In [13]:
# @title wwwwwwwwwwwwww

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) # lr=0.0001

import time
epochs = 18

for epoch in range(epochs):
    start_time = time.time()
    train_loss = train_epoch(train_loader, transformer, loss_fn, optimizer)
    end_time = time.time()
    val_loss = evaluate(val_loader, transformer, loss_fn)
    print((f"Epoch: {epoch+1}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
    print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))




TypeError: ignored

In [None]:
# @title inference

print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))
