# Imports and Setup

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import random

# Set random seed for reproducibility
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


# Positional Encoding Module

In [4]:
class PositionalEncoding(nn.Module):
    """
    Injects some information about the relative or absolute position of the tokens
    in the sequence. The positional encodings have the same dimension as the embeddings.
    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create constant 'pe' matrix with values dependent on pos and i
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as buffer (not a learnable parameter, but part of state_dict)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: [Batch, Seq_len, Dim]
        # Add positional encoding to the input embedding
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Seq2Seq Transformer Model Definition

In [6]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p=0.1):
        super().__init__()

        # Embedding layer and Positional Encoding
        self.embedding = nn.Embedding(num_tokens, dim_model)
        self.pos_encoder = PositionalEncoding(dim_model, dropout_p)

        # Core: nn.Transformer module (contains both Encoder and Decoder)
        # batch_first=True ensures input format is [Batch, Seq, Dim]
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
            batch_first=True
        )

        # Final output projection layer
        self.out = nn.Linear(dim_model, num_tokens)
        self.dim_model = dim_model

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        """
        src: Source sequence
        tgt: Target sequence (shifted right)
        src_mask: Mask for source (usually None or all zeros)
        tgt_mask: Mask for target (causal mask to hide future tokens)
        src/tgt_padding_mask: Bool mask where True indicates padding tokens
        """
        # 1. Apply embedding and positional encoding
        src = self.pos_encoder(self.embedding(src) * math.sqrt(self.dim_model))
        tgt = self.pos_encoder(self.embedding(tgt) * math.sqrt(self.dim_model))

        # 2. Pass through the Transformer
        # Note: In PyTorch's nn.Transformer, padding masks are named *_key_padding_mask
        output = self.transformer(
            src, tgt,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask # Mask encoder padding for the decoder
        )

        # 3. Project to vocabulary size
        return self.out(output)

# Masking Utilities

In [5]:
def generate_square_subsequent_mask(sz, device):
    """
    Generates a causal mask (look-ahead mask) for the decoder.
    It prevents positions from attending to subsequent positions.
    """
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt, device):
    """
    Creates both the causal mask for the target and padding masks for both source and target.
    """
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    # Target requires a causal mask
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)

    # Source does not require a causal mask (encoder sees all tokens)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # Padding masks (identify where the token is 0)
    src_padding_mask = (src == 0)
    tgt_padding_mask = (tgt == 0)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [8]:
mask = generate_square_subsequent_mask(5, 'cpu')
print(mask)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])


cannot go to future(-inf).

# Hyperparameters and Model Initialization

In [7]:

# Special Tokens
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2

# Hyperparameters
VOCAB_SIZE = 20    # 0~19
DIM_MODEL = 128
NUM_HEADS = 4
NUM_LAYERS = 2
BATCH_SIZE = 64
MAX_LEN = 12

# Initialize Model (Dropout=0.0 for toy task specific optimization)
model = Seq2SeqTransformer(
    VOCAB_SIZE, DIM_MODEL, NUM_HEADS, NUM_LAYERS, NUM_LAYERS, dropout_p=0.0
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

print(f"Model Initialized. Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model Initialized. Parameters: 2,510,356


# Data Generation (Toy Task: Reverse Sequence)

In [9]:
def get_batch(bsz):
    """
    Generates a batch of random sequences.
    Task: Reverse the input sequence.
    Input:  [1, 5, 2, 0] (0 is padding)
    Target: [SOS, 2, 5, 1, EOS]
    """
    data = []
    targets = []

    for _ in range(bsz):
        # Length between 3 and (MAX_LEN - 2) to fit SOS and EOS
        seq_len = random.randint(3, MAX_LEN - 2)

        # Random sequence (3 ~ 19 range to avoid collision with special tokens)
        seq = [random.randint(3, VOCAB_SIZE-1) for _ in range(seq_len)]

        # Source: Sequence + Padding
        src = seq + [PAD_IDX] * (MAX_LEN - len(seq))

        # Target Sequence (Reverse)
        tgt_seq = seq[::-1]

        # Target Input: SOS + Reversed Seq + Padding
        tgt_input = [SOS_IDX] + tgt_seq + [PAD_IDX] * (MAX_LEN - len(seq) - 1)

        # Target Output: Reversed Seq + EOS + Padding
        tgt_out = tgt_seq + [EOS_IDX] + [PAD_IDX] * (MAX_LEN - len(seq) - 1)

        data.append(src)
        targets.append((tgt_input, tgt_out))

    src = torch.tensor(data).to(device)
    tgt_input = torch.tensor([t[0] for t in targets]).to(device)
    tgt_out = torch.tensor([t[1] for t in targets]).to(device)

    return src, tgt_input, tgt_out

In [11]:
src, tgt_in, tgt_out = get_batch(4)
print(src)

tensor([[14, 11, 10,  9, 17, 19,  3, 19,  5,  3,  0,  0],
        [ 7, 18, 15,  8,  4,  5, 18,  8,  0,  0,  0,  0],
        [ 9,  7,  3, 14,  0,  0,  0,  0,  0,  0,  0,  0],
        [19,  7, 12, 17,  0,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')


In [12]:
# SOS + rev + PAD
print(tgt_in)
# SOS, 2, 3, 1 => # 2, 3, 1, EOS

tensor([[ 1,  3,  5, 19,  3, 19, 17,  9, 10, 11, 14,  0],
        [ 1,  8, 18,  5,  4,  8, 15, 18,  7,  0,  0,  0],
        [ 1, 14,  3,  7,  9,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 17, 12,  7, 19,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')


In [18]:
# GT of decoder: rev + EOS + PAD
# for calculate loss
print(tgt_out)

tensor([[ 3,  5, 19,  3, 19, 17,  9, 10, 11, 14,  2,  0],
        [ 8, 18,  5,  4,  8, 15, 18,  7,  2,  0,  0,  0],
        [14,  3,  7,  9,  2,  0,  0,  0,  0,  0,  0,  0],
        [17, 12,  7, 19,  2,  0,  0,  0,  0,  0,  0,  0]], device='cuda:0')


# Training Loop

In [19]:
model.train()
print("Training Start...")

EPOCHS = 3000

for epoch in range(EPOCHS):
    # 1. Get batch data
    src, tgt_input, tgt_out = get_batch(BATCH_SIZE)

    # 2. Create masks
    src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src, tgt_input, device)

    optimizer.zero_grad()

    # 3. Forward pass
    logits = model(src, tgt_input, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)

    # 4. Calculate Loss
    # Flatten output to [Batch * Seq, Vocab] for CrossEntropyLoss
    loss = criterion(logits.reshape(-1, VOCAB_SIZE), tgt_out.reshape(-1))

    # 5. Backward pass and Optimization
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f}")

print("Training Complete.")

Training Start...




Epoch   0 | Loss: 3.1807
Epoch  50 | Loss: 1.6873
Epoch 100 | Loss: 1.5163
Epoch 150 | Loss: 1.4808
Epoch 200 | Loss: 1.3423
Epoch 250 | Loss: 1.2897
Epoch 300 | Loss: 1.1034
Epoch 350 | Loss: 0.9643
Epoch 400 | Loss: 0.8307
Epoch 450 | Loss: 0.6528
Epoch 500 | Loss: 0.6369
Epoch 550 | Loss: 0.5581
Epoch 600 | Loss: 0.4480
Epoch 650 | Loss: 0.3925
Epoch 700 | Loss: 0.3075
Epoch 750 | Loss: 0.2828
Epoch 800 | Loss: 0.3065
Epoch 850 | Loss: 0.1718
Epoch 900 | Loss: 0.2212
Epoch 950 | Loss: 0.1399
Epoch 1000 | Loss: 0.1058
Epoch 1050 | Loss: 0.1404
Epoch 1100 | Loss: 0.4015
Epoch 1150 | Loss: 0.1511
Epoch 1200 | Loss: 0.0985
Epoch 1250 | Loss: 0.1048
Epoch 1300 | Loss: 0.0421
Epoch 1350 | Loss: 0.2300
Epoch 1400 | Loss: 0.0781
Epoch 1450 | Loss: 0.0633
Epoch 1500 | Loss: 0.0465
Epoch 1550 | Loss: 0.0316
Epoch 1600 | Loss: 0.1687
Epoch 1650 | Loss: 0.0456
Epoch 1700 | Loss: 0.2620
Epoch 1750 | Loss: 0.0920
Epoch 1800 | Loss: 0.0243
Epoch 1850 | Loss: 0.0538
Epoch 1900 | Loss: 0.0381
Epoch 

# Inference (Greedy Decoding)

In [20]:

def greedy_decode(model, src, max_len, start_symbol):
    src = src.to(device)
    # Encoder Masking
    src_padding_mask = (src == PAD_IDX).to(device)

    memory = model.transformer.encoder(
        model.pos_encoder(model.embedding(src) * math.sqrt(DIM_MODEL)),
        src_key_padding_mask=src_padding_mask
    )

    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)

    for i in range(max_len-1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1), device)

        out = model.transformer.decoder(
            model.pos_encoder(model.embedding(ys) * math.sqrt(DIM_MODEL)),
            memory,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=src_padding_mask
        )

        prob = model.out(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)

        if next_word == EOS_IDX:
            break

    return ys

# --- Run Inference ---
model.eval()
print("\nTesting (Source -> Reversed Prediction):")

# Test Sequence
test_src_list = [3, 4, 5, 6, 7, 8, 9]
print(f"Input Sequence: {test_src_list}")

src = torch.tensor([test_src_list + [PAD_IDX]*(MAX_LEN - len(test_src_list))]).to(device)

# Perform inference
pred_tensor = greedy_decode(model, src, MAX_LEN, start_symbol=SOS_IDX)

# Post-processing
result = pred_tensor.squeeze().tolist()
print(f"Raw Output: {result}")

# Clean up output
final_res = []
for token in result:
    if token == SOS_IDX: continue
    if token == EOS_IDX: break
    final_res.append(token)

print(f"Predicted Result: {final_res}")

# Check correctness
if final_res == test_src_list[::-1]:
    print("✅ Success! The sequence is correctly reversed.")
else:
    print("❌ Failed.")


Testing (Source -> Reversed Prediction):
Input Sequence: [3, 4, 5, 6, 7, 8, 9]
Raw Output: [1, 9, 8, 7, 6, 5, 4, 3, 2]
Predicted Result: [9, 8, 7, 6, 5, 4, 3]
✅ Success! The sequence is correctly reversed.
