In [1]:
import torch, torch.nn as nn

import sys, os
sys.path.append(os.path.abspath(".."))
from model.module.utils import DynamicMLP
from model.module.utils import DropPath
from diffusion_planner.model.module.mixer import MixerBlock

from diffusion_planner.model.module.decoder import(
  Decoder, RouteEncoder,DiT
)
from diffusion_planner.model.module.encoder import(
  Encoder
)
from diffusion_planner.utils.normalizer import StateNormalizer
from diffusion_planner.model.diffusion_utils.sde import SDE, VPSDE_linear
import yaml

In [2]:
from typing import Any
class Config:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    hidden_dim: int
    agent_num: int
    static_objects_num: int
    lane_num: int

# 加载 YAML 配置文件并转换为 Config 实例
def load_config(config_file: str) -> Config:
    with open(config_file, 'r') as f:
        config_dict = yaml.safe_load(f)  # 解析 YAML 文件
    print(type(config_dict))
    # 确保返回的是一个字典
    if not isinstance(config_dict, dict):
        raise ValueError("YAML 配置文件解析后应为字典")

    return Config(**config_dict)

In [3]:
config = load_config('train.yaml')

# 访问 config 中的属性
print(config.hidden_dim)  # 直接访问 hidden_dim

encoder = Encoder(config).to(config.device)
#whole forward
device = torch.device("cuda:0")
inputs = torch.load("/cailiu2/Diffusion-Planner/diffusion_planner/test/2021.10.01.19.16.42_veh-28_03307_03808_0001593541ec55c3_stationary_in_traffic.pt", map_location=device)

encoder_output = encoder.forward(inputs)

<class 'dict'>
192


In [4]:
# print(config.state_normalizer['mean'])
config.state_normalizer =StateNormalizer(
    mean=config.state_normalizer['mean'],
    std=config.state_normalizer['std'])
decoder = Decoder(config)

# decoder.forward(encoder_output, inputs)

In [5]:
import torch
print(torch.cuda.is_available())

True


In [6]:
ego_current = inputs['ego_current_state'][:, None, :4]
neighbors_current = inputs["neighbor_agents_past"][:, :decoder._predicted_neighbor_num, -1, :4]
neighbor_current_mask = torch.sum(torch.ne(neighbors_current[..., :4], 0), dim=-1) == 0

current_states = torch.cat([ego_current, neighbors_current], dim=1) # [B, P, 4]

B, P, _ = current_states.shape
assert P == (1 + decoder._predicted_neighbor_num)

In [7]:
ego_neighbor_encoding = encoder_output['encoding']
route_lanes = inputs['route_lanes']
print("B: ", B)
print("P: ", P)

print(inputs['ego_current_state'].shape)
print(inputs['neighbor_agents_past'].shape) # x, y, cos h, sin h, vx, vy, length, width，type[3]

# inputs['sampled_trajectories'] = inputs['ego_current_state']
x = inputs['ego_current_state'].view(B, 1, -1)
y= inputs['neighbor_agents_past'][:, :, :, :4]
y = y.squeeze(2)
print(x.shape)
print(y.shape)
inputs['sampled_trajectories'] =torch.cat( (x,  y), dim = 1)
inputs['diffusion_time'] = 10
print(inputs['sampled_trajectories'].shape)
inputs['sampled_trajectories'].reshape(B,P,-1)
print(inputs['sampled_trajectories'].shape)

B:  1
P:  31
torch.Size([1, 4])
torch.Size([1, 30, 1, 11])
torch.Size([1, 1, 4])
torch.Size([1, 30, 4])
torch.Size([1, 31, 4])
torch.Size([1, 31, 4])


In [12]:
dpr = config.decoder_drop_path_rate
dit = DiT(
            sde=VPSDE_linear(),
            route_encoder = RouteEncoder(config.route_num, config.lane_len, drop_path_rate=config.encoder_drop_path_rate, hidden_dim=config.hidden_dim),
            depth=config.decoder_depth,
            output_dim= (config.future_len + 1) * 4, # x, y, cos, sin
            hidden_dim=config.hidden_dim,
            heads=config.num_heads,
            dropout=dpr,
            model_type=config.diffusion_model_type
        ).to(config.device)

sample_trajecotries = inputs['sampled_trajectories'].to(config.device)
diffusion_time = 10

ego_neighbor_encoding = encoder_output['encoding'].to(config.device)
route_lanes = inputs['route_lanes'].to(config.device)

neighbors_current = inputs["neighbor_agents_past"][:, :config.predicted_neighbor_num, -1, :4].to(config.device)
neighbor_current_mask = (torch.sum(torch.ne(neighbors_current[..., :4], 0), dim=-1) == 0).to(config.device)
print(inputs['sampled_trajectories'].shape)
# score = dit(sample_trajecotries, diffusion_time, ego_neighbor_encoding, route_lanes, neighbor_current_mask)

torch.Size([1, 31, 4])


In [None]:
output_dim =(config.future_len + 1) * 4
preproj = DynamicMLP(in_features=output_dim, hidden_features=512, out_features=hidden_dim, drop=0.)