# Transformer

In [59]:
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
import torch.nn.functional as F
from torch.nn.modules.normalization import LayerNorm
import copy
import math

In [54]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float64
tmp = torch.rand((10, 2, 16)).to(device=device, dtype=dtype)
tmp.shape

torch.Size([10, 2, 16])

In [74]:
# Todo: My layernorm
# Todo: solfmax

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim, self.vdim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # self attention, q = k = v
        # proj = w * x + b     
        # w: (3 * embed_dim, embed_dim)
        # x: (seq, bsz, embed_dim)
        # w * x: (seq, bsz, 3 * embed_dim)
        # b: 3 * embed_dim

        self.linear_q = nn.Linear(embed_dim, embed_dim)
        self.linear_k = nn.Linear(embed_dim, embed_dim)
        self.linear_v = nn.Linear(embed_dim, embed_dim)
        self.linear_output = nn.Linear(embed_dim, embed_dim)

        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, attn_mask):
        # Todo: mask

        tgt_len, bsz, embed_dim = query.shape
        src_len, _, _ = key.shape

        query, key, value = self.linear_q(query), self.linear_k(key), self.linear_v(value)

        # reshape to: [bsz * num_heads, seq_len, embed_dim_per_head]
        query = query.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        key = key.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        value = value.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        # calculate attention
        B, Nt, E = query.shape
        q_scaled = query / math.sqrt(E)

        attn_output_weights = torch.bmm(q_scaled, key.transpose(-2, -1))
        attn_output_weights = self.softmax(attn_output_weights, dim=-1)
        attn_output_weights = self.dropout(attn_output_weights)     # [bsz * num_heads, tgt_len, src_len]

        attn_output = torch.bmm(attn_output_weights, value)     # [bsz * num_heads, tgt_len, embed_dim_per_head]
        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
        attn_output = self.linear_output(attn_output)
        attn_output = attn_output.view(tgt_len, bsz, -1)

        attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)

        return attn_output, attn_output_weights

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_encoder_layers, encoder_norm):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_encoder_layers)])
        self.num_layers = num_encoder_layers
        self.norm = encoder_norm

    def forward(self, src, src_mask):
        # Todo: mask
        
        output = src

        for layer in self.layers:
            output = layer(output, src_mask)
        
        output = self.norm(output)
        
        return output

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout)

        # feedforward
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = F.relu

    def forward(self, src, src_mask):

        # Todo: mask

        x = src

        x = self.norm1(x + self._sa_block(self.norm1(x), src_mask))
        x = self.norm2(x + self._ff_block(x))

        return x

    def _sa_block(self, x, attn_mask):
        x = self.self_attn(x, x, x, attn_mask)
        return self.dropout1(x)
    
    def _ff_block(self, x):
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)
    
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_decoder_layers, decoder_norm):
        super().__init__()


class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()


class Transformer(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, 
                 dim_feedforward, dropout):
        super().__init__()

        # encoder
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        encoder_norm = LayerNorm(d_model)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        #decoder
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)
        decoder_norm = LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self.d_model = d_model
        self.nhead = nhead
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        """
        src: (seq_s, batch_size, embedding)
        """
        memory = self.encoder(src, src_mask)
        output = self.decoder(tgt, memory, tgt_mask)

        return output
