In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / (d_k ** 0.5)
    weights = F.softmax(scores, dim=-1)
    output = weights @ V
    return output, weights

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()

        self.Wq = nn.Linear(embed_dim, embed_dim)
        self.Wk = nn.Linear(embed_dim, embed_dim)
        self.Wv = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        out, weights = scaled_dot_product_attention(Q, K, V)
        return out

In [4]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()

        self.attention = SelfAttention(embed_dim)
        self.ffn = FeedForward(embed_dim, hidden_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_out = self.attention(x)
        x = self.norm1(x + attn_out)

        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x

In [6]:
class ToyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.block = TransformerBlock(embed_dim, hidden_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.block(x)
        x = x.mean(dim=1)
        out = self.fc(x)
        return out

In [7]:
vocab_size = 50
embed_dim = 32
hidden_dim = 64
num_classes = 2

model = ToyTransformer(
    vocab_size,
    embed_dim,
    hidden_dim,
    num_classes
)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

batch_size = 16
seq_len = 10

for epoch in range(5):
    inputs = torch.randint(0, vocab_size, (batch_size, seq_len))
    labels = torch.randint(0, num_classes, (batch_size,))

    outputs = model(inputs)
    loss = loss_fn(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    preds = outputs.argmax(dim=1)
    acc = (preds == labels).float().mean().item()

    print("Epoch", epoch, "Loss", loss.item(), "Accuracy", acc)

Epoch 0 Loss 0.7529621720314026 Accuracy 0.3125
Epoch 1 Loss 0.7105728983879089 Accuracy 0.375
Epoch 2 Loss 0.6698642373085022 Accuracy 0.5625
Epoch 3 Loss 0.6829836368560791 Accuracy 0.5
Epoch 4 Loss 0.7062085270881653 Accuracy 0.4375
