In [1]:
%reload_ext autoreload
%autoreload 2

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import random
import time
import os

import sys
sys.path.append("..")
from data import get_dataset # custom helper function to get dataset

In [28]:
BATCH_SIZE = 128

In [29]:
train_data, val_data, test_data = get_dataset(["train", "val", "test"])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=6)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=6)

class WrappedDataLoader:
    def __init__(self, dataloader, func):
        self.dataloader = dataloader
        self.func = func
        
    def __len__(self):
        return len(self.dataloader)
    
    def __iter__(self):
        iter_dataloader = iter(self.dataloader)
        for batch in iter_dataloader:
            yield self.func(*batch)
            
def preprocess(x, y):
    """
    x = [batch size, 20, 2], encoder input
    y = [batch size, 30, 2], decoder target output
    we also need decoder input in train and eval stage
    decoder input is the last sequence in x concated with y without the last sequence
    """
    encoder_input = x
    decoder_input = torch.cat((x[:, -1, :].unsqueeze(1), y[:, :-1, :]), 1)
    target = y
    return encoder_input.transpose(0, 1), decoder_input.transpose(0, 1), target.transpose(0, 1)

train_loader = WrappedDataLoader(train_loader, preprocess)
val_loader = WrappedDataLoader(val_loader, preprocess)

In [30]:
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [31]:
for x, y, z in train_loader:
    print(x.shape)
    print(y.shape)
    print(z.shape)
    break

torch.Size([20, 128, 2])
torch.Size([30, 128, 2])
torch.Size([30, 128, 2])


In [42]:
class TrajectoryTransformer(nn.Module):
    def __init__(self,
                 device,
                 input_seq_len: int = 20,
                 target_seq_len: int = 30,
                 input_dim: int = 2,
                 output_dim: int = 2,
                 d_model: int = 512,
                 nhead: int = 8,
                 num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 dim_feedforward: int = 2048,
                 dropout: float = 0.1,
                 activation: str = 'relu'):
        super().__init__()
        self.input_seq_len = input_seq_len
        self.target_seq_len = target_seq_len
        self.total_seq_len = self.input_seq_len + self.target_seq_len
        self.device = device
        
        self.encoder_embedding = nn.Linear(input_dim, d_model)
        self.decoder_embedding = nn.Linear(output_dim, d_model)
        self.pos_embedding = nn.Embedding(self.total_seq_len, d_model)
        self.transformer = nn.Transformer(d_model,
                                          nhead,
                                          num_encoder_layers,
                                          num_decoder_layers,
                                          dim_feedforward,
                                          dropout,
                                          activation)
        self.linear = nn.Linear(d_model, output_dim)
        
    def batch_position(self, batch_size, start, end):
        """
        return tensor shape: [end - start + 1, batch size]
        content is like the following:
        [
         [start, start, ..., start],
         [start + 1, start + 1, ..., start + 1],
         ...
         [end - 1, end - 1, ..., end - 1]
        ]
        """
        return torch.arange(start, end).unsqueeze(0).repeat(batch_size, 1).transpose(0, 1).to(self.device)

    
    def forward(self, encoder_input, decoder_input):
        batch_size = encoder_input.shape[1]
        
        encoder_input_pos = self.batch_position(batch_size, 0, self.input_seq_len)
        decoder_input_pos = self.batch_position(batch_size,
                                                self.input_seq_len - 1,
                                                self.total_seq_len - 1)
        
        decoder_input_len = decoder_input.shape[0]
        decoder_mask = self.transformer.generate_square_subsequent_mask(decoder_input_len).to(self.device)
        
        encoder_input = self.encoder_embedding(encoder_input) + self.pos_embedding(encoder_input_pos)
        decoder_input = self.decoder_embedding(decoder_input) + self.pos_embedding(decoder_input_pos)
        output = self.transformer(encoder_input, decoder_input, tgt_mask = decoder_mask)
        
        return self.linear(output)

In [43]:
model = TrajectoryTransformer(device = dev).to(dev)

In [44]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 44,170,242 trainable parameters


In [45]:
def initialize_weights(model):
    if hasattr(model, 'weight') and model.weight.dim() > 1:
        nn.init.xavier_uniform_(model.weight.data)

In [46]:
model.apply(initialize_weights)

TrajectoryTransformer(
  (encoder_embedding): Linear(in_features=2, out_features=512, bias=True)
  (decoder_embedding): Linear(in_features=2, out_features=512, bias=True)
  (pos_embedding): Embedding(50, 512)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): Multih

In [47]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.MSELoss()

In [48]:
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, (encoder_input, decoder_input, target) in enumerate(iterator):
        encoder_input = encoder_input.to(dev)
        decoder_input = decoder_input.to(dev)
        target = target.to(dev)
        
        optimizer.zero_grad()
        output = model(encoder_input, decoder_input)
        loss = criterion(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [49]:
def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, (encoder_input, decoder_input, target) in enumerate(iterator):
            encoder_input = encoder_input.to(dev)
            decoder_input = decoder_input.to(dev)
            target = target.to(dev)
            
            output = model(encoder_input, decoder_input)
            loss = criterion(output, target)
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [50]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [52]:
best_val_loss = float('inf')

In [None]:
import os

N_EPOCHES = 60
CLIP = 1

# load previous best model params if exists
model_dir = "saved_models/Transformer"
saved_model_path = model_dir + "/best_transformer.pt"
if os.path.isfile(saved_model_path):
    model.load_state_dict(torch.load(saved_model_path))
    print("successfully load previous best model parameters")
    
for epoch in range(N_EPOCHES):
    start_time = time.time()
    
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
    val_loss = evaluate(model, val_loader, criterion)
    
    end_time = time.time()
    
    mins, secs = epoch_time(start_time, end_time)
    
    print(F'Epoch: {epoch+1:02} | Time: {mins}m {secs}s')
    print(F'\tTrain Loss: {train_loss:.3f}')
    print(F'\t Val. Loss: {val_loss:.3f}')

    if val_loss < best_val_loss:
        os.makedirs(model_dir, exist_ok=True)
        torch.save(model.state_dict(), saved_model_path)

successfully load previous best model parameters
Epoch: 01 | Time: 7m 11s
	Train Loss: 910835.144
	 Val. Loss: 589285.062
Epoch: 02 | Time: 7m 12s
	Train Loss: 1215558.381
	 Val. Loss: 1166729.546
Epoch: 03 | Time: 7m 12s
	Train Loss: 1213849.419
	 Val. Loss: 1145203.096
Epoch: 04 | Time: 7m 12s
	Train Loss: 1213756.981
	 Val. Loss: 1143709.401
