In [51]:
import math
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn.parameter import Parameter

In [52]:
batch_size = 128
seq_length = 27

# Multi-head Attention

In [53]:
def scaled_dot_product(q, k, v):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)

    return values, attention

In [54]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, dropout = 0.1, batch_size = batch_size, seq_length = seq_length):
        super().__init__()
        self.embed_dim = embed_dim
        self.nhead = nhead
        self.dropout = dropout
        self.head_dim = embed_dim // nhead
        self.batch_size = batch_size
        self.seq_length = seq_length

        assert self.head_dim * nhead == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))

        self.o_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, attn_mask = None):
        batch_size, seq_length, embed_dim = query.shape
        print("batch_size", batch_size, 'seq_length', seq_length, 'embed_dim',  embed_dim)
        if attn_mask == None:
            if query is key and key is value:
                qkv = torch._C._nn.linear(query, self.in_proj_weight)
                qkv = qkv.unflatten(-1, (3, self.embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
                q, k, v = qkv[0], qkv[1], qkv[2]

            else:
                w_q, w_k, w_v = self.in_proj_weight.chunk(3)
                q, k, v = torch._C._nn.linear(query, w_q), torch._C._nn.linear(key, w_k), torch._C._nn.linear(value, w_v)
            
            print('q', q.size())
            print('k', k.size())
            print('v', v.size())
            # batch_size, seq_length, self.nhead, 3*self.head_dim
            q = q.view(batch_size, seq_length, self.nhead, self.head_dim)
            k = k.view(batch_size, seq_length, self.nhead, self.head_dim)
            v = v.view(batch_size, seq_length, self.nhead, self.head_dim)

            q = q.permute(0, 2, 1, 3)
            k = k.permute(0, 2, 1, 3)
            v = v.permute(0, 2, 1, 3)

            print(q.size())
           
            values, attention = scaled_dot_product(q, k, v)
            values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
            values = values.reshape(batch_size, seq_length, self.q)

            o = self.o_proj(values)
            return o, attention



# Encoder

In [55]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

        self.attention = MultiheadAttention(d_model, nhead, dropout = dropout)

    def forward(self, x, src_mask = None):
        # sa
        attn = self.attention(x, x, x, attn_mask = src_mask)[0]
        x = x + self.dropout1(attn)
        x = self.norm1(x)

        # ff
        x = x + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))
        x = self.norm2(x)

        return x

In [56]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, norm=None, ** block_args):
        super().__init__()
        self.num_layers = num_layers
        self.norm = norm
        self.layers = nn.ModuleList([TransformerEncoderLayer(**block_args) for _ in range(num_layers)])

    def forward(self, x, src_mask = None):
        for mod in self.layers:
            output = mod(x, src_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output
        

# Decoder

In [57]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout = 0.1):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout = dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout = dropout)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm3 = nn.LayerNorm(d_model, eps=1e-5)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, x, memory, tgt_mask, memory_mask):
        sa = self.self_attn(x, x, x, attn_mask=tgt_mask)[0]
        x = self.norm1(x + self.dropout1(sa))

        ma = self.multihead_attn(x, memory, memory, attn_mask=memory_mask)[0]
        x = self.norm2(x + self.dropout2(ma))

        ff = self.linear2(self.dropout(self.activation(self.linear1(x))))
        x = self.norm3(x + self.dropout2(ff))

        return x

In [58]:
class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, norm = None, ** block_args):
        super().__init__()
        self.num_layers = num_layers
        self.norm = norm
        self.layers = nn.ModuleList([TransformerDecoderLayer(**block_args) for _ in range(num_layers)])

    def forward(self, tgt, memory, tgt_mask = None, memory_mask = None):
        for mod in self.layers:
            output = mod(tgt, memory, tgt_mask, memory_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

# Transformer

In [59]:
class Transformer(nn.Module):
    def __init__(self, d_model = 100, nhead = 2, num_encoder_layers = 3, num_decoder_layers = 3, dim_feedforward = 64, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.nhead = nhead

        block_args = {'d_model': d_model, 'nhead': nhead, 'dim_feedforward': dim_feedforward, 'dropout': dropout}

        encoder_norm = nn.LayerNorm(d_model, eps=1e-5)
        self.encoder = TransformerEncoder(num_encoder_layers, encoder_norm, **block_args)

        decoder_norm =nn.LayerNorm(d_model, eps=1e-5)
        self.decoder = TransformerDecoder(num_decoder_layers, decoder_norm, **block_args)

    def forward(self, src, tgt, src_mask = None, tgt_mask = None, memory_mask = None):
        memory = self.encoder(src, src_mask)
        output = self.decoder(tgt, memory, tgt_mask, memory_mask)

        return output