In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)

  def scaled_dot_product_attention(self, Q, K, V, mask=None):
    att_scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
    if mask is not None:
      att_scores = att_scores.masked_fill(mask == 0, -1e9)
    attn_probs = torch.softmax(att_scores, dim=-1)
    output = torch.matmul(attn_probs, V)
    return output

  def split_heads(self, x):
    batch_size, seq_length, d_model = x.size()
    return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1,2)

  def combine_heads(self, x):
    batch_size, _, seq_length, d_k = x.size()
    return x.transpose(1,2).contiguous().view(batch_size, seq_length, self.d_model)

  def forward(self, Q, K, V, mask = None):
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))

    att_output = self.scaled_dot_product_attention(Q, K, V, mask)
    output = self.W_o(self.combine_heads(att_output))
    return output

In [4]:
class PositionWiseFeedForward(nn.Module):
  def __init__(self, d_model, d_ff):
    super(PositionWiseFeedForward, self).__init__()

    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)
    self.relu = nn.ReLU()

  def forward(self, x):
    return self.fc2(self.relu(self.fc1(x)))

In [5]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_seq_length):
    super(PositionalEncoding, self).__init__()

    pe = torch.zeros(max_seq_length, d_model)
    position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self,x):
    return x + self.pe[:, :x.size(1)]

In [6]:
class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(EncoderLayer, self).__init__()

    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    attn_output = self.self_attn(x, x, x, mask)
    x = self.norm1( x + self.dropout(attn_output))
    ff_output = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_output))

    return x

In [8]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout):
    super(DecoderLayer, self).__init__()

    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.cross_attn = MultiHeadAttention(d_model, num_heads)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_output, src_mask, tgt_mask):
    attn_output = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(attn_output))

    attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
    x = self.norm2(x + self.dropout(attn_output))

    ff_output = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_output))
    return x

In [23]:
class Transformer(nn.Module):
  def __init__(self, src_vocab, tgt_vocab, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
    super(Transformer, self).__init__()

    self.encoder_embedding = nn.Embedding(src_vocab, d_model)
    self.decoder_embedding = nn.Embedding(tgt_vocab, d_model)
    self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

    self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
    self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

    self.fc = nn.Linear(d_model, tgt_vocab)
    self.dropout = nn.Dropout(dropout)

  def generate_mask(self, src, tgt):
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
    seq_length = tgt.size(1)
    nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
    tgt_mask = tgt_mask & nopeak_mask
    return src_mask, tgt_mask

  def forward(self, src, tgt):
    src_mask, tgt_mask = self.generate_mask(src, tgt)
    src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
    tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

    enc_output = src_embedded
    for enc_layer in self.encoder_layers:
        enc_output = enc_layer(enc_output, src_mask)

    dec_output = tgt_embedded
    for dec_layer in self.decoder_layers:
        dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

    output = self.fc(dec_output)
    return output


In [35]:
src_vocab_size = 8
tgt_vocab_size = 7
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

print(src_data)
print(tgt_data[:,:-1], tgt_data)

tensor([[1, 5, 6,  ..., 6, 6, 1],
        [6, 4, 1,  ..., 5, 5, 3],
        [6, 2, 2,  ..., 6, 2, 4],
        ...,
        [3, 4, 7,  ..., 4, 1, 2],
        [4, 4, 5,  ..., 7, 3, 4],
        [4, 6, 4,  ..., 3, 6, 6]])
tensor([[6, 3, 5,  ..., 5, 6, 5],
        [2, 4, 6,  ..., 5, 5, 5],
        [6, 6, 3,  ..., 3, 3, 1],
        ...,
        [1, 1, 5,  ..., 5, 3, 5],
        [3, 2, 3,  ..., 1, 6, 1],
        [2, 5, 6,  ..., 3, 5, 1]]) tensor([[6, 3, 5,  ..., 6, 5, 2],
        [2, 4, 6,  ..., 5, 5, 5],
        [6, 6, 3,  ..., 3, 1, 3],
        ...,
        [1, 1, 5,  ..., 3, 5, 5],
        [3, 2, 3,  ..., 6, 1, 4],
        [2, 5, 6,  ..., 5, 1, 1]])


In [12]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 8.687286376953125
Epoch: 2, Loss: 8.548818588256836
Epoch: 3, Loss: 8.478421211242676
Epoch: 4, Loss: 8.415529251098633
Epoch: 5, Loss: 8.354681015014648
Epoch: 6, Loss: 8.279276847839355
Epoch: 7, Loss: 8.202736854553223
Epoch: 8, Loss: 8.122802734375
Epoch: 9, Loss: 8.038825035095215
Epoch: 10, Loss: 7.960371017456055
Epoch: 11, Loss: 7.872034549713135
Epoch: 12, Loss: 7.7895026206970215
Epoch: 13, Loss: 7.704648971557617
Epoch: 14, Loss: 7.623815536499023
Epoch: 15, Loss: 7.54539155960083
Epoch: 16, Loss: 7.4590606689453125
Epoch: 17, Loss: 7.376415729522705
Epoch: 18, Loss: 7.2908477783203125
Epoch: 19, Loss: 7.21025276184082
Epoch: 20, Loss: 7.133049011230469
Epoch: 21, Loss: 7.052592754364014
Epoch: 22, Loss: 6.976838111877441
Epoch: 23, Loss: 6.903621673583984
Epoch: 24, Loss: 6.836450576782227
Epoch: 25, Loss: 6.747838020324707
Epoch: 26, Loss: 6.6709699630737305
Epoch: 27, Loss: 6.602648735046387
Epoch: 28, Loss: 6.5303263664245605
Epoch: 29, Loss: 6.4576864242

In [36]:
english_vocab = {
    "<START>": 1, "<END>": 2, "I": 3, "am": 4, "learning": 5, "machine": 6, "translation": 7
}
german_vocab = {
    "<START>": 1, "<END>": 2, "Ich": 3, "lerne": 4, "maschinelle": 5, "Übersetzung": 6
}

# Tokenized sentences
english_sentence = ["I", "am", "learning", "machine", "translation"]
german_sentence = ["Ich", "lerne", "maschinelle", "Übersetzung"]

In [37]:
src_data = torch.tensor([english_vocab[word] for word in english_sentence], dtype=torch.long).unsqueeze(0)  # Shape: (1, seq_len)
tgt_data = torch.tensor([german_vocab[word] for word in german_sentence], dtype=torch.long).unsqueeze(0)  # Shape: (1, seq_len)

print(src_data.shape, tgt_data)

torch.Size([1, 5]) tensor([[3, 4, 5, 6]])


In [38]:
transformer = Transformer(src_vocab=8,  # Includes <START>, <END>, and 6 words in the English sentence
                          tgt_vocab=7,  # Includes <START>, <END>, and 5 words in the German translation
                          d_model=512,
                          num_heads=8,
                          num_layers=6,
                          d_ff=2048,
                          max_seq_length=50,
                          dropout=0.1)

In [44]:
def generate_sequence(transformer, src, start_token, end_token, max_length=50):
  src_mask, _ = transformer.generate_mask(src, src)

  src_embedded = transformer.dropout(transformer.positional_encoding(transformer.encoder_embedding(src)))

  enc_output = src_embedded

  for enc_layer in transformer.encoder_layers:
    enc_output = enc_layer(enc_output, src_mask)

  tgt = torch.tensor([[start_token]], device=src.device)

  generated_sequence = [start_token]

  for _ in range(5):
    tgt_embedded = transformer.dropout(transformer.positional_encoding(transformer.decoder_embedding(tgt)))

    dec_output = tgt_embedded

    for dec_layer in transformer.decoder_layers:
      dec_output = dec_layer(dec_output, enc_output, src_mask, None)

    #print(dec_output[: ,-1])

    logits = transformer.fc(dec_output[: , -1, :])
    predicted_token = torch.argmax(logits, dim=-1)

    generated_sequence.append(predicted_token.item())

    tgt = torch.cat([tgt, predicted_token.unsqueeze(-1)], dim=-1)

    if predicted_token.item() == end_token:
      break

  return generated_sequence

In [41]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 2.0559885501861572
Epoch: 2, Loss: 0.5785306096076965
Epoch: 3, Loss: 0.6938692927360535
Epoch: 4, Loss: 0.325442910194397
Epoch: 5, Loss: 0.1525500863790512
Epoch: 6, Loss: 0.19429273903369904
Epoch: 7, Loss: 0.11214840412139893
Epoch: 8, Loss: 0.010207866318523884
Epoch: 9, Loss: 0.0039897519163787365
Epoch: 10, Loss: 0.00397447170689702
Epoch: 11, Loss: 0.003894123015925288
Epoch: 12, Loss: 0.004027724731713533
Epoch: 13, Loss: 0.0018542163306847215
Epoch: 14, Loss: 0.004249962512403727
Epoch: 15, Loss: 0.0043480959720909595
Epoch: 16, Loss: 0.002685801824554801
Epoch: 17, Loss: 0.0015867622569203377
Epoch: 18, Loss: 0.0021295829210430384
Epoch: 19, Loss: 0.0013033106224611402
Epoch: 20, Loss: 0.0012513699475675821
Epoch: 21, Loss: 0.0013546844711527228
Epoch: 22, Loss: 0.0008079048711806536
Epoch: 23, Loss: 0.000776016095187515
Epoch: 24, Loss: 0.0006637386395595968
Epoch: 25, Loss: 0.0009534747223369777
Epoch: 26, Loss: 0.000535212573595345
Epoch: 27, Loss: 0.00069

In [45]:

start_token = 1
end_token = 2

generated_seq = generate_sequence(transformer, src_data, start_token, end_token, max_length=50)


reverse_vocab = {v:k for k, v in german_vocab.items()}
gen_words =  [reverse_vocab[token] for token in generated_seq]

print(" ".join(gen_words))

<START> lerne maschinelle Übersetzung maschinelle Übersetzung
