# Transformer

In [14]:
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
import torch.nn.functional as F
from torch.nn.modules.normalization import LayerNorm
import copy

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
torch.empty((2, 2)).to(device=device, dtype=dtype)

tensor([[0., 0.],
        [0., 0.]], dtype=torch.float64)

In [None]:
# Todo: My layernorm

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim, self.vdim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj_weight = torch.empty((embed_dim, embed_dim)).to(device=device, dtype=dtype)
        xavier_uniform_(self.in_proj_weight)
        self.in_proj_bias = torch.zeros(3 * embed_dim).to(device=device, dtype=dtype)

    def forward(self, query, key, value):
        pass







class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_encoder_layers, encoder_norm):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_encoder_layers)])
        self.num_layers = num_encoder_layers
        self.norm = encoder_norm

    def forward(self, src, src_mask):
        # Todo: mask
        
        output = src

        for layer in self.layers:
            output = layer(output, src_mask)
        
        output = self.norm(output)
        
        return output

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout)

        # feedforward
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.relu

    def forward(self, src, src_mask):

        # Todo: mask

        x = src

        x = self.norm1(x + self._sa_block(self.norm1(x), src_mask))
        x = self.norm2(x + self._ff_block(x))

        return x

    def _sa_block(self, x, attn_mask):
        x = self.self_attn(x, x, x, attn_mask)
        return self.dropout1(x)
    
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)
    


class Transformer(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, 
                 dim_feedforward, dropout):
        super().__init__()

        # encoder
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        encoder_norm = LayerNorm(d_model)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        #decoder
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        decoder_norm = LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self.d_model = d_model
        self.nhead = nhead


    
    def forward(self, src, tgt, src_mask, tgt_mask):
        """
        src: (seq_s, batch_size, embedding)
        """
        memory = self.encoder(src, src_mask)
        output = self.decoder(tgt, memory, tgt_mask)

        return output
