In [1]:
import torch
import torch.nn as nn

In [2]:
# My version using pytorch for help
class FeedFowardLayer(nn.Module):
    def __init__(self, d_model: int, hidden: int, act_fn: nn.Module = nn.ReLU):
        super().__init__()
        self.linear = nn.Sequential(
            nn.Linear(d_model, hidden),
            act_fn(),
            nn.Linear(hidden, d_model))
        
    def forward(self, x): return self.linear(x)

class EncoderBlock(nn.Module):
    def __init__(self, seq_len, embed_dim, ff_hidden, num_heads, batch_first, act_fn):
        super().__init__()
        self.seq_len, self.embed_dim, self.num_heads = seq_len, embed_dim, num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attention_layer = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=batch_first)
        self.layer_norm1 = nn.LayerNorm([seq_len, embed_dim])
        self.feed_forward_layer = FeedFowardLayer(d_model=embed_dim, hidden=ff_hidden, act_fn=act_fn)
        self.layer_norm2 = nn.LayerNorm([seq_len, embed_dim])

    def attention_block(self, x, q, k, v):
        output, _ = self.attention_layer(query=q, key=k, value=v) # we dont use mask in encoder
        return self.layer_norm1(output + x)
    
    def feed_forward_block(self, attn_out):
        output = self.feed_forward_layer(attn_out)
        return self.layer_norm2(output + attn_out)
   
    def forward(self, x):
        # project our input into qkv linear
        qkv = self.qkv(x)

        # split into sperate matrices
        q,k,v = qkv.chunk(3, dim=-1)

        # feed into our attention block
        out = self.attention_block(x, q, k, v)

        # return output from feed_forward_block
        return self.feed_forward_block(out)

class Encoder(nn.Module):
    def __init__(self, seq_len: int, embed_dim: int = 512, ff_hidden: int = 2048, num_heads: int = 8, n_blocks: int = 12, batch_first: bool = True, linear_act_fn: nn.Module = nn.ReLU):
        super().__init__()
        self.blocks = nn.Sequential(*[EncoderBlock(seq_len=seq_len, embed_dim=embed_dim, ff_hidden=ff_hidden, num_heads=num_heads, batch_first=batch_first, act_fn=linear_act_fn) for i in range(n_blocks)])

    def forward(self, x): return self.blocks(x)

In [7]:
# our data:
batch_size = 64
seq_len = 4000 # 4k context
embedding_size = 512
x = torch.Tensor(batch_size, seq_len, embedding_size) 

# construct our encoder
encoder = Encoder(
    seq_len=seq_len, 
    embed_dim=embedding_size,
)

In [8]:
# how many parameters are in the model
print(f"Number of parameters in model: {sum(p.numel() for p in encoder.parameters())}")

Number of parameters in model: 833691648


In [None]:
# does it work? 
out = encoder(x)
out.shape