In [None]:
import os
import numpy as np
import time
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim import SGD
import torchtext
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from torchtext.data.utils import get_tokenizer
from torchtext import data
from torchtext.data.metrics import bleu_score
import spacy
from spacy.symbols import ORTH
import math
import random
import tqdm.notebook as tq
import copy

In [None]:
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.linear_input = nn.Linear(input_size, 4 * hidden_size)
        self.linear_hidden = nn.Linear(hidden_size, 4 * hidden_size)

    def forward(self, x, state):
        # WRITE YOUR CODE HERE
        hx = state[0]
        cx = state[1]

        out1 = self.linear_hidden(hx) + self.linear_input(x)
        chunk_forgetgate, chunk_ingate, chunk_cellgate, chunk_outgate = torch.chunk(out1, chunks=4, dim=1)

        fx = torch.sigmoid(chunk_forgetgate)
        ix = torch.sigmoid(chunk_ingate)
        c_hat_y = torch.tanh(chunk_cellgate)
        ox = torch.sigmoid(chunk_outgate)

        cy = cx*fx + ix*c_hat_y
        hy = torch.tanh(cy) * ox

        return hy, (hy, cy)

In [None]:
class LSTMLayer(nn.Module):
    def __init__(self,*cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = LSTMCell(*cell_args)

    def forward(self, x, state, length_x=None):
        # DO NOT MODIFY
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = x.unbind(0)
        assert (length_x is None) or torch.all(length_x == length_x.sort(descending=True)[0])
        outputs = [] 
        out_hidden_state = []
        out_cell_state = []
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i] , state)
            outputs += [out] 
            if length_x is not None:
                if torch.any(i+1 == length_x):
                    out_hidden_state = [state[0][i+1==length_x]] + out_hidden_state
                    out_cell_state = [state[1][i+1==length_x]] + out_cell_state
        if length_x is not None:
            state = (torch.cat(out_hidden_state, dim=0), torch.cat(out_cell_state, dim=0))
        return torch.stack(outputs), state 

In [None]:
class LSTM(nn.Module):
    def __init__(self, ninp, nhid, num_layers, dropout):
        super(LSTM, self).__init__()
        self.layers = []
        self.dropout = nn.Dropout(dropout)
        for i in range(num_layers):
            if i == 0:
                self.layers.append(LSTMLayer(ninp, nhid))
            else:
                self.layers.append(LSTMLayer(nhid, nhid)) 
        self.layers = nn.ModuleList(self.layers) 

    def forward(self, x, states, length_x=None):
        # WRITE YOUR CODE HERE

        output_states = []
        result = x
        num_layers = len(self.layers)

        for i in range(num_layers):
          result, output_state = self.layers[i](result, states[i], length_x = length_x)
          if i != num_layers - 1:
            result = self.dropout(result)
          output_states.append(output_state)
        output = result
        return output, output_states

In [None]:
class LSTMEncoder(nn.Module):
    def __init__(self):
        super(LSTMEncoder, self).__init__()
        ninp = args.ninp
        nhid = args.nhid
        nlayers = args.nlayers
        dropout = args.dropout
        self.embed = nn.Embedding(src_ntoken, ninp, padding_idx=pad_id)
        self.dropout = nn.Dropout(dropout)
        self.lstm = LSTM(ninp, nhid, nlayers, dropout)
        
    def forward(self, x, states, length_x=None):
        # WRITE YOUR CODE HERE
        out1 = self.embed(x)
        out1 = self.dropout(out1)
        output, context_vectors = self.lstm(out1, states, length_x)
        return output, context_vectors

In [None]:
class LSTMDecoder(nn.Module):
    def __init__(self):
        super(LSTMDecoder, self).__init__()
        self.embed = nn.Embedding(trg_ntoken, args.ninp, padding_idx=pad_id)
        self.lstm = LSTM(args.ninp, args.nhid, args.nlayers, args.dropout)
        self.fc_out = nn.Linear(args.nhid, trg_ntoken)
        self.dropout = nn.Dropout(args.dropout)
        self.fc_out.weight = self.embed.weight
        
    def forward(self, x, states):
        # WRITE YOUR CODE HERE
        out1 = self.embed(x)
        out1 = self.dropout(out1)
        output, output_states = self.lstm(out1, states)

        output = self.fc_out(output)

        return output, output_states

In [None]:
class LSTMSeq2Seq(nn.Module):
    def __init__(self):
        super(LSTMSeq2Seq, self).__init__()
        self.encoder = LSTMEncoder()
        self.decoder = LSTMDecoder()
    
    def _get_init_states(self, x):
        init_states = [
            (torch.zeros((x.size(1), args.nhid)).to(x.device),
            torch.zeros((x.size(1), args.nhid)).to(x.device))
            for _ in range(args.nlayers)
        ]
        return init_states

    def forward(self, x, y, length, max_len=None, teacher_forcing=True):
        # WRITE YOUR CODE HERE
        init_states = self._get_init_states(x)
        _, output_states = self.encoder(x, init_states, length)

        dec_input = y[0:1]

        if max_len is None:
          trg_len = y.size(0)
        else:
          trg_len = max_len
        output_token = []

        dec_output, states = self.decoder(dec_input, output_states)
        output_token.append(dec_output)

        for i in range(1, trg_len - 1):

          dec_input = dec_output.argmax(-1)
          if teacher_forcing:
            dec_input = y[i: i+1]

          dec_output, states = self.decoder(dec_input, states)
          output_token.append(dec_output)

        output = torch.cat(output_token)
        return output

In [None]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()

        self.nhid_enc = args.nhid
        self.nhid_dec = args.nhid
        self.W1 = nn.Linear(self.nhid_enc, args.nhid_attn)
        self.W2 = nn.Linear(self.nhid_dec, args.nhid_attn)
        self.W3 = nn.Linear(args.nhid_attn, 1)

    def forward(self, x, enc_o, dec_h, length_enc=None):
        # WRITE YOUR CODE HERE
        first = self.W1(enc_o)
        second = self.W2(dec_h)

        score = self.W3(torch.tanh(first + second))
        score = torch.squeeze(score)
        L, B = score.size()

        if length_enc is not None:
          lol = torch.range(0, L-1).reshape(L, -1).expand(L, B).to(device)
          xd = length_enc.expand(L, length_enc.shape[0])
          score[lol > xd] = float("-inf")
        
        score = torch.unsqueeze(score, 2)
        out = F.softmax(score, dim=0)
        out = out * enc_o
        out = torch.sum(out, dim=0, keepdims=True)
        out = torch.cat((x, out), dim=-1)

        return out

In [None]:
class LSTMAttnDecoder(nn.Module):
    def __init__(self):
        super(LSTMAttnDecoder, self).__init__()
        self.embed = nn.Embedding(trg_ntoken, args.ninp, padding_idx=pad_id)
        self.lstm = LSTM(args.ninp + args.nhid, args.nhid, args.nlayers, args.dropout)
        self.fc_out = nn.Linear(args.nhid, trg_ntoken)
        self.dropout = nn.Dropout(args.dropout)
        self.attn = Attention()
        self.fc_out.weight = self.embed.weight
        
    def forward(self, x, enc_o, states, length_enc=None):
        # WRITE YOUR CODE HERE
        out = self.embed(x)
        out = self.dropout(out)
        state, _ = states[-1]

        out = self.attn(out, enc_o, state, length_enc)
        output, output_states = self.lstm(out, states)
        output = self.fc_out(output)
        return output, output_states 

In [None]:
class LSTMAttnSeq2Seq(nn.Module):
    def __init__(self):
        super(LSTMAttnSeq2Seq, self).__init__()
        self.encoder = LSTMEncoder()
        self.decoder = LSTMAttnDecoder()
    
    def _get_init_states(self, x):
        init_states = [
            (torch.zeros((x.size(1), args.nhid)).to(x.device),
            torch.zeros((x.size(1), args.nhid)).to(x.device))
            for _ in range(args.nlayers)
        ]
        return init_states

    def forward(self, x, y, length, max_len=None, teacher_forcing=True):
        # WRITE YOUR CODE HERE
        init_states = self._get_init_states(x)
        
        enc_output, _ = self.encoder(x, init_states, length)

        dec_input = y[0:1]

        if max_len is None:
          trg_len = y.size(0)
        else:
          trg_len = max_len
        output_token = []

        dec_output, states = self.decoder(dec_input, enc_output, init_states, length)
        output_token.append(dec_output)

        for i in range(1, trg_len - 1):
          dec_input = dec_output.argmax(-1)
          if teacher_forcing:
            dec_input = y[i: i+1]

          dec_output, states = self.decoder(dec_input, enc_output, states, length)
          output_token.append(dec_output)

        output = torch.cat(output_token)
        #print(output.shape)
        return output

In [None]:
MAX_LEN = 100
class MaskedMultiheadAttention(nn.Module):
    """
    A vanilla multi-head masked attention layer with a projection at the end.
    """
    def __init__(self, mask=False):
        super(MaskedMultiheadAttention, self).__init__()
        assert args.nhid_tran % args.nhead == 0
        # mask : whether to use 
        # key, query, value projections for all heads
        self.key = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.query = nn.Linear(args.nhid_tran, args.nhid_tran)
        self.value = nn.Linear(args.nhid_tran, args.nhid_tran)
        # regularization
        self.attn_drop = nn.Dropout(args.attn_pdrop)
        # output projection
        self.proj = nn.Linear(args.nhid_tran, args.nhid_tran)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        if mask:
            self.register_buffer("mask", torch.tril(torch.ones(MAX_LEN, MAX_LEN)))
        self.nhead = args.nhead
        self.d_k = args.nhid_tran // args.nhead

    def forward(self, q, k, v, mask=None):
        # WRITE YOUR CODE HERE
        q1 = torch.transpose(self.query(q).reshape(q.size(0), q.size(1), self.nhead, self.d_k), 1, 2)
        k1 = torch.transpose(self.key(k).reshape(k.size(0), k.size(1), self.nhead, self.d_k), 1, 2)
        v1 = torch.transpose(self.value(v).reshape(v.size(0), v.size(1), self.nhead, self.d_k), 1, 2)

        score = torch.matmul(q1, torch.transpose(k1, 2, 3))
        score = score/math.sqrt(self.d_k)
        B, nhead, T_q, T = score.size()
        if hasattr(self, 'mask'):
          mask2 = torch.stack([torch.stack([self.mask[:T_q, :T]] * score.size(1))] * score.size(0))
          #score[mask2 == 0] = float('-inf')
          score[self.mask == 0] = float('-inf')
        
        if mask is not None:
          mask1 = torch.stack([torch.stack([mask] * score.size(1), 1)] * score.size(2), 2)
          #score[mask1 == 0] = float('-inf')
          score[mask == 0] = float('-inf')

        score = F.softmax(score, dim=-1)
        score = self.attn_drop(score)
        score = torch.matmul(score, v1)

        score = torch.transpose(score, 1, 2)
        score = score.reshape(score.size(0), score.size(1), -1)
        score = self.proj(score)
        
        return score

In [None]:
class TransformerEncLayer(nn.Module):
    def __init__(self):
        super(TransformerEncLayer, self).__init__()
        self.ln1 = nn.LayerNorm(args.nhid_tran)
        self.ln2 = nn.LayerNorm(args.nhid_tran)
        self.attn = MaskedMultiheadAttention()
        self.dropout1 = nn.Dropout(args.resid_pdrop)
        self.dropout2 = nn.Dropout(args.resid_pdrop)
        self.ff = nn.Sequential(
            nn.Linear(args.nhid_tran, args.nff),
            nn.ReLU(), 
            nn.Linear(args.nff, args.nhid_tran)
        )

    def forward(self, x, mask=None):
        # WRITE YOUR CODE HERE
        out1 = self.ln1(x)
        out = self.attn(out1, out1, out1, mask)
        out = self.dropout1(out)
        out = out + out1
        
        out2 = self.ln2(out)
        out = self.ff(out2)
        out = self.dropout2(out)
        out = out + out2
        
        return out


In [None]:
class TransformerDecLayer(nn.Module):
    def __init__(self):
        super(TransformerDecLayer, self).__init__()
        self.ln1 = nn.LayerNorm(args.nhid_tran)
        self.ln2 = nn.LayerNorm(args.nhid_tran)
        self.ln3 = nn.LayerNorm(args.nhid_tran)
        self.dropout1 = nn.Dropout(args.resid_pdrop)
        self.dropout2 = nn.Dropout(args.resid_pdrop)
        self.dropout3 = nn.Dropout(args.resid_pdrop)
        self.attn1 = MaskedMultiheadAttention(mask=True) # self-attention 
        self.attn2 = MaskedMultiheadAttention() # tgt to src attention
        self.ff = nn.Sequential(
            nn.Linear(args.nhid_tran, args.nff),
            nn.ReLU(), 
            nn.Linear(args.nff, args.nhid_tran)
        )
        
    def forward(self, x, enc_o, enc_mask=None):
        # WRITE YOUR CODE HERE
        out1 = self.ln1(x)
        out = self.attn1(out1, out1, out1)
        out = self.dropout1(out)
        out = out + out1

        out2 = self.ln2(out)
        out = self.attn2(out2, enc_o, enc_o, enc_mask)
        out = self.dropout2(out)
        out = out + out2

        out3 = self.ln3(out)
        out = self.ff(out3)
        out = self.dropout3(out)

        out = out + out3

        return out

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len=4096):
        super().__init__()
        dim = args.nhid_tran
        pos = np.arange(0, max_len)[:, None]
        i = np.arange(0, dim // 2)
        denom = 10000 ** (2 * i / dim)

        pe = np.zeros([max_len, dim])
        pe[:, 0::2] = np.sin(pos / denom)
        pe[:, 1::2] = np.cos(pos / denom)
        pe = torch.from_numpy(pe).float()

        self.register_buffer('pe', pe)

    def forward(self, x):
        # DO NOT MODIFY
        return x + self.pe[:x.shape[1]]

class TransformerEncoder(nn.Module):

    def __init__(self):
        super(TransformerEncoder, self).__init__()
        # input embedding stem
        self.tok_emb = nn.Embedding(src_ntoken, args.nhid_tran)
        self.pos_enc = PositionalEncoding()
        self.dropout = nn.Dropout(args.embd_pdrop)
        # transformer
        self.transform = nn.ModuleList([TransformerEncLayer() for _ in range(args.nlayers_transformer)])
        # decoder head
        self.ln_f = nn.LayerNorm(args.nhid_tran)
        

    def forward(self, x, mask):
        # WRITE YOUR CODE HERE
        out = self.tok_emb(x)
        out = self.pos_enc(out)
        out = self.dropout(out)

        for i in self.transform:
          out = i(out, mask)
        out = self.ln_f(out)
        
        return out

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self):
        super(TransformerDecoder, self).__init__()
        self.tok_emb = nn.Embedding(trg_ntoken, args.nhid_tran)
        self.pos_enc = PositionalEncoding()
        self.dropout = nn.Dropout(args.embd_pdrop)
        self.transform = nn.ModuleList([TransformerDecLayer() for _ in range(args.nlayers_transformer)])
        self.ln_f = nn.LayerNorm(args.nhid_tran)
        self.lin_out = nn.Linear(args.nhid_tran, trg_ntoken)
        self.lin_out.weight = self.tok_emb.weight


    def forward(self, x, enc_o, enc_mask):
        # WRITE YOUR CODE HERE
        out = self.tok_emb(x)
        out = self.pos_enc(out)
        out = self.dropout(out)

        for i in self.transform:
          out = i(out, enc_o, enc_mask)
        
        out = self.ln_f(out)
        logits = self.lin_out(out)
        logits /= args.nhid_tran ** 0.5 # Scaling logits. Do not modify this
        return logits

In [None]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder()
        self.decoder = TransformerDecoder()
        
    def forward(self, x, y, length_x, max_len=None, teacher_forcing=True):
        # WRITE YOUR CODE HERE
        enc_mask = None
        if length_x is not None:
          T = x.shape[1]
          B = x.shape[0]
          lol = torch.range(0, T-1).expand(B,T).to(device)
          lol2 = length_x.reshape(-1, 1).expand(B, T)#.to(device)
          enc_mask = torch.zeros(B, T).to(device)
          enc_mask[lol< lol2] = 1

        enc_o = self.encoder(x, enc_mask)
        if teacher_forcing is True or self.training is True:
          out = self.decoder(y[:, :-1], enc_o, enc_mask)
        else:
          dec_input = y[:, :1]
          for i in range(1, max_len):
            dec_output = self.decoder(dec_input, enc_o, enc_mask)
            dec_input = torch.concat((dec_input, dec_output[:,-1:].argmax(-1)), dim=1)
          #dec_output = self.decoder(dec_input, enc_o, enc_mask)

          out = dec_output

        return out