In [23]:
# ============================== tesnor lib ============================================ # 
import torch
import torch.nn as nn
# ====================================================================================== # 

from torch_cluster import radius_graph
from torch_cluster import radius
from torch_geometric.data import Batch
import matplotlib.pyplot as plt
from layers import FourierEmbedding
from utils.circle import minimum_enclosing_circle
from modules.encoder import QCNetEncoder
from torch_geometric.utils import dense_to_sparse
from layers.attention_layer import AttentionLayer
from utils.geometry import angle_between_2d_vectors

# ============================== building transformer lib ============================== #  
from transformers import BertModel, BertConfig
# ====================================================================================== # 


In [40]:
class BertTR(nn.Module):
    def __init__(self, hidden_dim, num_layers, num_heads):
        super(BertTR, self).__init__()
        
        config = BertConfig(
            hidden_size=hidden_dim,
            num_hidden_layers=num_layers,
            num_attention_heads=num_heads,
            intermediate_size=hidden_dim * 4,
            max_position_embeddings=512,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
        )
        self.encoder = BertModel(config)

    def forward(self, x, valid_mask):
        # x: [B, t, dim]
        # valid_mask: [B, t]
        
        # Generate time mask
        seq_length = x.size(1)
        time_mask = torch.triu(torch.ones((seq_length, seq_length)), diagonal=1).bool().to(x.device)
        
        # Combine time mask with valid_mask
        extended_valid_mask = valid_mask[:, None, None, :].to(x.device)
        combined_mask = time_mask[None, None, :, :] | (~extended_valid_mask)
        
        # Attention mask for BERT expects float values (0 or -inf)
        attention_mask = combined_mask.float() * -1e9
        # Pass through the transformer encoder
        outputs = self.encoder(inputs_embeds=x, attention_mask=attention_mask)
        
        return outputs.last_hidden_state


In [None]:
import torch
import torch.nn.functional as F 

class Agent_Self_Attn(nn.Module):
    def __init__(self,
                 num_head = 1) -> None:
        super(Agent_Self_Attn, self).__init__()
    
    def forward(self, x)

## agent 2 query cross attention

In [41]:
class Agent_2_query_attn(nn.Module):
    def __init__(self) -> None:
        super(Agent_2_query_attn, self).__init__()

    def forward(self, q, k, v):
        pass 

In [74]:
class Area_Prior(nn.Module):
    def __init__(self,
                 agent_self_attn_hidden_dim = 128,
                 agent_self_attn_layer_num = 2,
                 agent_self_attn_num_head = 8,
                 agent_self_attn_embed_num_freq_bands = 64,
                 agent_self_attn_embed_hidden_dim = 128,
                 history_step = 50,
                 prediction_step = 60,
                 input_dim = 2,
                 agent_embed_input_dim = 4) -> None:
        super(Area_Prior, self).__init__()
        self.history_step = history_step
        self.input_dim = input_dim
        self.agent_self_attn_hidden_dim = agent_self_attn_hidden_dim
        self.agent_self_attn = Agent_self_attn(agent_self_attn_hidden_dim,
                                               agent_self_attn_layer_num,
                                               agent_self_attn_num_head)
        # self.agent2query_attn = Agent_2_query_attn()
        self.linear = nn.Linear(1,1).to(device=torch.device('cuda'))
        self.type_a_emb = nn.Embedding(10, agent_self_attn_hidden_dim)
        self.x_a_emb = FourierEmbedding(input_dim=agent_embed_input_dim, 
                                        hidden_dim=agent_self_attn_embed_hidden_dim, 
                                        num_freq_bands=agent_self_attn_embed_num_freq_bands)
        
        self.proj_geo = nn.Linear(agent_self_attn_hidden_dim, input_dim)
        self.proj_t = nn.Linear(history_step, prediction_step)
        

    def forward(self, data):
        # ============================== setting data ============================== # 
        pred_mask = data['agent']['predict_mask'].any(dim=-1, keepdim=True).squeeze()
        valid_mask_history = data['agent']['valid_mask'][pred_mask,:self.history_step]
        #agent_data_pos = data['agent']['position'][pred_mask,:self.history_step,:self.input_dim]
        #agent_data_heading = data['agent']['heading'][pred_mask,:self.history_step,:self.input_dim]

        pos_a = data['agent']['position'][:, :self.history_step, :self.input_dim].contiguous()
        motion_vector_a = torch.cat([pos_a.new_zeros(data['agent']['num_nodes'], 1, self.input_dim),
                                     pos_a[:, 1:] - pos_a[:, :-1]], dim=1)
        head_a = data['agent']['heading'][:, :self.history_step].contiguous()
        head_vector_a = torch.stack([head_a.cos(), head_a.sin()], dim=-1)
        vel = data['agent']['velocity'][:, :self.history_step, :self.input_dim].contiguous()
        categorical_embs = [self.type_a_emb(data['agent']['type'].long()).repeat_interleave(repeats=self.history_step,dim=0)]
        agent_geo_feat = torch.stack(
                [torch.norm(motion_vector_a[:, :, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=motion_vector_a[:, :, :2]),
                 torch.norm(vel[:, :, :2], p=2, dim=-1),
                 angle_between_2d_vectors(ctr_vector=head_vector_a, nbr_vector=vel[:, :, :2])], dim=-1)
        agent_feat = self.x_a_emb(continuous_inputs=agent_geo_feat.view(-1, agent_geo_feat.size(-1)), categorical_embs=categorical_embs)
        agent_feat = agent_feat.view(-1, self.history_step, self.agent_self_attn_hidden_dim)
        # ============================================================================= # 

        # ============================== extract self_attn feat ============================== #
        agent_feat = self.agent_self_attn(agent_feat, valid_mask_history)
        # ==================================================================================== # 

        agent_feat_t = self.proj_t(agent_feat.permute(0,2,1)).permute(0,2,1)
        agent_feat_pos = self.proj_geo(agent_feat_t)
        return agent_feat_pos


## loss function

In [75]:
class CustomLoss(nn.Module):
    def __init__(self, lambda_area=1.0):
        super(CustomLoss, self).__init__()
        self.lambda_area = lambda_area

    def forward(self, points_A, points_B):
        distribution_loss = self.compute_distribution_loss(points_A, points_B)
        area_loss = self.compute_area_loss(points_A)
        total_loss = distribution_loss + self.lambda_area * area_loss * 0.1
        return total_loss

    def compute_distribution_loss(self, points_A, points_B):
        dist_matrix = torch.cdist(points_A, points_B)
        min_dist, _ = dist_matrix.min(dim=1)
        return min_dist.mean()

    def compute_area_loss(self, points_A):
        points_A = points_A.detach().cpu().numpy()
        if len(points_A) < 3:  # ConvexHull requires at least 3 points
            return torch.tensor(0.0)
        hull = ConvexHull(points_A)
        return torch.tensor(hull.volume, requires_grad=True)

In [77]:
from dataset_prepare.argoverse_v2_dataset import ArgoverseV2Dataset
from torch_geometric.data import DataLoader
import numpy as np

dataset = ArgoverseV2Dataset('D:\\argoverse2', 'train', None, None, None)
loader = DataLoader(dataset,batch_size=1,shuffle=False)
device = torch.device('cuda')

area_prior = Area_Prior(128,1,8).to(device)
optimizer = torch.optim.SGD(area_prior.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = CustomLoss(lambda_area=0.1)
epoch = 1

min_loss = np.inf
best_model_path = 'C:\\Users\\Lenovo\\OneDrive - City University of Hong Kong - Student\\Desktop\\mikumiku\\mikumiku\\best_model.pth'


for epoch_i in range(epoch):
    print(epoch_i)
    for data, loader_num in zip(loader,range(len(loader))):
        loss = 0
        pred = area_prior(data = data.to(device))
        pred_mask = data['agent']['predict_mask'].any(dim=-1, keepdim=True).squeeze()
        gt = data['agent']['position'][pred_mask,:,:]
        for each_pred, each_gt, data_in_loader_i in zip(pred,gt,range(gt.shape[0])):
            each_gt = each_gt[50:,:][data['agent']['predict_mask'][pred_mask][data_in_loader_i,50:],:2]
            loss += criterion(each_pred,each_gt)
        loss.backward()
        optimizer.step()
        # ====================================save best model param==================================== #
        if loss.item() < min_loss:
            min_loss = loss.item()
            torch.save(area_prior.state_dict(), best_model_path)
    print(loss)



0
torch.Size([1, 50, 50])


ValueError: Wrong shape for input_ids (shape torch.Size([30, 50])) or attention_mask (shape torch.Size([22, 1, 50, 50]))

In [None]:
class Area_Anchor(nn.Module):
    def __init__(self) -> None:
        super(Area_Anchor, self).__init__()

        # 1. agent self_attn 提取feature
        # 2. 设计anchor-free的 detr模型推理可行域点集，是预训练的
    def forward(self, data):
        pass

In [None]:
class Pts_Anchor(nn.Module):
    def __init__(self) -> None:
        super(Pts_Anchor, self).__init__()
        # 3. 初始化 Pts-query-content + Pts-query-pos, 并自注意力结合特征， query-pos是几何上的
        # 4. Pts-Query 对 最后一个历史步做解码得到新的 pos_ 和 content_， 并以 pos_new = pos+pos_ ;content = content_ + content 迭代
        # 5. 组合迭代结果对area_archor做ca 实现 时序关联
        # 6. refine
    
    def forward(self, data):
        pass 