# Minimal Transformer Example 

Without using a real dataset - for learning purposes

Materials: http://jalammar.github.io/illustrated-transformer/

omitted causal masking and a few dropout layers for simplicity

In [1]:
import torch
import numpy as np

## Embed

In [2]:
dict_size = 10
embedding_dim = 4
scale_factor = 1.0 / np.sqrt(embedding_dim)
max_seq_length = 4
num_heads = 5

In [3]:
embedding_layer = torch.nn.Embedding(dict_size, embedding_dim)

In [4]:
input = torch.LongTensor([[1, 2, 0, 2], [0, 4, 2, 9]])
embedding = embedding_layer(input)

In [5]:
print(embedding.shape)

torch.Size([2, 4, 4])


In [6]:
target = torch.LongTensor([0, 1])

In [7]:
class PositionalEncoder():
    def __init__(self, embedding_dim, max_seq_length):
        super().__init__()
        self.positional_encoding = torch.zeros((1, max_seq_length, embedding_dim), requires_grad=False)
        
        for pos in range(max_seq_length):
            for i in range(0, embedding_dim, 2):
                self.positional_encoding[0, pos, i] = np.sin(float(pos) / (10000 ** (2 * float(i)/embedding_dim)))
                self.positional_encoding[0, pos, i + 1] = np.cos(float(pos) / (10000 ** (2 * float(i)/embedding_dim)))        
    
    def get(self):
        return self.positional_encoding

## Self attention

In [8]:
# create query, keys, values using the embeddings
class AttentionHead(torch.nn.Module):
    def __init__(self, embedding_dim, scale_factor):
        super().__init__()
        self.scale_factor = scale_factor
        self.w_q = torch.nn.Linear(embedding_dim, embedding_dim)
        self.w_k = torch.nn.Linear(embedding_dim, embedding_dim)
        self.w_v = torch.nn.Linear(embedding_dim, embedding_dim)
        
    def compute_attention(self, q, k, v, input, embedding):
        mask = np.zeros(embedding.shape)
        mask[input > 0.] = 1.0
        
        attention = torch.nn.functional.softmax(torch.bmm(q, torch.transpose(k, 2, 1)) * self.scale_factor, dim = -1)
        attention = torch.masked_fill(attention, torch.Tensor(mask) == 0, -1e9)
        attention = torch.bmm(attention, v)
        
        return attention
        
    def forward(self, input, embedding):
        q = self.w_q(embedding)
        k = self.w_k(embedding)
        v = self.w_v(embedding)
        
        return self.compute_attention(q, k, v, input, embedding)

In [9]:
# create multi-head attention
class MultiAttentionHead(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, num_heads, scale_factor):
        super().__init__()
        self.scale_factor = scale_factor
        self.heads = torch.nn.ModuleList([AttentionHead(embedding_dim, scale_factor) for _ in range(num_heads)])
        self.linear = torch.nn.Linear(num_heads * embedding_dim, embedding_dim)
        
        self.norm_layer = torch.nn.LayerNorm([max_seq_length, embedding_dim])
    
    def forward(self, input, embedding):
        heads_attention = torch.cat([head(input, embedding) for head in self.heads], dim=-1)

        heads_attention = self.linear(heads_attention)
        
        # normalize
        heads_attention = self.norm_layer(heads_attention)
        
        return heads_attention 

## Residual Encoder

In [10]:
# feed forward
class FeedForward(torch.nn.Module):
    def __init__(self, embedding_dim, dim_net=64):
        super().__init__()
        self.linear1 = torch.nn.Linear(embedding_dim, dim_net)
        self.dropout = torch.nn.Dropout(0.1)
        self.linear2 = torch.nn.Linear(dim_net, embedding_dim)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        
        return x

In [11]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, num_heads, scale_factor):
        super().__init__()
        self.attention_head = MultiAttentionHead(embedding_dim, max_seq_length, num_heads, scale_factor)
        self.ffw_net = FeedForward(embedding_dim)
        self.encoder_norm = torch.nn.LayerNorm([max_seq_length, embedding_dim])
        
    def forward(self, x):
        x = self.attention_head(input, x)
        x = self.ffw_net(x)
        # + residual
        x += x
        x = self.encoder_norm(x)
        
        return x

In [12]:
class Encoder(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, num_heads, scale_factor, num_layers):
        super().__init__()
        self.num_layers = num_layers
        self.layers = torch.nn.ModuleList([EncoderLayer(embedding_dim, max_seq_length, num_heads, scale_factor) for _ in range(num_layers)])
        self.pe = PositionalEncoder(embedding_dim, max_seq_length)
        
    def forward(self, x):
        x += self.pe.get()
        for i in range(self.num_layers):
            x = self.layers[i](x)
        
        return x

## Decoder

In [13]:
class DecoderLayer(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, num_heads, scale_factor):
        super().__init__()
        self.init_norm = torch.nn.LayerNorm([max_seq_length, embedding_dim])
        self.attention_head1 = MultiAttentionHead(embedding_dim, max_seq_length, num_heads, scale_factor)
        self.attention_head2 = MultiAttentionHead(embedding_dim, max_seq_length, num_heads, scale_factor)
        self.ffw_net = FeedForward(embedding_dim)
        self.decoder_norm1 = torch.nn.LayerNorm([max_seq_length, embedding_dim])
        self.decoder_norm2 = torch.nn.LayerNorm([max_seq_length, embedding_dim])
        
        
    def forward(self, x):
        x = self.init_norm(x)
        
        x = self.attention_head1(input, x)
        x = self.ffw_net(x)
        # + residual
        x += x
        x = self.decoder_norm1(x)
        x = self.attention_head2(input, x)
        x = self.ffw_net(x)
        x += x
        x = self.decoder_norm2(x)
        
        return x

In [14]:
class Decoder(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, num_heads, scale_factor, num_layers):
        super().__init__()
        
        self.num_layers = num_layers
        self.layers = torch.nn.ModuleList([EncoderLayer(embedding_dim, max_seq_length, num_heads, scale_factor) for _ in range(num_layers)])
        self.pe = PositionalEncoder(embedding_dim, max_seq_length)
        
    def forward(self, x):
        for i in range(self.num_layers):
            x = self.layers[i](x)
        x += self.pe.get()
        
        return x

In [15]:
class Transformer(torch.nn.Module):
    def __init__(self, embedding_dim, max_seq_length, 
                 num_heads, scale_factor, 
                 num_layers_encoder=2, num_layers_decoder=2,
                 target_dim=2):
        super().__init__()
        self.encoder = Encoder(embedding_dim, max_seq_length, num_heads, scale_factor, num_layers_encoder)
        self.decoder = Decoder(embedding_dim, max_seq_length, num_heads, scale_factor, num_layers_decoder)
        self.out = torch.nn.Linear(embedding_dim * max_seq_length, target_dim)
        
    def forward(self, x):
        encoder_x = self.encoder(x)
        decoder_x = self.decoder(encoder_x)
        decoder_x = torch.reshape(decoder_x, (decoder_x.shape[0], -1))
        out = self.out(decoder_x)
        
        return out

## Training

In [16]:
model = Transformer(embedding_dim, max_seq_length, num_heads, scale_factor)

In [17]:
optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
loss = torch.nn.CrossEntropyLoss()

In [20]:
model.train()
for epoch in range(1000):
    output = model(embedding)

    optim.zero_grad()

    loss_value = loss(output, target)
    print(loss_value)
    
    loss_value.backward(retain_graph=True)
    
    optim.step()

tensor(0.3799, grad_fn=<NllLossBackward>)
tensor(0.3296, grad_fn=<NllLossBackward>)
tensor(0.3300, grad_fn=<NllLossBackward>)
tensor(0.2916, grad_fn=<NllLossBackward>)
tensor(0.3360, grad_fn=<NllLossBackward>)
tensor(0.5611, grad_fn=<NllLossBackward>)
tensor(0.4924, grad_fn=<NllLossBackward>)
tensor(0.5628, grad_fn=<NllLossBackward>)
tensor(0.2354, grad_fn=<NllLossBackward>)
tensor(0.3193, grad_fn=<NllLossBackward>)
tensor(0.3984, grad_fn=<NllLossBackward>)
tensor(0.4316, grad_fn=<NllLossBackward>)
tensor(0.3363, grad_fn=<NllLossBackward>)
tensor(0.2720, grad_fn=<NllLossBackward>)
tensor(0.3062, grad_fn=<NllLossBackward>)
tensor(0.2967, grad_fn=<NllLossBackward>)
tensor(0.2839, grad_fn=<NllLossBackward>)
tensor(0.2568, grad_fn=<NllLossBackward>)
tensor(0.4196, grad_fn=<NllLossBackward>)
tensor(0.3694, grad_fn=<NllLossBackward>)
tensor(0.6779, grad_fn=<NllLossBackward>)
tensor(0.2094, grad_fn=<NllLossBackward>)
tensor(0.2244, grad_fn=<NllLossBackward>)
tensor(0.3022, grad_fn=<NllLossBac

tensor(0.2023, grad_fn=<NllLossBackward>)
tensor(0.1662, grad_fn=<NllLossBackward>)
tensor(0.1532, grad_fn=<NllLossBackward>)
tensor(0.1487, grad_fn=<NllLossBackward>)
tensor(0.1694, grad_fn=<NllLossBackward>)
tensor(0.1331, grad_fn=<NllLossBackward>)
tensor(0.2109, grad_fn=<NllLossBackward>)
tensor(0.1573, grad_fn=<NllLossBackward>)
tensor(0.1902, grad_fn=<NllLossBackward>)
tensor(0.1589, grad_fn=<NllLossBackward>)
tensor(0.3417, grad_fn=<NllLossBackward>)
tensor(0.1475, grad_fn=<NllLossBackward>)
tensor(0.1992, grad_fn=<NllLossBackward>)
tensor(0.1469, grad_fn=<NllLossBackward>)
tensor(0.1714, grad_fn=<NllLossBackward>)
tensor(0.3623, grad_fn=<NllLossBackward>)
tensor(0.1416, grad_fn=<NllLossBackward>)
tensor(0.1792, grad_fn=<NllLossBackward>)
tensor(0.1359, grad_fn=<NllLossBackward>)
tensor(0.1488, grad_fn=<NllLossBackward>)
tensor(0.1695, grad_fn=<NllLossBackward>)
tensor(0.2019, grad_fn=<NllLossBackward>)
tensor(0.1322, grad_fn=<NllLossBackward>)
tensor(0.1596, grad_fn=<NllLossBac

tensor(0.1092, grad_fn=<NllLossBackward>)
tensor(0.1360, grad_fn=<NllLossBackward>)
tensor(0.1010, grad_fn=<NllLossBackward>)
tensor(0.1310, grad_fn=<NllLossBackward>)
tensor(0.0931, grad_fn=<NllLossBackward>)
tensor(0.1079, grad_fn=<NllLossBackward>)
tensor(0.1323, grad_fn=<NllLossBackward>)
tensor(0.1185, grad_fn=<NllLossBackward>)
tensor(0.0971, grad_fn=<NllLossBackward>)
tensor(0.1024, grad_fn=<NllLossBackward>)
tensor(0.1248, grad_fn=<NllLossBackward>)
tensor(0.1166, grad_fn=<NllLossBackward>)
tensor(0.1019, grad_fn=<NllLossBackward>)
tensor(0.0961, grad_fn=<NllLossBackward>)
tensor(0.1114, grad_fn=<NllLossBackward>)
tensor(0.0845, grad_fn=<NllLossBackward>)
tensor(0.0935, grad_fn=<NllLossBackward>)
tensor(0.1042, grad_fn=<NllLossBackward>)
tensor(0.2602, grad_fn=<NllLossBackward>)
tensor(0.0960, grad_fn=<NllLossBackward>)
tensor(0.1091, grad_fn=<NllLossBackward>)
tensor(0.1031, grad_fn=<NllLossBackward>)
tensor(0.1265, grad_fn=<NllLossBackward>)
tensor(0.1229, grad_fn=<NllLossBac

tensor(0.0791, grad_fn=<NllLossBackward>)
tensor(0.0759, grad_fn=<NllLossBackward>)
tensor(0.0594, grad_fn=<NllLossBackward>)
tensor(0.0883, grad_fn=<NllLossBackward>)
tensor(0.0981, grad_fn=<NllLossBackward>)
tensor(0.0709, grad_fn=<NllLossBackward>)
tensor(0.0725, grad_fn=<NllLossBackward>)
tensor(0.0823, grad_fn=<NllLossBackward>)
tensor(0.0819, grad_fn=<NllLossBackward>)
tensor(0.0520, grad_fn=<NllLossBackward>)
tensor(0.0720, grad_fn=<NllLossBackward>)
tensor(0.0973, grad_fn=<NllLossBackward>)
tensor(0.0620, grad_fn=<NllLossBackward>)
tensor(0.0631, grad_fn=<NllLossBackward>)
tensor(0.0653, grad_fn=<NllLossBackward>)
tensor(0.0616, grad_fn=<NllLossBackward>)
tensor(0.0720, grad_fn=<NllLossBackward>)
tensor(0.0672, grad_fn=<NllLossBackward>)
tensor(0.0793, grad_fn=<NllLossBackward>)
tensor(0.0717, grad_fn=<NllLossBackward>)
tensor(0.0785, grad_fn=<NllLossBackward>)
tensor(0.0704, grad_fn=<NllLossBackward>)
tensor(0.0855, grad_fn=<NllLossBackward>)
tensor(0.0734, grad_fn=<NllLossBac

tensor(0.0546, grad_fn=<NllLossBackward>)
tensor(0.0517, grad_fn=<NllLossBackward>)
tensor(0.0533, grad_fn=<NllLossBackward>)
tensor(0.0492, grad_fn=<NllLossBackward>)
tensor(0.0452, grad_fn=<NllLossBackward>)
tensor(0.0373, grad_fn=<NllLossBackward>)
tensor(0.0545, grad_fn=<NllLossBackward>)
tensor(0.0585, grad_fn=<NllLossBackward>)
tensor(0.0583, grad_fn=<NllLossBackward>)
tensor(0.0599, grad_fn=<NllLossBackward>)
tensor(0.0667, grad_fn=<NllLossBackward>)
tensor(0.0731, grad_fn=<NllLossBackward>)
tensor(0.0564, grad_fn=<NllLossBackward>)
tensor(0.0423, grad_fn=<NllLossBackward>)
tensor(0.0667, grad_fn=<NllLossBackward>)
tensor(0.0481, grad_fn=<NllLossBackward>)
tensor(0.0475, grad_fn=<NllLossBackward>)
tensor(0.0478, grad_fn=<NllLossBackward>)
tensor(0.0433, grad_fn=<NllLossBackward>)
tensor(0.0437, grad_fn=<NllLossBackward>)
tensor(0.0507, grad_fn=<NllLossBackward>)
tensor(0.0444, grad_fn=<NllLossBackward>)
tensor(0.0493, grad_fn=<NllLossBackward>)
tensor(0.0487, grad_fn=<NllLossBac

tensor(0.0336, grad_fn=<NllLossBackward>)
