# 09. Transformers

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/pytorch_tutorial/blob/main/09_transformers/demo.ipynb)

---

In [None]:
import torch
import torch.nn as nn
import math

## Self-Attention from Scratch

In [None]:
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.size(-1)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    attention = torch.softmax(scores, dim=-1)
    return torch.matmul(attention, V), attention

# Test
seq_len, d_model = 4, 8
Q = K = V = torch.randn(1, seq_len, d_model)
output, attn = scaled_dot_product_attention(Q, K, V)
print(f'Output shape: {output.shape}')
print(f'Attention weights:\n{attn[0]}')

## Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size = x.size(0)
        Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        output, _ = scaled_dot_product_attention(Q, K, V)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.W_o(output)

mha = MultiHeadAttention(d_model=64, num_heads=4)
x = torch.randn(2, 10, 64)
print(f'MHA output: {mha(x).shape}')