In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch import nn
import torchtune
import torch.optim as optim

In [8]:
class GCSTransformer(nn.Module):
    def __init__(self, num_time_bins = 512, num_space_bins = 1024, d_model = 128*6, nhead = 8):
        super().__init__()

        self.num_time_bins = num_time_bins
        self.num_space_bins = num_space_bins

        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        
        self.map_embs = nn.Embedding(16, 128)
        self.traj_embs = nn.Embedding(num_time_bins + num_space_bins + 3, # Plus 3 for start/end of traj, end of seq, and pad tokens
                                        128*6, padding_idx = num_time_bins + num_space_bins + 2)
        
        self.transformer = nn.Transformer(d_model=self.d_model, nhead=self.nhead, 
                                          num_encoder_layers=12, num_decoder_layers=12, 
                                          dim_feedforward=2048, batch_first=True)
        self.rope = torchtune.modules.RotaryPositionalEmbeddings(self.d_head)

        self.linear = nn.Linear(self.d_model, self.traj_embs.num_embeddings)

    
    def forward(self, map_array, trajs):
        assert map_array.shape[0] == trajs.shape[0]

        # Assume map_array has structure:
        # [B, H * W, Channels]        
        batch_size, map_seq_len, channels = map_array.shape
        map_embs = self.map_embs(map_array).reshape(batch_size, map_seq_len, -1)

        batch_size, traj_seq_len = trajs.shape
        traj_embs = self.traj_embs(trajs)

        # Apply rope after concatenating map and traj embs, then split again.
        all_embs = torch.cat([map_embs, traj_embs], dim=1).reshape(batch_size, map_seq_len + traj_seq_len, self.nhead, self.d_head)
        all_embs = self.rope(all_embs).reshape(batch_size, map_seq_len + traj_seq_len, -1)

        map_embs = all_embs[:, :map_seq_len]
        traj_embs = all_embs[:, map_seq_len:]

        # Pass through transformer
        # [batch_size, traj_seq_len, d_model]
        output = self.transformer(src = map_embs, tgt = traj_embs,
                                  tgt_mask = nn.Transformer.generate_square_subsequent_mask(traj_seq_len, device=map_array.device),
                                  src_is_causal = False,
                                  tgt_is_causal = True)
        
        logits = self.linear(output)

        return logits


In [9]:
device = "cuda:0"
policy = GCSTransformer().to(device)

In [11]:
def loss_fn(pred_logits, labels, vocab_size):
    shifted_pred_logits = pred_logits[:, :-1].reshape(-1, vocab_size)
    shifted_labels = labels[:, 1:].reshape(-1)
    return F.cross_entropy(shifted_pred_logits, shifted_labels)

In [None]:

# optimizer = optim.AdamW(policy.parameters(), lr=0.01)
optimizer = optim.SGD(policy.parameters())

In [None]:
# map_tensor: batch size x seq len=36 x channels=6
# traj_tensor: batch size x seq len

for i in range(1000):
    optimizer.zero_grad()
    pred_logits = policy(map_tensor, traj_tensor)
    loss = loss_fn(pred_logits, traj_tensor, policy.traj_embs.num_embeddings)
    loss.backward()
    optimizer.step()
    print(round(loss.item(), 4), round(torch.mean((torch.argmax(pred_logits[:, :-1], dim=-1) == traj_tensor[:, 1:]) * 1.).item(), 4))

0.3566 0.9132
0.344 0.9358
0.3529 0.9245
0.3391 0.9245
0.3484 0.9245
0.3455 0.9283
0.3553 0.9132
0.3399 0.9245
0.3501 0.9132
0.3554 0.9245
0.3462 0.9283
0.3221 0.9208
0.3474 0.9245
0.3418 0.9321
0.354 0.9358
0.35 0.9208
0.3493 0.9245
0.3219 0.9396
0.3362 0.9283
0.3238 0.9321
0.3442 0.917
0.3306 0.9358
0.333 0.9472
0.3387 0.9094
0.3359 0.9321
0.3583 0.9208
0.3256 0.9208
0.3468 0.9208
0.3444 0.9245
0.3493 0.9245
0.3304 0.9358
0.3253 0.9434
0.3392 0.9321
0.3332 0.9321
0.3232 0.9245
0.3179 0.9208
0.3315 0.917
0.3427 0.9321
0.3504 0.9208
0.3345 0.917
0.3337 0.9321
0.3391 0.9283
0.3269 0.9283
0.3384 0.9245
0.3192 0.9208
0.3432 0.9245
0.3287 0.9358
0.3349 0.9283
0.3498 0.917
0.3287 0.9358
0.325 0.9245
0.3432 0.9321
0.3278 0.9245
0.32 0.9245
0.323 0.9396
0.3234 0.9472
0.3232 0.9358
0.3325 0.9283
0.3156 0.9358
0.3221 0.9358
0.3263 0.9434
0.3253 0.9283
0.3208 0.9245
0.3486 0.917
0.3249 0.9321
0.3272 0.9283
0.3295 0.9283
0.3222 0.9434
0.3333 0.9283
0.3533 0.9245
0.3261 0.9208
0.3188 0.9245
0.3275