In [24]:
import torch
import numpy as np
from scenarionet import read_dataset_summary, read_scenario
from metadrive.engine.asset_loader import AssetLoader

In [82]:
torch.cuda.empty_cache()

In [94]:
class AdaptiveTokenizer:
    def __init__(self, importance_threshold=0.5, max_tokens=512, token_length=10, use_gpu=False):
        self.importance_threshold = importance_threshold
        self.max_tokens = max_tokens
        self.token_length = token_length
        self.device = torch.device('cuda' if (use_gpu and torch.cuda.is_available()) else 'cpu')
        self.type_importance = {
            'VEHICLE': 1.0,
            'PEDESTRIAN': 0.8,
            'BICYCLE': 0.7,
            'STATIC': 0.5,
            'UNKNOWN': 0.3
        }
        self.map_feature_importance = {
            'LANE': 0.9,
            'CROSSWALK': 0.7,
            'STOP_SIGN': 0.6,
            'TRAFFIC_LIGHT': 0.8,
            'UNKNOWN': 0.4
        }

    def score_importance(self, tracks, map_features):
        importance_map = np.zeros(len(tracks) + len(map_features))
        
        for i, (track_id, track) in enumerate(tracks.items()):
            state = track['state']
            valid = state['valid']
            if not valid.any():
                importance_map[i] = 0
                continue

            valid_positions = state['position'][valid]
            valid_velocities = state['velocity'][valid]
            valid_headings = state['heading'][valid]
            valid_lengths = state['length'][valid]
            valid_widths = state['width'][valid]
            valid_heights = state['height'][valid]
            
            activity_score = np.linalg.norm(valid_positions, axis=1).sum()
            velocity_score = np.linalg.norm(valid_velocities, axis=1).sum()
            heading_score = np.abs(valid_headings).sum()
            size_score = valid_lengths.sum() + valid_widths.sum() + valid_heights.sum()
            type_score = self.type_importance.get(track.get('type', 'UNKNOWN'), 0.3)
            
            importance_map[i] = activity_score + velocity_score + heading_score + size_score + type_score

        for i, (feature_id, feature) in enumerate(map_features.items(), start=len(tracks)):
            feature_type = feature.get('type', 'UNKNOWN')
            polyline = feature.get('polyline', [])
            if len(polyline) == 0:
                importance_map[i] = 0
                continue
            feature_score = len(polyline) * self.map_feature_importance.get(feature_type, 0.4)
            importance_map[i] = feature_score
        
        importance_map = importance_map / (importance_map.max() + 1e-5)
        return importance_map

    def pad_or_truncate(self, token, expected_dim):
        if token.shape[1] != expected_dim:
            token = token[:, :expected_dim] if token.shape[1] > expected_dim else np.pad(token, ((0, 0), (0, expected_dim - token.shape[1])), mode='constant')
        if token.shape[0] > self.token_length:
            token = token[:self.token_length]
        else:
            pad_length = self.token_length - token.shape[0]
            pad = np.zeros((pad_length, expected_dim), dtype=np.float32)
            token = np.vstack((token, pad))
        return token

    def tokenize(self, tracks, map_features, metadata):
        importance_scores = self.score_importance(tracks, map_features)
        track_regions = []
        map_regions = []
        token_types = []
        token_ids = []

        for i, (track_id, track) in enumerate(tracks.items()):
            state = track['state']
            valid = state['valid']
            if not valid.any():
                continue
            
            valid_positions = state['position'][valid]
            valid_velocities = state['velocity'][valid]
            valid_headings = state['heading'][valid]
            valid_lengths = state['length'][valid]
            valid_widths = state['width'][valid]
            valid_heights = state['height'][valid]
            
            combined_token = np.column_stack((valid_positions, valid_velocities, valid_headings, valid_lengths, valid_widths, valid_heights))
            token = self.pad_or_truncate(combined_token, expected_dim=9)
            token_types.append(f'high-detail-{track.get("type", "UNKNOWN")}')
            track_regions.append(torch.tensor(token, dtype=torch.float32, device=self.device))
            token_ids.append(track_id)

        for i, (feature_id, feature) in enumerate(map_features.items(), start=len(tracks)):
            polyline = feature.get('polyline', [])
            if len(polyline) > 0:
                polyline_array = np.array(polyline)
                polyline_array = polyline_array[:, :2] if polyline_array.shape[1] > 2 else polyline_array
                token = self.pad_or_truncate(polyline_array, expected_dim=2)
                token_types.append(f'map-{feature.get("type", "UNKNOWN")}')
                map_regions.append(torch.tensor(token, dtype=torch.float32, device=self.device))
                token_ids.append(feature_id)

        track_tensor = torch.cat(track_regions, dim=0) if track_regions else torch.empty((0, self.token_length, 9), device=self.device)
        map_tensor = torch.cat(map_regions, dim=0) if map_regions else torch.empty((0, self.token_length, 2), device=self.device)

        return {
            'track_regions': track_tensor,
            'map_regions': map_tensor,
            'token_types': token_types,
            'token_ids': token_ids,
            'metadata': metadata
        }


In [95]:
sample_tracks = {
        'track_1': {'state': {'position': np.random.rand(10, 3), 'velocity': np.random.rand(10, 2), 'heading': np.random.rand(10), 
                              'length': np.random.rand(10), 'width': np.random.rand(10), 'height': np.random.rand(10), 
                              'valid': np.array([True] * 10)}, 'type': 'VEHICLE'}
    }
sample_map_features = {
        'lane_1': {'type': 'LANE', 'polyline': np.random.rand(5, 2)}
    }
sample_metadata = {'scenario_id': 'sample_001', 'map': 'city_map_1'}

tokenizer = AdaptiveTokenizer(token_length=10, use_gpu=True)
tokens = tokenizer.tokenize(scenario["tracks"], scenario["map_features"], scenario["metadata"])

print("Tokenization Complete!")
print(f"Track Tensor Shape: {tokens['track_regions'].shape}")
print(f"Map Tensor Shape: {tokens['map_regions'].shape}")
print(f"Token Types: {set(tokens['token_types'])}")
print(f"Token IDs: {tokens['token_ids']}")
print(f"Metadata: {tokens['metadata']}")

Tokenization Complete!
Track Tensor Shape: torch.Size([210, 9])
Map Tensor Shape: torch.Size([2050, 2])
Token Types: {'map-ROAD_LINE_BROKEN_SINGLE_WHITE', 'map-ROAD_LINE_SOLID_SINGLE_WHITE', 'map-ROAD_LINE_SOLID_DOUBLE_YELLOW', 'high-detail-OTHER', 'high-detail-VEHICLE', 'map-UNKNOWN_LINE', 'map-LANE_SURFACE_STREET'}
Token IDs: ['73980', '74036', '74140', '74157', '74167', '74225', '74231', '74250', '74252', '74253', '74254', '74256', '74262', '74263', '74264', '74271', '74275', '74276', '74286', '74290', 'AV', '942716285', '942716286', '471289997', '942716768', '942716769', '471290480', '942716831', '942716832', '471290543', '942716964', '942716965', '471290676', '942716984', '942716985', '471290696', '942717007', '942717008', '471290719', '942717015', '942717016', '471290727', '942717250', '942717251', '471290962', '942717338', '942717339', '471291050', '942717350', '942717351', '471291062', '942717352', '471291063', '942717361', '942717362', '471291073', '942717363', '942717364', '4

In [99]:
tokens['track_regions'][0]

tensor([ 2.7942e+03,  1.3763e+03,  0.0000e+00,  4.8375e+00, -5.1642e+00,
        -8.1722e-01,  4.0000e+00,  2.0000e+00,  1.0000e+00], device='cuda:0')

In [73]:
av2_data =  AssetLoader.file_path("/home/light/Documents/Thesis/preprocessed_dataset", unix_style=False)
dataset_summary, scenario_ids, mapping = read_dataset_summary(dataset_path=av2_data)

scenario_file_name = scenario_ids[0]
scenario = read_scenario(dataset_path=av2_data, mapping=mapping, scenario_file_name=scenario_file_name)

In [74]:
scenario.keys()

dict_keys(['id', 'version', 'length', 'tracks', 'dynamic_map_states', 'map_features', 'metadata'])