In [2]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 
import torch.utils.data as data

In [8]:
class MiniTransformer(nn.Module):
    def __init__(self, embedding_dim, ff_dim):
        super(MiniTransformer,self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=4)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim ,embedding_dim)
        )
        self.norm2 = nn.LayerNorm(embedding_dim)
    def forward(self, x):
        self.attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x+self.attn_out)
        self.ffn_out = self.ffn(x)
        x = self.norm2(x+self.ffn_out)
        return x
    

In [9]:
model = MiniTransformer(embedding_dim=256, ff_dim=1024)
x = torch.randn(10, 32, 256)  # (sequence_length, batch_size, embedding_dim)
target = torch.randn_like(x)
optimizer = optim.Adam(model.parameters())
output = model(x)
loss = F.mse_loss(output , target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Output shape:", output.shape)
print("Loss:", loss.item())

Output shape: torch.Size([10, 32, 256])
Loss: 1.98832106590271
