In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

In [18]:
import torch

In [83]:
from recibrew.data_util import construct_torchtext_iterator
train_csv = '../data/processed/train.csv'
dev_csv = '../data/processed/dev.csv'
test_csv = '../data/processed/test.csv'
constructed_ttext = construct_torchtext_iterator(train_csv, dev_csv, test_csv, device='cpu', fix_length=None)

In [102]:
constructed_ttext.keys()

dict_keys(['train_iter', 'val_iter', 'test_iter', 'src_field', 'tgt_field'])

In [84]:
train_iter = constructed_ttext['train_iter']

In [85]:
src_field = constructed_ttext['src_field']
tgt_field = constructed_ttext['tgt_field']

In [86]:
max_vocab = len(src_field.vocab)
max_vocab

3004

In [87]:
btch = next(train_iter.__iter__())

In [88]:
src, tgt = btch.src, btch.tgt

In [89]:
from torch.nn import Transformer, Embedding, Dropout, Module

In [90]:
trfm = Transformer(d_model=128, dim_feedforward=512, num_encoder_layers=4, num_decoder_layers=4, dropout=0.3)

In [91]:
num_embedding = 128
dropout = 0.2
max_len = 140

In [92]:
import math

In [93]:
class PositionalEncoding(Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [94]:
out_emb_dec = pos_embedding(inp_embedding(tgt))

In [95]:
out_emb_enc.shape

torch.Size([128, 64, 128])

In [96]:
out_emb_dec.shape

torch.Size([121, 64, 128])

In [97]:
class FullTransformer(Module):
    
    def __init__(self, num_embedding=128, dim_feedforward=512, num_encoder_layer=4, num_decoder_layer=4, dropout=0.3, padding_idx=1):
        super(FullTransformer, self).__init__()
        
        self.padding_idx = padding_idx
        
        # [x : seq_len,  batch_size ]
        self.inp_embedding = Embedding(max_vocab , num_embedding, padding_idx=padding_idx)

        # [ x : seq_len, batch_size, num_embedding ]
        self.pos_embedding = PositionalEncoding(num_embedding, dropout, max_len=max_len)
        
        self.trfm = Transformer(d_model=num_embedding, dim_feedforward=dim_feedforward, 
                                num_encoder_layers=num_encoder_layer, num_decoder_layers=num_decoder_layer, 
                                dropout=dropout)
    
    def make_pad_mask(self, inp):
        """
        Make mask attention that caused 'True' element will not be attended (ignored).
        Padding stated in self.padding_idx will not be attended at all.
        """
        return (inp == self.padding_idx).transpose(0, 1)
    
    def forward(self, src, tgt):
        """
        forward!
        """
        # Generate mask for decoder attention
        tgt_mask = ft.trfm.generate_square_subsequent_mask(len(tgt))   
        
        # trg_mask shape = [target_seq_len, target_seq_len]
        src_pad_mask = self.make_pad_mask(src)
        tgt_pad_mask = self.make_pad_mask(tgt)
        
        # [ src : seq_len, batch_size, num_embedding ]

        out_emb_enc = self.pos_embedding(inp_embedding(src))
        
        # [ src : seq_len, batch_size, num_embedding ]
        out_emb_dec = self.pos_embedding(inp_embedding(tgt))
        
        out_trf = self.trfm(out_emb_enc, out_emb_dec, src_mask=None, tgt_mask=tgt_mask, memory_mask=None,
                          src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask, memory_key_padding_mask=src_pad_mask)
        
        # [ out_trf : seq_len, batch_size, num_embedding]
        return out_trf

In [63]:
ft = FullTransformer()

In [64]:
# [ x : seq_len, batch_size, num_embedding ]

out_emb_enc = ft.pos_embedding(ft.inp_embedding(src))
out_emb_dec = ft.pos_embedding(ft.inp_embedding(tgt))

In [74]:
tgt[:-1]

tensor([[  2,   2,   2,  ...,   2,   2,   2],
        [  5,   6,   5,  ...,   6,   5,   5],
        [ 16, 225, 114,  ..., 193,  16,  93],
        ...,
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1],
        [  1,   1,   1,  ...,   1,   1,   1]])

In [105]:
output_trf = ft(src,tgt[:-1])

In [106]:
gold_truth = tgt[1:,:]

In [78]:
trg_mask = ft.trfm.generate_square_subsequent_mask(len(tgt))  # Mask for generator 

In [79]:
trg_mask.shape

torch.Size([128, 128])

In [38]:
ft.forward(src, tgt).shape

torch.Size([128, 64, 128])

In [28]:
trfm.forward(out_emb_enc, out_emb_dec)

tensor([[[-0.6885,  0.3914, -0.1431,  ..., -2.0010, -2.1248,  1.2258],
         [-0.8926,  0.4760,  0.0989,  ..., -1.8607, -1.8448, -0.6924],
         [-0.7496,  1.6615,  0.3059,  ..., -1.6595, -2.5151, -1.3831],
         ...,
         [-1.8085,  1.0539, -0.6407,  ..., -1.0351, -1.2879, -1.1842],
         [-0.5450,  1.0887, -0.3561,  ..., -0.5517, -2.0140, -1.0108],
         [-1.0373,  1.8606,  1.2161,  ..., -0.6123, -2.4126, -0.0155]],

        [[-0.5035,  0.7363,  1.6802,  ..., -1.0266, -1.4050, -0.6545],
         [-0.6525, -0.1513, -0.5397,  ..., -0.1573, -2.2122, -0.4740],
         [-1.3595,  0.4303, -0.4063,  ..., -0.6392, -0.9940, -1.8435],
         ...,
         [-0.8114, -1.2614, -0.2073,  ..., -1.0146, -0.2468,  0.0290],
         [-2.1922,  0.4243,  1.3759,  ..., -0.9851, -1.4986, -1.1615],
         [-0.9675, -1.4978, -1.0785,  ..., -0.3849, -1.1353, -0.1319]],

        [[-0.2701,  0.3993,  0.7897,  ..., -2.2512, -0.6533,  0.2493],
         [-2.8280,  0.7702,  0.4581,  ...,  0

In [120]:
linear_to_out(reformer_out[0]).shape

torch.Size([64, 128, 3004])

# Combine them to become a module

In [127]:
from recibrew.nn.transformers import FullTransformer

In [128]:
ft = FullTransformer(max_vocab)

In [132]:
out_trf = ft.forward(src, tgt[:-1])

In [133]:
out_trf.shape

torch.Size([120, 64, 3004])

## Calculate Loss

In [134]:
from torch.nn import functional as F

In [135]:
out_trf.size()

torch.Size([120, 64, 3004])

In [136]:
tgt[1:].shape

torch.Size([120, 64])

In [145]:
tgt[1:].view(-1).shape

torch.Size([7680])

In [138]:
out_trf.view(-1, output_dim).shape

torch.Size([180240, 128])

In [142]:
out_trf.shape

torch.Size([120, 64, 3004])

In [143]:
output_dim = out_trf.shape[-1]
loss = torch.nn.CrossEntropyLoss()

In [144]:
loss(out_trf.view(-1, output_dim), tgt[1:,:].view(-1))

tensor(8.1723, grad_fn=<NllLossBackward>)

In [120]:
F.nll_loss(out_trf, tgt[1:])

ValueError: Expected target size (120, 128), got torch.Size([120, 64])