In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
# import seaborn
# seaborn.set_context(context="talk")
%matplotlib inline

### Encoder Decoder Class 
*STANDARD ENCODER DECODER ARCHITECTURE (existed before this work)*

---

(src, src_mask) --> ENCODER --> (TBD)

(tgt_embed_input, memory, src_mask, tgt_mask) --> DECODER -->  (TBD)

In [3]:
class EncoderDecoder(nn.Module):
    
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator) -> None:
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        encoded_output  = self.encode(src, src_mask)
        decoder_output = self.decode(encoded_output, src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

#### LOG SOFTMAX INTERPRETATION

Information theoretic interpretation: Log softmax can be seen as a measure of the information content or surprise of a given output. The lower the log softmax value, the more surprising or unlikely the output is. This can help the model to penalize incorrect predictions more heavily and reward correct predictions more appropriately

logits -> softmax (gives them values between 0 and 1 i.e. probabilities)-> log_softmax (lower the probability, higer negative value for that logit)

#### GENERATOR
maps the output of the model decoder to vocab words




In [4]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Liner(d_model, vocab)
    
    def forward(self, x):
        return F.log_softmax(self.proj(x), dim=-1)

### clones
It taskes a layer and makes `N` copies of it.



In [5]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

## LayerNorm 
Features -> output of a layer

a2 -> Parameters -- sizeof(Features), init with 1 (multiplicative logit)

b2 -> Parameters -- sizeof(Features), init with 0 (additive logit)

output 
$$lnorm(X) = \frac{a2 \times (X - \mu(X))}{\sigma{(X)} + \epsilon} + b2$$

a2, b2 -> since they are parameters learnt during back propagation

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6) -> None:
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / ( std + self.eps ) + self.b_2




In [9]:
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout) -> None:
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [10]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self)
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)
    
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)