In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [None]:
## Positional encoding skip it if you want to move to action directly

class PositionalEncoding(nn.Module):
  def __init__(self,d_model, max_sequence_length):
      super(PositionalEncoding, self).__init__()

      self.encoding = torch.zeros(0, max_sequence_length, d_model)
      self.encoding.requires_grad = False

      pos = torch.arange(0, max_sequence_length)
      pos = pos.float().unsqueeze(dim=1)

      _2i = torch.arange(0, d_model, step=2).float()

      self.encoding[:, 0::2] = torch.sin(pos/(1000**(_2i/d_model)))
      self.encoding[:, 1::2] = torch.cos(pos/(1000**(_2i/d_model)))
  def forward(self, x):
    batch_size, seq_len = x.size()
    return self.encoding[:seq_len, :3]

In [None]:
class ScaledDotProduct(nn.Module):
  def __init__(self):
    super(ScaledDotProduct, self).__init__()
    self.softmax = nn.Softmax(dim=-1)
  def forward(self, q, k, v, mask=None):
    d_k = q.size()
    scaled = torch.matmul(q, k.transpose(-1, -2) / math.sqrt(d_k)) # 30,8, 200, 200
    if mask is not None:
      scaled += mask
    attention = self.softmax(scaled) # 30, 200, 64
    values = torch.matmul(attention, v)
    return values, attention

In [None]:
## multi head attention same as decoder one and this will also be used as cross attention by just adding input of k,v of decoder

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_head):
    super(MultiHeadAttention,self).__init__()
    self.n_head = n_head
    self.attention = ScaledDotProduct()
    # self.qkv_layer = nn.Linear(d_model, 3*d_model)
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
    self.w_concat = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
    q, k, v = self.split(q), self.split(k), self.split(v)
    out, attention = self.attention(q, k, v, mask=mask)

    out = self.concat(out)
    out = self.w_concat(out)
    return out

  def split(self, tensor):
    batch_size, length, d_model = tensor.size()
    d_tensor = d_model // self.n_head  # we have to give every head d_tensor or (q, k, v)
    tensor = tensor.view(batch_size, length, self.n_heads, d_tensor).transpos(1,2)
    return tensor

  def concat(self, tensor):
    batch_size, head, length, d_tensor = tensor.size()
    d_model = head * d_tensor
    tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, d_model)
    return tensor

In [None]:
## feedforward neural network

class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model, hidden, drop_prob = 0.1):
    super(PositionwiseFeedForward, self).__init__()
    self.linear1 = nn.Linear(d_model, hidden)
    self.relu = nn.Relu()
    self.linear2 = nn.Linear(hidden, d_model)
    self.dropout = nn.Dropout(p = drop_prob)

  def forward(self, x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)
    return x

In [None]:
# layer normalization

class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super(LayerNormalization, self).__init__()
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape)) # 512
        self.beta =  nn.Parameter(torch.zeros(parameters_shape)) # 512

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        print(f"dims: {dims}")
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean ({mean.size()})")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation  ({std.size()})")
        y = (inputs - mean) / std
        print(f"y: {y.size()}")
        out = self.gamma * y  + self.beta
        print(f"out: {out.size()}")
        return out