In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math

Transformer architecture

In [2]:
class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(attn_scores, dim=-1)
        return torch.matmul(attn, V)

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        self.q_linear = nn.Linear(embed_size, embed_size)
        self.k_linear = nn.Linear(embed_size, embed_size)
        self.v_linear = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)
        self.attention = ScaledDotProductAttention()

    def split_heads(self, x):
        batch_size, seq_length, embed_size = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, _, _ = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.q_linear(Q))
        K = self.split_heads(self.k_linear(K))
        V = self.split_heads(self.v_linear(V))

        out = self.attention(Q, K, V, mask)
        out = self.combine_heads(out)

        return self.fc_out(out)

class FeedForward(nn.Module):
    def __init__(self, embed_size, hidden_dim):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_size)

    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=100):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * -(math.log(10000.0) / embed_size))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

class EncoderLayer(nn.Module):
    def __init__(self, embed_size, num_heads, hidden_dim):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size, hidden_dim)

    def forward(self, x, mask=None):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        return self.norm2(x + ff_output)

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, seq_len):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.positional_encoding = PositionalEncoding(embed_size, seq_len)
        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(embed_size, num_heads, hidden_dim) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x)
        return self.fc_out(x)


Train model

In [3]:
seq_len = 6
vocab_size = 10
embed_size = 16
num_heads = 8
num_layers = 8
hidden_dim = 32
num_epochs = 1000
learning_rate = 0.001

model = TransformerModel(vocab_size, embed_size, num_heads, hidden_dim, num_layers, seq_len)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def generate_data(batch_size=64):
    x = np.random.randint(1, vocab_size, (batch_size, seq_len))
    y = np.flip(x, axis=1).copy()
    return torch.LongTensor(x), torch.LongTensor(y)

for epoch in range(num_epochs):
    model.train()
    x_train, y_train = generate_data()
    optimizer.zero_grad()
    output = model(x_train)
    loss = criterion(output.view(-1, vocab_size), y_train.view(-1))
    loss.backward()
    optimizer.step()

Testing

In [4]:
model.eval()
test_input = torch.LongTensor([[1, 3, 2, 4, 5, 3]])
with torch.no_grad():
    output = model(test_input)
    predicted_seq = torch.argmax(output, dim=-1).squeeze().tolist()
print("Input:", [1, 3, 2, 4, 5, 3])
print("Reversed:", predicted_seq)

Input: [1, 3, 2, 4, 5, 3]
Reversed: [3, 5, 4, 2, 3, 1]
