<a href="https://colab.research.google.com/github/chidambarambaskaran/MachineLearning/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn
import torch.optim as optim
import math

# Positional Encoding
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 5000):
    super(PositionalEncoding, self).__init__()
    pe = torch.zeros(max_len, d_model)
    position = torch.arrange(0, max_len, dtype = torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arrange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    pe[:, 0:2] = torch.sin(position * div_term)
    pe[:, 1:2] = torch.cos(position * div_term)
    self.pe = pe.unsqueeze(0)

  def forward(self, x):
    return x + self.pe[:, :x.size(1), :].to(x.device)

# Multi-Head Attention
class MutliHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % num_heads == 0
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.q_linear = nn.Linear(d_model, d_model)
    self.k_linear = nn.Linear(d_model, d_model)
    self.v_Linear = nn.Linear(d_model, d_model)
    self.out = nn.Linear(d_model, d_model)

  def attention(self, Q, K, V, mask=None):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))
    attn_weights = torch.nn.functional.softmax(scores, dim=-1)
    return torch.matmul(scores, V)

  def forward(self, q, k, v, mask=None):
    batch_size = q.size(0)
    Q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    K = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    V = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    attn_output = self.attention(Q, K, V, mask)
    attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
    return self.out(attn_output)

# FeedForward
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff):
    super(FeedForward, self).__init__()
    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_model, d_ff)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

# Encoder Layer
class EncoderLayer(nn.Module):
  def __init__(self, d_model, max_len, d_ff):
    super(EncoderLayer, self).__init__()
    self.self.attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = FeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)

  def forward(self, x, mask=None):
    attn_out = self.self_attn(x, x, x, mask)
    x = self.norm1(x + attn_out)
    ff_out = self.feed_forward(x)
    return self.norm2(x + ff_out)

# Full Encoder
class Encoder(nn.Module):
  def __init__(self, vocab_size, d_model, max_len, d_ff, num_layers, num_heads):
    super(Encoder, self).__init__()
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_len)
    self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, num_heads)for _ in range(num_layers)])

  def forward(self, x):
    x = self.embedding(x)
    x = self.positional_encoding(x)
    for layer in self.layers:
      x = layer(x, mask)
    return x

# Decoder Layer
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff):
    super(DecoderLayer. self).__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = FeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)

  def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
    attn_out = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + attn_out)
    attn_out = self.cross_attn(x, enc_output, enc_output, src_mask)
    x = self.norm2(x + attn_out)
    ff_out = self.feed_forward(x)
    return self.norm3(ff_out + x)

# Full Decoder
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
        return x

# Transformer
class Transformer(nn.Module):
  def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_len):
    super(Transformer, self).__init__()
    self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff, num_layers, max_len)
    self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_len)
    self.final_layer = nn.Linear(d_model, tgt_vocab_size)

  def forward(self, src, tgt, src_mask=None, tgt_mask=None):
    enc_output = self.encoder(src, src_mask)
    dec_ouput = self.decoder(tgt, enc_ouput, src_mask, tgt_mask)
    return self.final_layer(dec_output)

# Hyperparameters
src_vocab_size = tgt_vocab_size = 10000
d_model = 512
num_heads = 8
d_ff = 2048
num_layers = 6
max_len = 100

# Initialize Transformer
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_len)
