In [1]:
from map_gnn_encoder import SMARTMapDecoder
import pickle
import torch
import sys
import os
import contextlib
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.data import HeteroData
from torch.optim.lr_scheduler import LambdaLR
import math
import numpy as np
from collections import defaultdict

# Add the parent directory to the Python path to import SMARTMapDecoder
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../../'))


# Define the SMART class copied from smart_z_infer.py
class SMART(nn.Module):

    def __init__(self) -> None:
        super(SMART, self).__init__()
        self.warmup_steps = 1000
        self.lr = 1e-4
        self.total_steps = 100000
        self.dataset = 'waymo'
        self.input_dim = 128
        self.hidden_dim = 128
        self.output_dim = 128
        self.output_head = True
        self.num_historical_steps = 11
        self.num_future_steps = 80
        self.num_freq_bands = 64
        self.vis_map = False
        self.noise = True
        module_dir = os.path.abspath('../../small_model/SMART/smart')
        self.map_token_traj_path = os.path.join(module_dir, 'tokens/map_traj_token5.pkl')
        self.init_map_token()
        self.inference_token = False
        self.rollout_num = 1

    def init_map_token(self):
        self.argmin_sample_len = 3
        with open(self.map_token_traj_path, 'rb') as f:
            map_token_traj = pickle.load(f)
        self.map_token = {'traj_src': map_token_traj['traj_src'], }
        traj_end_theta = np.arctan2(self.map_token['traj_src'][:, -1, 1]-self.map_token['traj_src'][:, -2, 1],
                                    self.map_token['traj_src'][:, -1, 0]-self.map_token['traj_src'][:, -2, 0])
        indices = torch.linspace(0, self.map_token['traj_src'].shape[1]-1, steps=self.argmin_sample_len).long()
        self.map_token['sample_pt'] = torch.from_numpy(self.map_token['traj_src'][:, indices]).to(torch.float)
        self.map_token['traj_end_theta'] = torch.from_numpy(traj_end_theta).to(torch.float)
        self.map_token['traj_src'] = torch.from_numpy(self.map_token['traj_src']).to(torch.float)

    def match_token_map(self, data):
        traj_pos = data['map_save']['traj_pos'].to(torch.float)
        traj_theta = data['map_save']['traj_theta'].to(torch.float)
        pl_idx_list = data['map_save']['pl_idx_list']
        token_sample_pt = self.map_token['sample_pt'].to(traj_pos.device)
        token_src = self.map_token['traj_src'].to(traj_pos.device)
        max_traj_len = self.map_token['traj_src'].shape[1]
        pl_num = traj_pos.shape[0]

        pt_token_pos = traj_pos[:, 0, :].clone()
        pt_token_orientation = traj_theta.clone()
        cos, sin = traj_theta.cos(), traj_theta.sin()
        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
        rot_mat[..., 0, 0] = cos
        rot_mat[..., 0, 1] = -sin
        rot_mat[..., 1, 0] = sin
        rot_mat[..., 1, 1] = cos
        traj_pos_local = torch.bmm((traj_pos - traj_pos[:, 0:1]), rot_mat.view(-1, 2, 2))
        distance = torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1))
        pt_token_id = torch.argmin(distance, dim=1)

        if self.noise:
            topk_indices = torch.argsort(torch.sum((token_sample_pt[None] - traj_pos_local.unsqueeze(1))**2, dim=(-2, -1)), dim=1)[:, :8]
            sample_topk = torch.randint(0, topk_indices.shape[-1], size=(topk_indices.shape[0], 1), device=topk_indices.device)
            pt_token_id = torch.gather(topk_indices, 1, sample_topk).squeeze(-1)

        cos, sin = traj_theta.cos(), traj_theta.sin()
        rot_mat = traj_theta.new_zeros(pl_num, 2, 2)
        rot_mat[..., 0, 0] = cos
        rot_mat[..., 0, 1] = sin
        rot_mat[..., 1, 0] = -sin
        rot_mat[..., 1, 1] = cos
        token_src_world = torch.bmm(token_src[None, ...].repeat(pl_num, 1, 1, 1).reshape(pl_num, -1, 2),
                                    rot_mat.view(-1, 2, 2)).reshape(pl_num, token_src.shape[0], max_traj_len, 2) + traj_pos[:, None, [0], :]
        token_src_world_select = token_src_world.view(-1, 1024, 11, 2)[torch.arange(pt_token_id.view(-1).shape[0]), pt_token_id.view(-1)].view(pl_num, max_traj_len, 2)

        pl_idx_full = pl_idx_list.clone()
        token2pl = torch.stack([torch.arange(len(pl_idx_list), device=traj_pos.device), pl_idx_full.long()])
        count_nums = []
        for pl in pl_idx_full.unique():
            pt = token2pl[0, token2pl[1, :] == pl]
            left_side = (data['pt_token']['side'][pt] == 0).sum()
            right_side = (data['pt_token']['side'][pt] == 1).sum()
            center_side = (data['pt_token']['side'][pt] == 2).sum()
            count_nums.append(torch.Tensor([left_side, right_side, center_side]))
        count_nums = torch.stack(count_nums, dim=0)
        num_polyline = int(count_nums.max().item())
        traj_mask = torch.zeros((int(len(pl_idx_full.unique())), 3, num_polyline), dtype=bool)
        idx_matrix = torch.arange(traj_mask.size(2)).unsqueeze(0).unsqueeze(0)
        idx_matrix = idx_matrix.expand(traj_mask.size(0), traj_mask.size(1), -1)  #
        counts_num_expanded = count_nums.unsqueeze(-1)
        mask_update = idx_matrix < counts_num_expanded
        traj_mask[mask_update] = True

        data['pt_token']['traj_mask'] = traj_mask
        data['pt_token']['position'] = torch.cat([pt_token_pos, torch.zeros((data['pt_token']['num_nodes'], 1),
                                                                            device=traj_pos.device, dtype=torch.float)], dim=-1)
        data['pt_token']['orientation'] = pt_token_orientation
        data['pt_token']['height'] = data['pt_token']['position'][:, -1]
        data[('pt_token', 'to', 'map_polygon')] = {}
        data[('pt_token', 'to', 'map_polygon')]['edge_index'] = token2pl
        data['pt_token']['token_idx'] = pt_token_id
        return data

    def sample_pt_pred(self, data):
        traj_mask = data['pt_token']['traj_mask']
        raw_pt_index = torch.arange(1, traj_mask.shape[2]).repeat(traj_mask.shape[0], traj_mask.shape[1], 1)
        masked_pt_index = raw_pt_index.view(-1)[torch.randperm(raw_pt_index.numel())[:traj_mask.shape[0]*traj_mask.shape[1]*((traj_mask.shape[2]-1)//3)].reshape(traj_mask.shape[0], traj_mask.shape[1], (traj_mask.shape[2]-1)//3)]
        masked_pt_index = torch.sort(masked_pt_index, -1)[0]
        pt_valid_mask = traj_mask.clone()
        pt_valid_mask.scatter_(2, masked_pt_index, False)
        pt_pred_mask = traj_mask.clone()
        pt_pred_mask.scatter_(2, masked_pt_index, False)
        tmp_mask = pt_pred_mask.clone()
        tmp_mask[:, :, :] = True
        tmp_mask.scatter_(2, masked_pt_index-1, False)
        pt_pred_mask.masked_fill_(tmp_mask, False)
        pt_pred_mask = pt_pred_mask * torch.roll(traj_mask, shifts=-1, dims=2)
        pt_target_mask = torch.roll(pt_pred_mask, shifts=1, dims=2)

        data['pt_token']['pt_valid_mask'] = pt_valid_mask[traj_mask]
        data['pt_token']['pt_pred_mask'] = pt_pred_mask[traj_mask]
        data['pt_token']['pt_target_mask'] = pt_target_mask[traj_mask]

        return data




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append(os.path.abspath('../../small_model/SMART/'))

from smart.datasets.preprocess import TokenProcessor

# Load the pre-processed data
file_path = '../../small_model/SMART/data/waymo_processed/training/1001f7ed426203e1.pkl'
with open(file_path, 'rb') as f:
    data = pickle.load(f)

# Instantiate the TokenProcessor and process the data
token_processor = TokenProcessor(token_size=2048)
processed_data = token_processor.preprocess(data)
smart_model = SMART()
processed_data = smart_model.match_token_map(data)
processed_data = smart_model.sample_pt_pred(processed_data)


In [None]:
# Instantiate the SMARTMapDecoder
# Assuming the encoder expects these dimensions. You may need to adjust them.
smart_map_decoder = SMARTMapDecoder(
    input_dim=2,
    hidden_dim=128,
    num_historical_steps=11,
    pl2pl_radius=50.0,
    num_freq_bands=64,
    num_layers=2,
    num_heads=8,
    head_dim=16,
    dropout=0.1,
    map_token_traj_path='../../small_model/SMART/smart/tokens/map_traj_token5.pkl',
    data_dtype='float32'
)

# Get the map features from the encoder
# The TokenProcessor might modify the data structure, so we pass the whole dictionary
smart_map_decoder = smart_map_decoder.to(dtype=torch.float32)
with torch.autocast(device_type="cuda", dtype=torch.float32):
    map_features = smart_map_decoder(processed_data)

print("Map features obtained successfully!")
print("Shape of map features:", map_features.shape)

RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

In [12]:
processed_data['pt_token']['token_idx']

tensor([473,  52,  52,  ...,  39, 634,  39])