In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F

## Positional Encoding Layer

In [69]:
class PositionalEncoding(nn.Module):
  def __init__(self, max_seq_length, d_model):
    super(PositionalEncoding, self).__init__()
    positions = torch.arange(0, max_seq_length, dtype = torch.float).unsqueeze(0)
    for dim in range(d_model):
      std_dim = dim//2 * 2
      dimVal = torch.tensor([(1/1e4)**(std_dim / d_model)] * max_seq_length).unsqueeze(0)
      dimVal = torch.sin(dimVal * positions) if dim % 2 == 0 else torch.cos(dimVal * positions)
      if dim == 0:
        self.allDimVals = dimVal
      else:
        self.allDimVals = torch.cat((self.allDimVals, dimVal), dim = 0)
    self.allDimVals = self.allDimVals.transpose(-1, -2)

  def forward(self, x):
    return x + self.allDimVals[:x.size(1)]

## Multi-Head Attention Layer

In [70]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % n_heads == 0
    self.d_model = d_model
    self.n_heads = n_heads
    self.d_k = d_model // n_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

  def scaleDotProduct(self, Q, K, V, mask = None):
    attentionScore = torch.matmul(Q, K.transpose(-1, -2)) / torch.sqrt(self.d_k)
    if mask is not None:
      attentionScore = attentionScore.masked_fill_(mask == 0, -1e9)
    attentionScore = torch.softmax(attentionScore, dim = -1)

    output = torch.matmul(attentionScore, V)
    return output

  def split_heads(self, x):
    batch_size, seq_length, d_model = x.size()
    return x.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)

  def combine_heads(self, x):
    batch_size, n_heads, seq_length, d_k = x.size()
    return x.transpose(1, 2).view(batch_size, seq_length, d_k * n_heads)

  def forward(self, Q, K, V, mask):
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))

    attentionMat = self.scaleDotProduct(Q, K, V, mask)
    output = self.W_o(self.combine_heads(attentionMat))
    return output

## Feed Forward Layer (after Add&Norm Layer)

In [71]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_fflayer):
    super(FeedForward, self).__init__()
    self.fc1 = nn.Linear(d_model, d_fflayer)
    self.fc2 = nn.Linear(d_fflayer, d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    return self.relu(x)

## Encoder Layer

In [72]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, d_fflayer, max_seq_length, dropout, n_heads):
    super(EncoderLayer, self).__init__()
    self.attention = MultiHeadAttention(d_model, n_heads)
    self.ff = FeedForward(d_model, d_fflayer)
    self.norm = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    attentionScore = self.attention(x, x, x, mask)
    x = self.norm(x + self.dropout(attentionScore))
    ff_output = self.ff(x)
    x = self.norm(x + self.dropout(ff_output))
    return x

## Decoder Layer

In [73]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, d_fflayer, dropout, n_heads):
    super(DecoderLayer, self).__init__()
    self.in_attention = MultiHeadAttention(d_model, n_heads)
    self.cross_attention = MultiHeadAttention(d_model, n_heads)
    self.ff = FeedForward(d_model, d_fflayer)
    self.norm = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_output, src_mask, tgt_mask):
    in_attentionScore = self.in_attention(x, x, x, tgt_mask)
    x = self.norm(x + self.dropout(in_attentionScore))
    cross_attentionScore = self.cross_attention(x, enc_output, enc_output, src_mask)
    x = self.norm(x + self.dropout(cross_attentionScore))
    ffoutput = self.ff(x)
    x = self.norm(x + self.dropout(ffoutput))
    return x

## Transformer Model

In [74]:
class Transformer(nn.Module):
  def __init__(self, d_model, d_fflayer, max_seq_length, dropout, n_layers, n_heads, src_vocab_size, tgt_vocab_size):
    super(Transformer, self).__init__()
    self.n_layers = n_layers

    self.embed1 = nn.Embedding(src_vocab_size, d_model)
    self.embed2 = nn.Embedding(tgt_vocab_size, d_model)
    self.pe = PositionalEncoding(max_seq_length, d_model)
    self.encoders = nn.ModuleList([EncoderLayer(d_model, d_fflayer, max_seq_length, dropout, n_heads)] * n_layers)
    self.decoders = nn.ModuleList([DecoderLayer(d_model, d_fflayer, dropout, n_heads)] * n_layers)
    self.dropout = nn.Dropout(dropout)
    self.fc = nn.Linear(d_model, tgt_vocab_size)

  def generate(self, src, tgt):
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
    seq_length = tgt.size(1)
    peak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal = 1)).bool()
    tgt_mask = tgt_mask & peak_mask
    return src_mask, tgt_mask

  def forward(self, x, src, tgt):
    src_mask, tgt_mask = self.generate(src, tgt)
    src, tgt = self.embed1(src), self.embed2(tgt)
    src_embed, tgt_embed = self.dropout(self.pe(src)), self.dropout(self.pe(tgt))
    for enc_layer in self.encoders:
      src_embed = enc_layer(src_embed, src_mask)

    for dec_layer in self.decoders:
      tgt_embed = dec_layer(tgt_embed, src_embed, src_mask, tgt_mask)
    output = self.fc(tgt_embed)
    return output

## Train

In [75]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(d_model, d_fflayer, max_seq_length, dropout, n_layers, n_heads, src_vocab_size, tgt_vocab_size)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")