In [1]:
import torch
from typing import Tuple

In [154]:
class Attention3D(torch.nn.Module):
    def __init__(self, n_features: int, masked: bool = False) -> None:
        super(Attention3D, self).__init__()

        self.query, self.key, self.value = [
            torch.nn.Linear(n_features, n_features) 
            for _ in range(3)
        ]
        
        self.masked = masked
        self.softmax = torch.nn.Softmax(dim = 1)

    def forward(self, input: torch.Tensor, K_input: torch.Tensor = None, V_input: torch.Tensor = None) -> torch.Tensor:
        context_length, batch_size, n_features = input.shape
        
        K_context, _, _ = K_input.shape if K_input is not None else input.shape
        V_context, _, _ = V_input.shape if V_input is not None else input.shape
            
        output = input.clone().detach()
        K_output = torch.Tensor(K_context, batch_size, n_features)
        V_output = torch.Tensor(V_context, batch_size, n_features)
        
        for i in range(batch_size):
            input_2d = output[:, i: i+1, :].clone().detach().reshape(context_length, n_features)
            Q = self.query(input_2d)
            
            if K_input is None:
                K = self.key(input_2d)
            else:
                K = K_input[:, i: i+1, :].clone().detach().reshape(K_context, n_features)
                
            if V_input is None:
                V = self.value(input_2d)
            else:
                V = V_input[:, i: i+1, :].clone().detach().reshape(V_context, n_features)

            I = Q @ K.T
            if self.masked:
                for i in range(len(I)):
                    I[i][i+1:] = -torch.inf
                
            attention_2d = self.softmax( I / n_features**0.5) @ V
            output[:, i: i+1, :] = attention_2d.reshape(context_length, 1, n_features)
            
            K_output[:, i: i+1, :] = K.reshape(K_context, 1, n_features)
            V_output[:, i: i+1, :] = V.reshape(V_context, 1, n_features)
        
        return output, (K_output, V_output)

In [155]:
class Encoder(torch.nn.Module):
    
    def __init__(self, batch_sample: torch.Tensor, final_module: bool = False) -> None:
        super(Encoder, self).__init__()
        
        context_length, batch_size, n_features = batch_sample.shape
        self.layer_norm = torch.nn.LayerNorm(normalized_shape = n_features)
        
        self.attention_layer = Attention3D(n_features, masked = False)
        self.feed_forward_layer = torch.nn.Sequential(
            torch.nn.Linear(n_features, n_features * 2),
            torch.nn.Linear(n_features * 2, n_features)
        )
        
        self.final_module = final_module
        
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        attention_output, KV_tuple = self.attention_layer(input)
        norm_attention_output = self.layer_norm(input + attention_output)
        
        feed_forward_output = self.feed_forward_layer(norm_attention_output)
        encoder_output = self.layer_norm(norm_attention_output + feed_forward_output)
        
        return encoder_output if not self.final_module else KV_tuple

In [156]:
class Decoder(torch.nn.Module):
    def __init__(self, batch_sample: torch.Tensor) -> None:
        
        super(Decoder, self).__init__()
        
        context_length, batch_size, n_features = batch_sample.shape
        self.layer_norm = torch.nn.LayerNorm(normalized_shape = n_features)
        
        self.attention_layer = Attention3D(n_features, masked = True)
        self.cross_attention_layer = Attention3D(n_features, masked = False)
        
        self.feed_forward_layer = torch.nn.Sequential(
            torch.nn.Linear(n_features, n_features * 2),
            torch.nn.Linear(n_features * 2, n_features)
        )
        
    def forward(self, input: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
        masked_attention_output, _ = self.attention_layer(input)
        norm_attention_output = self.layer_norm(input + masked_attention_output)
        
        cross_attention_output, _ = self.cross_attention_layer(norm_attention_output, K, V)
        norm_cross_attention_output = self.layer_norm(norm_attention_output + cross_attention_output)
        
        feed_forward_output = self.feed_forward_layer(norm_cross_attention_output)
        decoder_output = self.layer_norm(norm_cross_attention_output + feed_forward_output)
        
        return decoder_output

In [None]:
class DecoderStack(torch.nn.Sequential):
    def __init__(self, *args, **kwargs) -> None:
        super(DecoderStack, self).__init__(*args, **kwargs)
        
    def forward(self, input, *args, **kwargs):
        for i, module in enumerate(self):
            input = module(input, *args, **kwargs)
        return input

In [157]:
class Transformer(torch.nn.Module):
    
    def __init__(self, batch_sample: torch.Tensor, stack_size: int, vocab_size: int) -> None:
        super(Transformer, self).__init__()
        
        context_length, batch_size, n_features = batch_sample.shape
        self.encoders, self.decoders = torch.nn.Sequential(), DecoderStack()
        
        for i in range(stack_size):
            is_final_encoder = i == stack_size - 1
            self.encoders.append(Encoder(batch_sample, is_final_encoder))
            self.decoders.append(Decoder(batch_sample))
            
        self.linear = torch.nn.Linear(n_features, vocab_size)
        self.softmax = torch.nn.Softmax(dim = 1)
            
    def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor) -> torch.Tensor:
        
        # 1. Forward input through encoders stack. Last encoder contains needed K,V tensors
        K, V = self.encoders(encoder_input)
        
        decoder_output = self.decoders(decoder_input, K, V)
        word_probabilities = self.softmax(self.linear(decoder_output))
        
        return word_probabilities

In [159]:
bs = torch.rand(25, 1, 8)

In [160]:
tr = Transformer(bs, stack_size = 3, vocab_size = 2048)

In [163]:
et = torch.rand(25, 1, 8)
dt = torch.rand(30, 1, 8)

r = tr.forward(et, dt)

In [164]:
r.shape

torch.Size([25, 1, 2048])