In [174]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange

In [19]:
transformer_model = nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1)

In [20]:
src = torch.rand((10, 32, 512)) # batch_size x time_stamps x token_size
tgt = torch.rand((20, 32, 512))
out = transformer_model(src, tgt)

In [21]:
out.shape

torch.Size([20, 32, 512])

# Transformer from scratch

In [121]:
class hEmbedding(nn.Module):
    def __init__(self, dmodel=512):
        super(hEmbedding, self).__init__()
        self.dmodel = dmodel
    def forward(self, x):
        b, t, d = x.size()
        assert d==self.dmodel, "the size of token doesn't match that of tensor x"
        pos = torch.arange(t)
        i = torch.arange(d//2)
        pos_embedding_0 = torch.stack([torch.sin(torch.ones(d//2)*p / 10000**((2*i)/self.dmodel)) for p in pos])
        pos_embedding_1 = torch.stack([torch.cos(torch.ones(d//2)*p / 10000**((2*i+1)/self.dmodel)) for p in pos])
        pos_embedding = torch.stack((pos_embedding_0, pos_embedding_1), dim=1)
        pos_embedding = pos_embedding.view(len(pos), -1) 
        pos_embedding = pos_embedding.unsqueeze(0).repeat(b, 1, 1)
        pos_embedding.requires_grad = False
        self.register_buffer('pe', pos_embedding, persistent=False)
        x += pos_embedding
        return x

In [122]:
he = hEmbedding()

In [123]:
e_src = he(src)

In [159]:
e_src.shape

torch.Size([10, 32, 512])

In [164]:
class hAttention(nn.Module):
    def __init__(self, dmodel=512, dim_feedforward=2048, dropout=0.1):
        super(hAttention, self).__init__()
        self.dmodel = dmodel
        self.projection = nn.Linear(dmodel, dmodel*3)
        self.att_drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dmodel)
        
    def forward(self, x):
        b, t, d = x.shape
        assert d==self.dmodel, "the size of token doesn't match that of tensor x"
        qkv = self.projection(x).view(b, t, d, 3)
        q, k, v = qkv[...,0], qkv[...,1], qkv[...,2]
        att = torch.einsum('bqd, bkd -> bqk', q, k)
        att /= math.sqrt(d)
        att = F.softmax(att, dim=-1)
        att = self.att_drop(att)
        
        out = torch.einsum('bqt, btd -> bqd ', att, v)
        out += x
        out = self.norm(out)        
        return out

In [165]:
hatt = hAttention()

In [166]:
a_src = hatt(src)

In [167]:
a_src.shape

torch.Size([10, 32, 512])

In [181]:
class hMultiHeadAttention(nn.Module):
    def __init__(self, nhead=8, dmodel=512, dim_feed_forward=2048, dropout=0.1):
        super(hMultiHeadAttention, self).__init__()
        self.dmodel = dmodel
        self.nhead = nhead
        self.projection = nn.Linear(dmodel, dmodel*3*nhead)
        self.att_drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dmodel)
        self.linear = nn.Linear(dmodel*nhead, dmodel)
        
    def forward(self, x):
        b, t, d = x.shape
        assert d==self.dmodel, "the size of token doesn't match that of tensor x"
        qkv = self.projection(x).view(b, t, d, self.nhead, 3)
        q, k, v = qkv[...,0], qkv[...,1], qkv[...,2]
        q = rearrange(q, 'b t d h -> b h t d')
        k = rearrange(k, 'b t d h -> b h t d')
        v = rearrange(v, 'b t d h -> b h t d')
        att = torch.einsum('b h q d, b h k d -> b h q k', q, k)
        att /= math.sqrt(d)
        att = F.softmax(att, dim=-1)
        att = self.att_drop(att)

        out = torch.einsum('b h q t, b h t d -> b h q d', att, v)
        out = rearrange(out, 'b h t d -> b t h d')
        out = torch.flatten(out, start_dim=2, end_dim=-1)
        out = self.linear(out)
        out += x
        out = self.norm(x)
        return out

In [182]:
hmatt = hMultiHeadAttention()

In [183]:
ma_src = hmatt(src)

In [None]:
class hFeedForward(nn.Module):
    def __init__(self, dmodel, dim_feed_forward=2048):
        self.linear_1 = nn.Linear(dmodel, dim_feedforward)
        self.linear_2 = nn.Linear(dim_feedforward, dmodel)
        self.norm = nn.LayerNorm(dmodel)
    def forward(self, x):
        out = self.linear_1(x)
        out = self.linear_2(out)
        out += x
        out = self.norm(out)
        return out