In [1]:
import torch
import numpy as np
from scenarionet import read_dataset_summary, read_scenario
from metadrive.engine.asset_loader import AssetLoader
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, Batch

In [2]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [3]:
torch.cuda.empty_cache()

In [4]:
av2_data =  AssetLoader.file_path("/home/light/Documents/Thesis/preprocessed_dataset", unix_style=False)
dataset_summary, scenario_ids, mapping = read_dataset_summary(dataset_path=av2_data)

scenario_file_name = scenario_ids[0]
scenario = read_scenario(dataset_path=av2_data, mapping=mapping, scenario_file_name=scenario_file_name)
scenario.keys()

dict_keys(['id', 'version', 'length', 'tracks', 'dynamic_map_states', 'map_features', 'metadata'])

In [5]:
import torch
import numpy as np

class AdaptiveTokenScenarioNet:
    def __init__(self, importance_threshold=0.5, max_tokens=512, token_length=10, use_gpu=False):
        self.importance_threshold = importance_threshold
        self.max_tokens = max_tokens
        self.token_length = token_length
        self.device = torch.device('cuda' if (use_gpu and torch.cuda.is_available()) else 'cpu')
        self.type_importance = {
            'VEHICLE': 1.0,
            'PEDESTRIAN': 0.8,
            'BICYCLE': 0.7,
            'STATIC': 0.5,
            'UNKNOWN': 0.3
        }
        self.map_feature_importance = {
            'LANE': 0.9,
            'CROSSWALK': 0.7,
            'STOP_SIGN': 0.6,
            'TRAFFIC_LIGHT': 0.8,
            'UNKNOWN': 0.4
        }
        # Need to check these scores whether its proper or should I turn this to a hyperparameter learning patameter would be super costly

    def score_tracks(self, tracks):
        importance_scores = {}
        for track_id, track in tracks.items():
            state = track['state']
            valid = state['valid']
            if not valid.any():
                importance_scores[track_id] = 0
                continue
            positions = state['position'][valid]
            velocities = state['velocity'][valid]
            headings = state['heading'][valid]
            lengths = state['length'][valid]
            widths = state['width'][valid]
            heights = state['height'][valid]
            activity_score = np.linalg.norm(positions, axis=1).sum()
            velocity_score = np.linalg.norm(velocities, axis=1).sum()
            heading_score = np.abs(headings).sum()
            size_score = lengths.sum() + widths.sum() + heights.sum()
            type_score = self.type_importance.get(track.get('type', 'UNKNOWN'), 0.3)
            importance_scores[track_id] = activity_score + velocity_score + heading_score + size_score + type_score
            # No dynamic map features and the logic for importance score needs to be verified and should I normalise the dtaa here
        return importance_scores

    def score_map_features(self, map_features):
        importance_scores = {}
        for feature_id, feature in map_features.items():
            feature_type = feature.get('type', 'UNKNOWN')
            polyline = feature.get('polyline', [])
            if len(polyline) == 0:
                importance_scores[feature_id] = 0
                continue
            feature_score = len(polyline) * self.map_feature_importance.get(feature_type, 0.4)
            importance_scores[feature_id] = feature_score
        return importance_scores

    def pad_or_truncate(self, token, expected_dim):
        if token.shape[1] != expected_dim:
            token = token[:, :expected_dim] if token.shape[1] > expected_dim else np.pad(
                token, ((0, 0), (0, expected_dim - token.shape[1])), mode='constant'
            )
        if token.shape[0] > self.token_length:
            token = token[:self.token_length]
        else:
            pad_length = self.token_length - token.shape[0]
            pad = np.zeros((pad_length, expected_dim), dtype=np.float32)
            token = np.vstack((token, pad))
        return token

    def tokenize_tracks(self, tracks):
        tokenized_tracks = {}
        for track_id, track in tracks.items():
            state = track['state']
            valid = state['valid']
            if not valid.any():
                tokenized_tracks[track_id] = torch.empty((self.token_length, 9), dtype=torch.float32, device=self.device)
                continue
            positions = state['position'][valid]
            velocities = state['velocity'][valid]
            headings = state['heading'][valid]
            lengths = state['length'][valid]
            widths = state['width'][valid]
            heights = state['height'][valid]
            combined_token = np.column_stack((positions, velocities, headings, lengths, widths, heights))
            token = self.pad_or_truncate(combined_token, expected_dim=9)
            tokenized_tracks[track_id] = torch.tensor(token, dtype=torch.float32, device=self.device)
        return tokenized_tracks

    def tokenize_map_features(self, map_features):
        tokenized_features = {}
        for feature_id, feature in map_features.items():
            polyline = feature.get('polyline', [])
            if len(polyline) == 0:
                tokenized_features[feature_id] = torch.empty((self.token_length, 2), dtype=torch.float32, device=self.device)
                continue
            polyline_array = np.array(polyline)
            polyline_array = polyline_array[:, :2] if polyline_array.shape[1] > 2 else polyline_array
            token = self.pad_or_truncate(polyline_array, expected_dim=2)
            tokenized_features[feature_id] = torch.tensor(token, dtype=torch.float32, device=self.device)
            # Currently just added polyline need to check left right neighbours and need to still properluy sturcture also need to add boundary
            # Traffic info is not available
        return tokenized_features

    def process(self, scenario):
        tracks = scenario.get('tracks', {})
        map_features = scenario.get('map_features', {})
        metadata = scenario.get('metadata', {})
        track_scores = self.score_tracks(tracks)
        map_scores = self.score_map_features(map_features)
        tokenized_tracks = self.tokenize_tracks(tracks)
        tokenized_map_features = self.tokenize_map_features(map_features)
        for track_id, track in tracks.items():
            track['importance_score'] = track_scores[track_id]
            track['token'] = tokenized_tracks[track_id]
        for feature_id, feature in map_features.items():
            feature['importance_score'] = map_scores[feature_id]
            feature['token'] = tokenized_map_features[feature_id]
        scenario['tracks'] = tracks
        scenario['map_features'] = map_features
        return scenario

In [6]:
# scenarios = [
#     {
#         'id': 'scenario_1',
#         'version': 1,
#         'length': 10,
#         'tracks': {
#             'track_1': {'state': {'position': np.random.rand(10, 3), 'velocity': np.random.rand(10, 2), 
#                                   'heading': np.random.rand(10), 'length': np.random.rand(10), 
#                                   'width': np.random.rand(10), 'height': np.random.rand(10), 
#                                   'valid': np.array([True] * 10), 'type': 'VEHICLE'
#                                   }
#                                   }
#         },
#         'map_features': {
#             'lane_1': {'type': 'LANE', 'polyline': np.random.rand(10, 2)}
#         },
#         'metadata': {'scenario_id': 'sample_001', 'map': 'city_map_1'}
#     }
# ]
tokenizer = AdaptiveTokenScenarioNet(token_length=10, use_gpu=True)
tokens = tokenizer.process(scenario)



In [7]:
def custom_collate_fn(batch):
    batch_data = {'tracks': [], 'map_features': [], 'edge_index': []}
    for scenario in batch:
        if 'tracks' in scenario and 'map_features' in scenario:
            track_tokens = [t['token'] for t in scenario['tracks'].values()]
            map_tokens = [m['token'] for m in scenario['map_features'].values()]
            #Adding dummy edge_index for now latere need to write a logic to extract the data from the scenario, I am thinking idwise a->b->c->d and vice versa
            edge_index = scenario.get('edge_index', torch.empty((2, 0), dtype=torch.long))
            batch_data['tracks'].append(torch.stack(track_tokens))
            batch_data['map_features'].append(torch.stack(map_tokens))
            batch_data['edge_index'].append(edge_index)
    batch_data['tracks'] = torch.cat(batch_data['tracks'], dim=0) if batch_data['tracks'] else torch.empty(0)
    batch_data['map_features'] = torch.cat(batch_data['map_features'], dim=0) if batch_data['map_features'] else torch.empty(0)
    batch_data['edge_index'] = torch.cat(batch_data['edge_index'], dim=1) if batch_data['edge_index'] else torch.empty((2, 0), dtype=torch.long)
    return batch_data


class AVQVAE(nn.Module):
    def __init__(self, lr=1e-2, epochs=2000, commitment_loss_weight=0.1, device='cuda'):
        super(AVQVAE, self).__init__()
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.track_encoder = nn.Sequential(nn.Linear(9, 256), nn.LayerNorm(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU(), nn.Dropout(0.2),
            nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=256, nhead=2, dim_feedforward=256, dropout=0.2), num_layers=6),
            nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU())
        self.map_encoder_gnn = GCNConv(2, 256)
        self.map_encoder_post = nn.Sequential(nn.LayerNorm(256), nn.ReLU(), nn.Dropout(0.2))
        self.codebook = nn.Embedding(256, 256)
        self.track_decoder = nn.Sequential(nn.Linear(256, 256), nn.LayerNorm(256),nn.ReLU(),nn.Dropout(0.2), nn.Linear(256, 256),nn.LayerNorm(256),nn.ReLU(),nn.Linear(256, 9))
        self.map_decoder = nn.Sequential(nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 256), nn.LayerNorm(256), nn.ReLU(), nn.Linear(256, 2))
        # GAT is not possible because av2 scenarionet doesnt proivide dynamic map details its empty
        self.epochs = epochs
        self.optimizer = optim.AdamW(self.parameters(), lr=lr)
        self.commitment_loss_weight = commitment_loss_weight
        self.reconstruction_loss = nn.SmoothL1Loss()
        self.to(self.device)
    
    def forward(self, track_tokens, map_tokens, edge_index):
        track_tokens = track_tokens.to(self.device)
        map_tokens = map_tokens.to(self.device)
        edge_index = edge_index.to(self.device)
        track_encoded = self.track_encoder(track_tokens)
        #Honestly, I have no idea whether I should normalize the map and track tokens or not
        graph_data = Data(x=map_tokens, edge_index=edge_index)
        map_encoded = self.map_encoder_gnn(graph_data.x, graph_data.edge_index)
        map_encoded = self.map_encoder_post(map_encoded)
        track_quantized = self.codebook(torch.argmax(track_encoded, dim=-1))
        map_quantized = self.codebook(torch.argmax(map_encoded, dim=-1))
        # Thinking of adding another layer for map_quamtized not sure if its a good idea to increase the parameters
        track_decoded = self.track_decoder(track_quantized)
        map_decoded = self.map_decoder(map_quantized)
        return track_decoded, map_decoded, track_quantized, map_quantized, track_encoded, map_encoded
    
    def calculate_loss(self, track_tokens, map_tokens, track_decoded, map_decoded, track_quantized, map_quantized, track_encoded, map_encoded):
        # Need to play around this yet I just added the loss function from original paper
        track_decoded = track_decoded.view(-1, track_decoded.shape[-1])
        map_decoded = map_decoded.view(-1, map_decoded.shape[-1])
        track_loss = self.reconstruction_loss(track_decoded, track_tokens.view(-1, track_tokens.shape[-1]))
        map_loss = self.reconstruction_loss(map_decoded, map_tokens.view(-1, map_tokens.shape[-1]))
        commitment_loss = self.commitment_loss_weight * (torch.mean((track_encoded.detach()-track_quantized)**2)+torch.mean((map_encoded.detach()- map_quantized)** 2))
        codebook_loss = torch.mean((track_encoded - track_quantized.detach()) ** 2) + \
                         torch.mean((map_encoded - map_quantized.detach()) ** 2)
        total_loss = track_loss + map_loss + commitment_loss + codebook_loss
        return total_loss, track_loss, map_loss, commitment_loss, codebook_loss
    
    def train_model(self, scenarios):
        self.train()
        dataloader = DataLoader(scenarios, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
        for epoch in range(self.epochs):
            total_loss = 0
            for batch in dataloader:
                if batch['tracks'].nelement() > 0 and batch['map_features'].nelement() > 0:
                    track_tokens = batch['tracks'].to(self.device)
                    map_tokens = batch['map_features'].to(self.device)
                    edge_index = batch['edge_index'].to(self.device)
                    self.optimizer.zero_grad()
                    track_decoded, map_decoded, track_quantized, map_quantized, track_encoded, map_encoded = self.forward(track_tokens, map_tokens, edge_index)
                    loss, track_loss, map_loss, commitment_loss, codebook_loss = self.calculate_loss(track_tokens, map_tokens, track_decoded, map_decoded, track_quantized, map_quantized, track_encoded, map_encoded)
                    loss.backward()
                    nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
                    self.optimizer.step()
                    total_loss += loss.item()
                
            print(f"Epoch [{epoch+1}/{self.epochs}] | Total Loss: {total_loss / len(dataloader):.4f} | "
                  f"Track Loss: {track_loss.item():.4f} | Map Loss: {map_loss.item():.4f} | "
                  f"Commitment Loss: {commitment_loss.item():.4f} | Codebook Loss: {codebook_loss.item():.4f}")
        torch.save(self.state_dict(), 'avqvae_final.pth')


In [None]:
vqvae = AVQVAE(device='cuda')
vqvae.train_model([tokens])
# Just loaded one scenario need to check the loss and other stuffs currently the loss is not decreasing



Epoch [1/2000] | Total Loss: 1871.2417 | Track Loss: 467.1244 | Map Loss: 1400.9512 | Commitment Loss: 0.2878 | Codebook Loss: 2.8783
Epoch [2/2000] | Total Loss: 1869.4338 | Track Loss: 466.3592 | Map Loss: 1399.5183 | Commitment Loss: 0.3233 | Codebook Loss: 3.2331
Epoch [3/2000] | Total Loss: 1867.4413 | Track Loss: 465.6418 | Map Loss: 1398.5233 | Commitment Loss: 0.2978 | Codebook Loss: 2.9783
Epoch [4/2000] | Total Loss: 1867.3229 | Track Loss: 465.3111 | Map Loss: 1398.8956 | Commitment Loss: 0.2833 | Codebook Loss: 2.8328
Epoch [5/2000] | Total Loss: 1865.1094 | Track Loss: 464.9706 | Map Loss: 1397.2604 | Commitment Loss: 0.2617 | Codebook Loss: 2.6167
Epoch [6/2000] | Total Loss: 1863.3250 | Track Loss: 464.5686 | Map Loss: 1396.0995 | Commitment Loss: 0.2415 | Codebook Loss: 2.4152
Epoch [7/2000] | Total Loss: 1861.7893 | Track Loss: 464.2170 | Map Loss: 1395.0173 | Commitment Loss: 0.2323 | Codebook Loss: 2.3226
Epoch [8/2000] | Total Loss: 1860.0726 | Track Loss: 463.8891 