In [1]:
import json
import torch
import numpy as np
import os
from torch.utils.data import Dataset, random_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader 
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
import matplotlib.pyplot as plt
import sys
from torch_geometric.utils import to_dense_batch
import torch.optim as optim
import json

sys.path.append(os.path.abspath(".."))

#### 1. Make Data Loaders

- Please indicate the directory where the json graphical dataset is located

- Additionally you can change the following parameters:

    - mode: This lets you choose which features to include in the loaded data. The options are:
        - all: includes all nodes and edges
        - ego_veh: includes all ego and vehicle nodes and edges
        - ego: includes only ego nodes and edges
        - ego_env: includes all ego and environment nodes and edges
        - all_no_edges: includes all node features but no edges
        - ego_veh_no_edges: includes all ego and vehicle node features but no edges
        - ego_no_edges: includes only ego node features with no edges
        - ego_env_no_edges: includes all ego and environment node features but no edges

    - normalize: True/False. If true, this applies a normalization transform + nan to num

    - norm_methos: zscore or l2

In [2]:
from functions.data_loaders import get_graph_dataset

l2d_train_set = get_graph_dataset("../data/processed_graphical/l2d/", 
                                  mode="ego", normalize=True, norm_method="zscore")
l2d_train_loader = DataLoader(l2d_train_set, batch_size=64, shuffle=True)

print(f"Total graphs: {len(l2d_train_set)}")  
for data in l2d_train_loader:
    print(data) 
    break

print('===========')

nup_train_set = get_graph_dataset("../data/processed_graphical/nuplan/", 
                                  mode="ego", normalize=True, norm_method="zscore")
nup_train_loader = DataLoader(l2d_train_set, batch_size=64, shuffle=True)

print(f"Total graphs: {len(nup_train_set)}")  
for data in l2d_train_loader:
    print(data) 
    break

Total graphs: 9593
HeteroDataBatch(
  ego={
    x=[492, 6],
    batch=[492],
    ptr=[65],
  },
  window_meta={ episode_path=[64] },
  (ego, to, ego)={ edge_index=[2, 428] }
)
Total graphs: 88370
HeteroDataBatch(
  ego={
    x=[517, 6],
    batch=[517],
    ptr=[65],
  },
  window_meta={ episode_path=[64] },
  (ego, to, ego)={ edge_index=[2, 453] }
)


### 🚧 Under Construction 🚧

In [3]:
from functions.models import GraphEmbedder, ProjectionHead, kl_divergence_between_gaussians

batch_l2d = next(iter(l2d_train_loader))
batch_nup = next(iter(nup_train_loader))

node_dims_l2d = {
    node_type: batch_l2d[node_type].x.size(1)
    for node_type in batch_l2d.node_types
    if hasattr(batch_l2d[node_type], 'x')
}

node_dims_nup = {
    node_type: batch_nup[node_type].x.size(1)
    for node_type in batch_nup.node_types
    if hasattr(batch_nup[node_type], 'x')
}

print("L2D node dims:", node_dims_l2d)
print("NUP node dims:", node_dims_nup)

l2d_encoder = GraphEmbedder(node_dims_l2d, hidden_dims=[64,128,32])
nup_encoder = GraphEmbedder(node_dims_nup, hidden_dims=[64,128,32])
projector = ProjectionHead(in_dim=32, proj_dim=16)

L2D node dims: {'ego': 6}
NUP node dims: {'ego': 6}


In [4]:
z_l2d = projector(l2d_encoder(batch_l2d))  # [B, 128]
z_nup = projector(nup_encoder(batch_nup))  # [B, 128]
loss = kl_divergence_between_gaussians(z_l2d, z_nup)
loss

tensor(3428.4546, grad_fn=<MulBackward0>)

In [None]:
import torch
from tqdm import tqdm

# Send models to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
l2d_encoder = l2d_encoder.to(device)
nup_encoder = nup_encoder.to(device)
projector = projector.to(device)

# Optimizer for all params
optimizer = torch.optim.Adam(
    list(l2d_encoder.parameters()) + 
    list(nup_encoder.parameters()) + 
    list(projector.parameters()), 
    lr=1e-4
)

# Number of epochs
num_epochs = 40

for epoch in range(num_epochs):
    l2d_encoder.train()
    nup_encoder.train()
    projector.train()

    total_loss = 0.0

    pbar = tqdm(zip(l2d_train_loader, nup_train_loader), 
                total=min(len(l2d_train_loader), len(nup_train_loader)), 
                desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_l2d, batch_nup in pbar:
        batch_l2d = batch_l2d.to(device)
        batch_nup = batch_nup.to(device)

        z_l2d = projector(l2d_encoder(batch_l2d))  # [B, D]
        z_nup = projector(nup_encoder(batch_nup))  # [B, D]

        z_l2d = F.normalize(z_l2d, dim=-1)
        z_nup = F.normalize(z_nup, dim=-1)
        loss = kl_divergence_between_gaussians(z_l2d, z_nup) + kl_divergence_between_gaussians(z_nup, z_l2d)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(pbar)
    print(f"Epoch {epoch+1} complete. Average KL Loss: {avg_loss:.4f}")