In [None]:
import torch
import torch.nn as nn

In [None]:
class self_attention(nn.Module):
  '''
  embeded_dim     = 256 # - number of dimensions of the encoder space
  heads_parallel  = 8   # - number of heads
  head_dim        = 32  # - dimension of the single head
  '''
  def __init__(self, embedded_dim, heads_parallel):
    super(self).__init__()
    self.embedded_dim = embedded_dim
    self.heads_parallel = heads_parallel
    self.head_dim = embedded_dim // heads_parallel

    #K, Q, V - push through linear layers
    self.values   = nn.Linear(self.embedded_dim, self.embedded_dim, bias=False)
    self.queries  = nn.Linear(self.embedded_dim, self.embedded_dim, bias=False)
    self.keys     = nn.Linear(self.embedded_dim, self.embedded_dim, bias=False)

    self.multi_head_out = nn.Linear(self.heads_parallel*self.head_dim, self.embedded_dim)

  def forward(self, values, keys, queries, mask):
    N = queries.shape[0] # number of training examples
    value_len, key_len, queries_len = values.shape[1], keys.shape[1], queries.shape[1]

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(queries)

    values  = values.reshape(N, value_len, self.heads_parallel, self.head_dim)
    keys    = keys.reshape(N, value_len, self.heads_parallel, self.head_dim)
    queries = queries.reshape(N, value_len, self.heads_parallel, self.head_dim)

    #raw_weights shape: (N, heads_parallel, key_len, queries_len)
    raw_weights = torch.einsum('nqhd, nkhd->nhqk', [queries, keys])

    # if we are going to use maskes attention
    if mask is not None:
      raw_weights = raw_weights.masked_fill(mask == 0, float("-1e10"))

    # work out the weights for attention
    attention = torch.softmax(raw_weights / (self.embedded_dim ** (1/2)), dim=3)

    # attention dim: (N, heads_parallel, queries_len, key_len)
    out = torch.einsum('nhqs, nshd->nqhd', [attention, values])
    out.reshape(N, queries_len, self.heads_parallel*self.head_dim)
    out = self.multi_head_out(out)

    return out

In [None]:
class encoder_module(nn.Module):
  def __init__(self, embedded_dim, heads_parallel, dropout, forward_dim):
    super(self).__init__()
    self.en_module_attention = self_attention(embedded_dim, heads_parallel)
    self.add_norm_attention = nn.LayerNorm(embedded_dim)
    self.add_norm_feed_forward = nn.LayerNorm(embedded_dim)
    self.feed_forward_nn = nn.Sequential(
        nn.Linear(embedded_dim, forward_dim*embedded_dim),
        nn.Relu(),
        nn.Linear(forward_dim*embedded_dim, embedded_dim)
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, value, key, query, mask): # Q, K, V are vanilla
    embedding = value # can be any V, K, Q
    attention = self.en_module_attention(value, key, query, mask) # Q, K, V changed only internally
    x = self.dropout(self.add_norm_attention(attention + embedding))
    forward_nn = self.feed_forward_nn(x)
    out = self.droput(self.add_norm_feed_forward(forward_nn + x))
    return out



In [None]:
class encoder(nn.Module):
  '''
  vocabulary - input vocabul size
  device
  forward_dim - is the input embedding expansion
  max_dim - max size of the sequence (sentence)
  '''
  def __init__(self, vocabulary, device, embedded_dim, layers, heads_parallel,
               forward_dim, dropout, max_dim):
    super(self).__init__()
    self.embedded_dim = embedded_dim
    self.device = device
    self.word_embedding = nn.Embedding(vocabulary, embedded_dim)
    self.position_embedding = nn.Embedding(max_dim, embedded_dim)

    self.architecture = nn.ModuleList(
        [
            encoder_module(embedded_dim, heads_parallel, dropout, forward_dim)
            for _ in range(layers)
        ]
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    N, seq_dim = x.shape
    positions = torch.arange(0, seq_dim).expand(N, seq_dim).to(self.device)
    vkq = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

    for layer in self.architecture:
      vkq = layer(vkq, vkq, vkq)
    return vkq


In [None]:
class decoder_module(nn.Module):
  def __init__(self, embedded_dim, heads_parallel, forward_dim, dropout, device):
    super(self).__init__()
    self.attention = self.attention(embedded_dim, heads_parallel)
    self.add_norm_attention = nn.LayerNorm(embedded_dim)
    self.add_norm_feed_forward = nn.LayerNorm(embedded_dim)
    self.encoder_module = encoder_module(embedded_dim, heads_parallel, dropout, forward_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key, in_mask, target_mask):
    embedding = x
    attention = self.attention(x, x, x, target_mask)
    query = self.dropout(self.add_norm_attention(attention + embedding))
    out = self.encoder_module(value, key, query, in_mask)
    return out


In [None]:
class decoder(nn.Module):
  def __init__(self,target_vocabulary, embedded_dim, layers, heads_parallel,
               forward_dim, dropout, device, max_dim):
    super(self).__init__()

    self.device = device
    self.word_embedding = nn.Embedding(target_vocabulary, embedded_dim)
    self.position_embedding = nn.Embedding(max_dim, embedded_dim)

    self.architecture = nn.ModuleList(
        [
            decoder_module(embedded_dim, heads_parallel, forward_dim, dropout, device)
            for _ in range(layers)
        ]
    )
    self.forward_nn = nn.Linear(embedded_dim, target_vocabulary)
    self.dropout = nn.Dropout(dropout)


  def forwars(self, x, encoder_out, in_mask, target_mask):
    N, seq_dim = x.shape
    positions = torch.arange(0, seq_dim).expand(N, seq_dim).to(self.device)
    vkq = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

    for layer in self.architecture:
      out_decoder = layer(vkq, x, encoder_out, encoder_out, in_mask, target_mask)
    out_decoder = self.forward_nn(out_decoder)
    return out_decoder