# Encoder
## Easy to understand implementation of the Encoder Architecture

In [87]:
import torch

In [None]:
class PositionalEncoder(torch.nn.Module):
    def __init__(self, max_seq_len, d_model):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.d_model = d_model

        self.pe = torch.zeros(self.max_seq_len, self.d_model)

        # positional encoder 
        for j in range(max_seq_len):
            for i in range(0, d_model):
                if i%2 == 0:
                    k = i//2
                    wk = 1 / (10000**((2*k)/d_model))
                    self.pe[j][i] = torch.sin(torch.tensor(wk*j))
                else:
                    self.pe[j][i] = torch.cos(torch.tensor(wk*j))

        self.pe = self.pe.detach()

    def show_pe(self):
        print(self.pe)

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(0)

        if x.shape[1]>self.max_seq_len:
            raise Exception("Number of tokens exceeds max_seq_len")
        
        if x.shape[2] != self.d_model:
            raise Exception("Token dimension do not match model dimension")
        
        if x.device != self.pe.device:
            self.pe = self.pe.to(x.device)  # Move pe to the same device as x

        x = x + self.pe[:x.shape[1], :]
        return x

In [None]:
pe = PositionalEncoder(max_seq_len=10, d_model=5)
pe.show_pe()

## True positional encoding obtained from the paper
Positional Encodings (max_seq_len=10, dim_model=5):<br>
[[ 0.00000000e+00  1.00000000e+00  0.00000000e+00  1.00000000e+00 0.00000000e+00]<br>
 [ 8.41470985e-01  5.40302306e-01  2.51162229e-02  9.99684538e-01 6.30957303e-04]<br>
 [ 9.09297427e-01 -4.16146837e-01  5.02165994e-02  9.98738351e-01 1.26191435e-03]<br>
 [ 1.41120008e-01 -9.89992497e-01  7.52852930e-02  9.97162035e-01 1.89287090e-03]<br>
 [-7.56802495e-01 -6.53643621e-01  1.00306487e-01  9.94956586e-01 2.52382670e-03]<br>
 [-9.58924275e-01  2.83662185e-01  1.25264396e-01  9.92123395e-01 3.15478149e-03]<br>
 [-2.79415498e-01  9.60170287e-01  1.50143272e-01  9.88664249e-01 3.78573502e-03]<br>
 [ 6.56986599e-01  7.53902254e-01  1.74927419e-01  9.84581331e-01 4.41668705e-03]<br>
 [ 9.89358247e-01 -1.45500034e-01  1.99601200e-01  9.79877217e-01 5.04763732e-03]<br>
 [ 4.12118485e-01 -9.11130262e-01  2.24149048e-01  9.74554875e-01 5.67858558e-03]]<br>

In [89]:
class FeedForwardNN(torch.nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff

        self.ff1 = torch.nn.Linear(self.d_model, self.d_ff)
        self.act = torch.nn.GELU()
        self.ff2 = torch.nn.Linear(self.d_ff, self.d_model)

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
            print("Warning: batch size not present")

        if x.shape[2]!=self.d_model:
            raise Exception("Token dimension do not match model dimension")
        
        x = self.ff1(x)
        x = self.act(x)
        x = self.ff2(x)

        return x

In [90]:
class LayerNorm(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.layerNorm = torch.nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.layerNorm(x)
        return x

In [None]:
class SelfAttentionHead(torch.nn.Module):
    def __init__(self, d_model, d_k, max_seq_len):
        super().__init__()
        self.d_model = d_model
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.W_q = torch.nn.Linear(in_features=d_model, out_features=d_k)
        self.W_k = torch.nn.Linear(in_features=d_model, out_features=d_k)
        self.W_v = torch.nn.Linear(in_features=d_model, out_features=d_k)

        self.W_O = torch.nn.Linear(in_features = d_k, out_features=d_model)
        
    def forward(self, Q, K, V, mask = False):
        if len(Q.shape)==2:
            Q.unsqueeze(0)
            print("Warning: batch size not present")

        if len(K.shape)==2:
            K.unsqueeze(0)
            print("Warning: batch size not present")

        if len(V.shape)==2:
            V.unsqueeze(0)
            print("Warning: batch size not present")

        if Q.shape[1] > self.max_seq_len or V.shape[1]> self.max_seq_len or K.shape[1]>self.max_seq_len:
            raise Exception("Number of tokens exceed max sequence length")

        if Q.shape[2] != self.d_model or K.shape[2] != self.d_model or V.shape[2]!=self.d_model:
            raise Exception("Tokens dimension do not match model dimension")

        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        logits = (torch.matmul(Q, torch.transpose(K, 1, 2)) )/ (self.d_k ** 0.5)
        if mask is True:
            ones = torch.ones(Q.shape[0], Q.shape[1], Q.shape[1]).to(Q.device)
            mask = torch.tril(ones)
            logits = logits.masked_fill(mask == 0, -float('1e9'))

        scores = torch.nn.functional.softmax(logits, dim = 1)
        attention = torch.matmul(scores, V)

        output = self.W_O(attention)
        
        return output

In [None]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, d_model, d_k, max_seq_len):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k=d_k
        self.max_seq_len = max_seq_len

        self.heads = torch.nn.ModuleList([SelfAttentionHead(d_model=self.d_model, d_k=self.d_k, max_seq_len=self.max_seq_len) for _ in range(num_heads)])

        self.W_O = torch.nn.Linear(num_heads*self.d_model, self.d_model)

    def forward(self, Q, K, V, mask = False):
        if len(Q.shape)==2:
            Q.unsqueeze(0)
            print("Warning: batch size not present")

        if len(K.shape)==2:
            K.unsqueeze(0)
            print("Warning: batch size not present")

        if len(V.shape)==2:
            V.unsqueeze(0)
            print("Warning: batch size not present")

        if Q.shape[1] > self.max_seq_len or V.shape[1]> self.max_seq_len or K.shape[1]>self.max_seq_len:
            raise Exception("Number of tokens exceed max sequence length")

        if Q.shape[2] != self.d_model or K.shape[2] != self.d_model or V.shape[2]!=self.d_model:
            raise Exception("Tokens dimension do not match model dimension")

        head_outputs = [head(Q, K, V, mask) for head in self.heads]
        concatenated = torch.cat(head_outputs, dim=-1)
        return self.W_O(concatenated)

In [92]:
class EncoderBlock(torch.nn.Module):
    def __init__(self, num_heads, d_model, d_k, d_ff, max_seq_len):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_ff = d_ff
        self.max_seq_len = max_seq_len

        self.multiHeadAttention = MultiHeadAttention(num_heads=self.num_heads, d_model=self.d_model, d_k=self.d_k, max_seq_len=self.max_seq_len)
        self.layerNorm = LayerNorm(d_model=self.d_model)
        self.ffnn = FeedForwardNN(d_model=self.d_model, d_ff=self.d_ff)

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
            print("Warning: batch size not present")

        if x.shape[2]!=self.d_model:
            raise Exception("Token dimension do not match model dimension")

        y = self.multiHeadAttention(x, x, x, False)
        
        x = x + y
        x = self.layerNorm(x)
        
        y = self.ffnn(x)
        x = x + y
        x = self.layerNorm(x)
        
        return x

In [93]:
class EncoderWrapper(torch.nn.Module):
    def __init__(self, num_heads, d_model, d_k, d_ff, max_seq_len, num_blocks, word_embedder = None):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_ff = d_ff
        self.max_seq_len = max_seq_len
        self.num_blocks = num_blocks
        self.word_embedder = word_embedder

        self.pe = PositionalEncoder(max_seq_len=self.max_seq_len, d_model=self.d_model)
        self.encoders = torch.nn.ModuleList([EncoderBlock(d_model=self.d_model, d_k=self.d_k, max_seq_len=self.max_seq_len, d_ff=self.d_ff, num_heads=self.num_heads) for _ in range(num_blocks)])

    def forward(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(0)
            print("Warning: batch size not present")
        
        if x.shape[1] > self.max_seq_len:
            print(x.shape)
            raise Exception("Number of tokens exceeds max sequence length")
        
        if self.word_embedder is not None:
            x = self.word_embedder(x)

        if x.shape[2]!=self.d_model:
            raise Exception("Token dimension do not match model dimension")

        x = self.pe(x)

        for block in self.encoders:
            x = block(x)
        
        return x


## From here on, tests

In [94]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_HEADS = 3
MAX_SEQ_LEN = 500
D_MODEL = 512
D_K = 128
D_FF = 256
NUM_BLOCKS = 2
BATCH_SIZE = 16

In [95]:
seq_len = 10
d_model = 512

model = EncoderWrapper(num_heads=NUM_HEADS, num_blocks=NUM_BLOCKS, d_model=D_MODEL, d_k=D_K, d_ff=D_FF, max_seq_len=MAX_SEQ_LEN).to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])

input_tensor = torch.randn(1, seq_len, d_model).to(DEVICE)
output = model(input_tensor)

assert output.shape == (1, seq_len, d_model), f"Unexpected output shape: {output.shape}"
print("Shape test passed!")

Shape test passed!


In [96]:
input_tensor = torch.randn(1, seq_len, d_model).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
loss_fn = torch.nn.MSELoss()

for epoch in range(0,100):
    optimizer.zero_grad()
    output = model(input_tensor)
    loss = loss_fn(output, input_tensor)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item()}")

Epoch 0: Loss = 0.31085100769996643


Epoch 10: Loss = 0.012060209177434444
Epoch 20: Loss = 0.00608134875074029
Epoch 30: Loss = 0.0060678087174892426
Epoch 40: Loss = 0.005880665499716997
Epoch 50: Loss = 0.003627615747973323
Epoch 60: Loss = 0.002997001400217414
Epoch 70: Loss = 0.002640163293108344
Epoch 80: Loss = 0.0024372104089707136
Epoch 90: Loss = 0.0022794497199356556


In [97]:
print(torch.max(torch.abs(input_tensor-model(input_tensor))))

tensor(0.1905, device='cuda:0', grad_fn=<MaxBackward1>)


In [None]:
#TODO: need more tests