In [None]:
import torch
import torch.nn as nn

In [None]:
# nn.Module is the subclass
# Generate Random Embeddings
class Embedding(nn.Module):

  def __init__(self, dict_size, dmodel = 512):
    super().__init__()
    self.dict_size = dict_size;
    self.dmodel = dmodel;
    self.embeddings = nn.Embedding(dict_size, dmodel);

  def forward(self, x):
    return self.embeddings(x)

In [None]:
# Periodic Functions
# Bounded Functions
class PositionalEncoding(nn.Module):

  def __init__(self, max_len, dmodel = 512):
    super().__init__();
    self.max_len = max_len;
    self.dmodel = dmodel;
    self.pe = torch.zeros(max_len, dmodel);

    pos = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1);
    mul_term = torch.pow(10000, -1 * torch.arange(0, dmodel, 2, dtype = torch.float));
    self.pe[:, 0::2] = torch.sin(pos * mul_term);
    self.pe[:, 1::2] = torch.cos(pos * mul_term);

    self.register_buffer('pe', self.pe);

  def forward(self, embedding):
    return embedding + self.pe;


In [None]:
# dk = dv = dmodel // num_heads
class MultiHeadAttention(nn.Module):

  def __init__(self, dmodel = 512, num_heads = 8):
    super().__init__();
    self.dmodel = dmodel;
    self.num_heads = num_heads;
    self.head_dim = dmodel // num_heads;
    self.softmax_layer = nn.Softmax(dim = -1);

    self.w_key = nn.Linear(dmodel, dmodel);
    self.w_query = nn.Linear(dmodel, dmodel);
    self.w_value = nn.Linear(dmodel, dmodel);

    self.output = nn.Linear(dmodel, dmodel);

  # Size of Query / Key / Value : (NB, NH, S/T, HD)
  # return Attention Scores : (NB, NH, S/T, HD)
  def attention(self, query, key, value, mask = None):

    attention_score = torch.matmul(query, key.transpose(-1, -2));
    attention_score = attention_score / torch.sqrt(torch.tensor(self.head_dim));

    # Replace the masked positions by very small value,
    # Softmax for masked positions -> 0
    if mask is not None:
      attention_score = attention_score.masked_fill(mask == 0, -1e10);

    attention_score = self.softmax_layer(attention_score);
    attention_score = torch.matmul(attention_score, value);

    return attention_score;

  # Size of Query / Key / Value : (NB, S/T, ED)
  def forward(self, query, key, value, mask = None):

    batch_size = query.shape[0];
    key = self.w_key(key);
    query = self.w_query(query);
    value = self.w_value(value);

    # Reshape and Transpose for calculating Attention Scores

    key = key.reshape(batch_size, -1, self.num_heads, self.head_dim);
    query = query.reshape(batch_size, -1, self.num_heads, self.head_dim);
    value = value.reshape(batch_size, -1, self.num_heads, self.head_dim);

    key = key.transpose(1,2);
    query = query.transpose(1,2);
    value = value.transpose(1,2);

    attention_score = self.attention(query, key, value, mask);
    attention_score = attention_score.tranpose(1,2);
    attention_score = attention_score.reshape(batch_size, -1, self.dmodel);

    return self.output(attention_score);



In [None]:

# FFN(x) = MAX(xW1 + b1, 0)W2 + b2
class FeedForwardNetwork(nn.Module):

  def __init__(self, dmodel = 512, hidden_dim = 2048):
    super().__init__();
    self.dmodel = dmodel;
    self.hidden_dim = hidden_dim;
    self.linear1 = nn.Linear(dmodel, hidden_dim);
    self.linear2 = nn.Linear(hidden_dim, dmodel);
    self.relu = nn.ReLU();

  def forward(self, x):
    x = self.linear1(x);
    x = self.relu(x);
    x = self.linear2(x);
    return x;

In [None]:

# y(x) = x + Sublayer(x)
class Sublayer(nn.Module):
  def __init__(self, dmodel = 512):
    super().__init__();
    self.dmodel = dmodel;
    self.norm = nn.LayerNorm(dmodel);

  def forward(self, x, sublayer):
    return self.norm(x + sublayer);

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, dmodel = 512, num_heads = 8, hidden_layer = 2048):
    super().__init__();
    self.dmodel = dmodel;

    self.multi_head_attention = MultiHeadAttention(dmodel, num_heads);
    self.sublayer1 = Sublayer(dmodel);

    self.feed_forward_network = FeedForwardNetwork(dmodel, hidden_layer);
    self.sublayer2 = Sublayer(dmodel);

  def forward(self, vector_in, src_mask = None):

    attention_out = self.multi_head_attention(vector_in, vector_in, vector_in, src_mask);
    attention_norm = self.sublayer1(vector_in, attention_out);

    ffn_out = self.feed_forward_network(attention_norm);
    ffn_norm = self.sublayer2(attention_norm, ffn_out);
    return ffn_norm;

In [None]:
class EncoderBlock(nn.Module):
  def __init__(self, num_layers = 6, dmodel = 512, num_heads = 8, hidden_layer = 2048):
    super().__init__();
    self.num_layers = num_layers;
    self.dmodel = dmodel;
    self.encoder_layer = EncoderLayer(dmodel, num_heads, hidden_layer);
    self.layers = get_clone(self.encoder_layer, num_layers);

  def forward(self, vector_in, src_mask = None):
    for layer in self.layers:
      vector_out = layer(vector_in, src_mask);
      vector_in = vector_out;

    return vector_out;

In [None]:
class DecoderLayer(nn.Module):
  def __init__(self, dmodel = 512, num_heads = 8, hidden_layer = 2048):
    super().__init__();

    self.dmodel = dmodel;
    self.multi_head_attention1 = MultiHeadAttention(dmodel, num_heads);
    self.sublayer1 = Sublayer(dmodel);

    self.multi_head_attention2 = MultiHeadAttention(dmodel, num_heads);
    self.sublayer2 = Sublayer(dmodel);

    self.feed_forward_network = FeedForwardNetwork(dmodel, hidden_layer);
    self.sublayer3 = Sublayer(dmodel);

  def forward(self, enc_in, dec_in, target_mask):

    attention_out1 = self.multi_head_attention1(dec_in, dec_in, dec_in);
    attention_norm1 = self.sublayer1(dec_in, attention_out1);

    attention_out2 = self.multi_head_attention2(attention_norm1, enc_in, enc_in, target_mask);
    attention_norm2 = self.sublayer2(attention_norm1, attention_out2);

    ffn_out = self.feed_forward_network(attention_norm2);
    ffn_norm = self.sublayer3(ffn_out)

    return ffn_norm;


In [None]:
class DecoderBlock(nn.Module):
  def __init__(self, num_layers = 6, dmodel = 512, num_heads = 8, hidden_layer = 2048):
    super().__init__();
    self.dmodel = dmodel;
    self.num_layers = num_layers;
    self.decoder_layer = DecoderLayer(dmodel, num_heads, hidden_layer);
    self.layers = get_clone(self.decoder_layer, num_layers);

  def forward(self, enc_in, dec_in, target_mask):
    for layer in self.layers:
      vector_out = layer(enc_in, dec_in, target_mask);
      vector_in = vector_out;
    return vector_out;

In [None]:
class DecoderOutput(nn.Module):
  def __init__(self, dmodel):
    super().__init__();
    self.dmodel = dmodel;
    self.linear = nn.Linear(dmodel, dmodel);
    self.softmax = nn.Softmax(dim = -1)

  def forward(self, target_vec):
    dout = self.linear(target_vec);
    dout = self.softmax(dout);
    return dout;


In [None]:
class Transformers(nn.Module):
  def __init__(self, src_vocab_size, target_vocab_size, dmodel = 512, num_heads = 8, hidden_layer = 2048, num_layers = 6):
    super().__init__();
    self.dmodel = dmodel;

    self.src_embedding = Embedding(src_vocab_size, dmodel);
    self.src_pe = PositionalEncoding(src_vocab_size, dmodel);

    self.target_embedding = Embedding(target_vocab_size, dmodel);
    self.target_pe = PositionalEncoding(target_vocab_size, dmodel);

    self.encoder_block = EncoderBlock(num_layers, dmodel, num_heads, hidden_layer);
    self.decoder_block = DecoderBlock(num_layers, dmodel, num_heads, hidden_layer);

    self.output = DecoderOutput(dmodel);

  def forward(self, src_word_idx, target_word_idx, src_mask = None, target_mask = None):
    src_embedding = self.src_embedding(src_word_idx);
    src_embedding = self.src_pe(src_embedding);

    target_embedding = self.target_embedding(target_word_idx);
    target_embedding = self.target_pe(target_embedding);

    enc_out = self.encoder_block(src_embedding, src_mask);
    dec_out = self.decoder_block(enc_out, target_embedding, target_mask);
    output = self.output(dec_out);

    return output;