<a href="https://colab.research.google.com/github/dineshkumar-2003/A-simple-transformer-to-predict-next-letter/blob/main/Transformer_to_predict_next_letter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.qkv_proj = nn.Linear(d_model, d_model * 3)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x)  # (B, T, 3 * C)
        qkv = qkv.reshape(B, T, 3, self.num_heads, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Each: (B, num_heads, T, d_k)

        attn_scores = (q @ k.transpose(-2, -1)) / (self.d_k ** 0.5)  # (B, num_heads, T, T)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = attn_weights @ v  # (B, num_heads, T, d_k)

        attn_output = attn_output.transpose(1, 2).reshape(B, T, C)
        return self.out_proj(attn_output)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, ffn_hidden):
        super().__init__()
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Linear(ffn_hidden, d_model)
        )
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.ln1(x + self.attn(x))
        x = self.ln2(x + self.ffn(x))
        return x

class MiniTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_heads=4, num_layers=2, ffn_hidden=512, max_len=128):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, ffn_hidden) for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        x = self.token_emb(x)
        x = self.pos_enc(x)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        return self.head(x)

# # Example usage
# vocab_size = 5000
# model = MiniTransformer(vocab_size)
# x = torch.randint(0, vocab_size, (2, 32))  # (batch_size=2, seq_len=32)
# logits = model(x)  # (2, 32, vocab_size)
# print(logits.shape)  # should print torch.Size([2, 32, 5000])


In [None]:
text = "hello world"
chars = sorted(list(set(text)))  # unique characters
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for ch,i in stoi.items() }
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)



block_size = 4
X = []
Y = []
for i in range(len(data) - block_size):
    X.append(data[i:i+block_size])
    Y.append(data[i+1:i+block_size+1])
X = torch.stack(X)
Y = torch.stack(Y)


In [None]:
model = MiniTransformer(vocab_size=len(chars), d_model=32, num_heads=2, num_layers=2)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for step in range(1000):
    logits = model(X)  # (batch, seq_len, vocab)
    loss = loss_fn(logits.view(-1, logits.size(-1)), Y.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}, Loss {loss.item():.4f}")


Step 0, Loss 2.1370
Step 100, Loss 0.0383
Step 200, Loss 0.0144
Step 300, Loss 0.0076
Step 400, Loss 0.0048
Step 500, Loss 0.0033
Step 600, Loss 0.0024
Step 700, Loss 0.0019
Step 800, Loss 0.0015
Step 900, Loss 0.0012


In [None]:
def generate(start_seq, max_new_tokens=10):
    model.eval()
    input_ids = torch.tensor([stoi[c] for c in start_seq], dtype=torch.long).unsqueeze(0)

    for _ in range(max_new_tokens):
        input_chunk = input_ids[:, -block_size:]
        logits = model(input_chunk)
        next_id = torch.argmax(logits[:, -1, :], dim=-1)
        input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1)

    return ''.join([itos[i] for i in input_ids[0].tolist()])


In [None]:
print(generate("w"))


worldrldrld
