In [1]:
import torch, torch.nn as nn

import sys, os
sys.path.append(os.path.abspath(".."))
from model.module.utils import DynamicMLP
from model.module.utils import DropPath
from diffusion_planner.model.module.mixer import MixerBlock

from diffusion_planner.model.module.encoder import(
  Encoder,SelfAttentionBlock,
  AgentFusionEncoder, StaticFusionEncoder,
  LaneFusionEncoder, FusionEncoder
)
import yaml

In [2]:
"""
  The forward of diffusion planner is:
    planner_encoder ---> planner_decoder

  planner_encoder: inputs
      fusion_encoder( [neighbor_encoder(neighbors) + static_encoder(static) +\
                       lane_encoder(lanes)] + encoding_pos)



"""
from typing import Any
class Config:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    hidden_dim: int
    agent_num: int
    static_objects_num: int
    lane_num: int

# 加载 YAML 配置文件并转换为 Config 实例
def load_config(config_file: str) -> Config:
    with open(config_file, 'r') as f:
        config_dict = yaml.safe_load(f)  # 解析 YAML 文件
    print(type(config_dict))
    # 确保返回的是一个字典
    if not isinstance(config_dict, dict):
        raise ValueError("YAML 配置文件解析后应为字典")

    return Config(**config_dict)

# 使用配置文件
config = load_config('train.yaml')

# 访问 config 中的属性
print(config.hidden_dim)  # 直接访问 hidden_dim

encoder = Encoder(config).to(config.device)


<class 'dict'>
192


In [None]:
#whole forward
device = torch.device("cuda:0")
inputs = torch.load("/cailiu2/Diffusion-Planner/diffusion_planner/test/2021.10.01.19.16.42_veh-28_03307_03808_0001593541ec55c3_stationary_in_traffic.pt", map_location=device)

# encoder.forward(inputs)
encoder_outputs = {}

# agents
neighbors = inputs['neighbor_agents_past']

# static objects
static = inputs['static_objects']

# vector maps
lanes = inputs['lanes']
lanes_speed_limit = inputs['lanes_speed_limit']
lanes_has_speed_limit = inputs['lanes_has_speed_limit']

B = neighbors.shape[0]


In [4]:
encoding_neighbors, neighbors_mask, neighbor_pos = encoder.neighbor_encoder(neighbors)
encoding_static, static_mask, static_pos = encoder.static_encoder(static)
encoding_lanes, lanes_mask, lane_pos = encoder.lane_encoder(lanes, lanes_speed_limit, lanes_has_speed_limit)

In [5]:
print(encoding_neighbors.shape)
print(encoding_static.shape)
print(encoding_lanes.shape)
encoding_input = torch.cat([encoding_neighbors, encoding_static, encoding_lanes], dim=1)
print(encoding_input.shape)

torch.Size([1, 30, 192])
torch.Size([1, 30, 192])
torch.Size([1, 40, 192])
torch.Size([1, 100, 192])


In [12]:
print(neighbor_pos.shape)
print(static_pos.shape)
print(lane_pos.shape)
encoding_pos =torch.cat([neighbor_pos, static_pos, lane_pos], dim=1).view(B * encoder.token_num, -1)
encoding_mask = torch.cat([neighbors_mask, static_mask, lanes_mask], dim=1).view(-1)
print(encoding_mask.shape)
encoding_pos = encoder.pos_emb(encoding_pos[~encoding_mask])
encoding_pos_result = torch.zeros((B * encoder.token_num, encoder.hidden_dim), device=encoding_pos.device)
encoding_pos_result[~encoding_mask] = encoding_pos  # Fill in valid parts
encoding_input = encoding_input + encoding_pos_result.view(B, encoder.token_num, -1)
encoder_outputs['encoding'] = encoder.fusion(encoding_input, encoding_mask.view(B, encoder.token_num))

torch.Size([1, 30, 7])
torch.Size([1, 30, 7])
torch.Size([1, 40, 7])
torch.Size([100])
