Learning a sequence of numbres: 1, 2, 3, 4. Using the full nn.Transformer model.

In [None]:
import math
from tempfile import TemporaryDirectory
from typing import Tuple
import torch
from torch import nn, Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
      
class TransformerModel(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super().__init__()
        self.model_type = 'Transformer'
        self.tgt_mask = None
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(d_model, nhead, nlayers, nlayers, d_hid, dropout)
        self.encoder = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)
        self.init_weights()

    def _generate_square_subsequent_mask(self, sz: int):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src: Tensor, tgt: Tensor, use_mask: bool = True):
        if self.tgt_mask is None or self.tgt_mask.size(0) != len(tgt):
            device = tgt.device
            mask = self._generate_square_subsequent_mask(len(tgt)).to(device)
            self.tgt_mask = mask

        src = self.encoder(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        
        tgt = self.encoder(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)
        
        if (use_mask):
          mask = self.tgt_mask
        else:
          mask = None
    
        output = self.transformer(src, tgt, tgt_mask=mask)
        output = self.decoder(output)
        return output
  
  
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:
# Initialize model, optimizer and loss function
model = TransformerModel(ntoken=101, d_model=512, nhead=8, d_hid=2048, nlayers=6)
lr = 0.0002  # learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

# Create the number sequence and prepare input and target tensors
num_seq = torch.arange(1, 101).unsqueeze(1)  # Shape: [100, 1]
input_seq = torch.cat([torch.zeros(1, 1).long(), num_seq[:-1]])  # Add <start> token and remove last number
target_seq = num_seq  # Predict the next number in the sequence

# Model training
model.train()
optimizer.zero_grad()

# Forward pass
output = model(input_seq, input_seq)
print(output.shape)
output = output.view(-1, 101)  # Reshape for loss function
print(output.shape)

# Compute loss
loss = loss_fn(output, target_seq.view(-1))

# Backward pass and optimization
loss.backward()
optimizer.step()

print('Training loss:', loss.item())
