In [1]:
from config import config
import torch
import logging
import os
from PIL import Image
import pandas as pd
import numpy as np
from torchvision import transforms

from src.net import build_model
from src.utils import set_all_seeds

logger = logging.getLogger(__name__)

In [2]:
logging.basicConfig(level=logging.INFO)
if config.use_cuda and torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

logger.info(f"Test on device {device}")

set_all_seeds(config)
logger.info(f"Random seed value: {config.seed}")

INFO:__main__:Test on device cuda:0
INFO:__main__:Random seed value: 42


In [3]:
model = build_model(config)
weight_path = os.path.join(config.weight_dir, config.weight_name + ".pth")
if os.path.exists(weight_path):
    logger.info(f"Loading pretrained weights from {weight_path}")
    model.load_state_dict(torch.load(weight_path, map_location='cpu'), strict=False)
else:
    logger.warning(f"Pretrained weights not found at {weight_path}")
    exit(0)    

model.eval()
model.to(device)    

INFO:__main__:Loading pretrained weights from ./weight/event_bert_v2.pth


EventBERTV2(
  (eventImg2Token): EventImg2TokenV2(
    (conv): Sequential(
      (0): ConvDw(
        (conv): Sequential(
          (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=2, bias=False)
          (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): ConvDw(
        (conv): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bi

In [4]:
from tqdm import tqdm
def measure(model, root_dir, log_dir, config, device):
    transform_base = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.5], std=[0.5])
    ])      
    img_dir = os.path.join(root_dir, 'img')
    csv_path = os.path.join(root_dir, 'data', 'trajectory.csv')
    df = pd.read_csv(csv_path)
    image_files = sorted([f for f in os.listdir(img_dir) if f.endswith('_pos.png')])
    num_timestamps = len(image_files)
    vx_results = np.full((num_timestamps, num_timestamps), np.nan)
    vy_results = np.full((num_timestamps, num_timestamps), np.nan)
    vz_results = np.full((num_timestamps, num_timestamps), np.nan)
    num_windows = num_timestamps - config.max_seq_len + 1
    for i in tqdm(range(num_windows), desc="Measuring"):
        pos_images = []
        neg_images = []
        traj_list = []
        for j in range(config.max_seq_len):
            frame_idx = i + j
            frame_name = str(frame_idx).zfill(4)
            pos_image_path = os.path.join(img_dir, f'{frame_name}_pos.png')
            neg_image_path = os.path.join(img_dir, f'{frame_name}_neg.png')
            pos_image = Image.open(pos_image_path).convert('L')
            neg_image = Image.open(neg_image_path).convert('L')
            pos_image = transform_base(pos_image)
            neg_image = transform_base(neg_image)
            pos_images.append(pos_image)
            neg_images.append(neg_image)
            data = df.iloc[frame_idx, -7:].to_numpy()
            data_tensor = torch.tensor(data).float()
            traj_list.append(data_tensor)
            
            # It selects the last 7 columns of the dataframe row at index frame_idx
            
        x_pos_seq = torch.stack(pos_images) 
        x_neg_seq = torch.stack(neg_images)        
        x_seq = torch.cat([x_pos_seq, x_neg_seq], dim=1).unsqueeze(0)
        traj_seq = torch.stack(traj_list).unsqueeze(0)

        x_seq = x_seq.to(device)
        traj_seq = traj_seq.to(device)

        with torch.no_grad():
            output = model(x_seq, traj_seq)

        velocities = output[0].detach().cpu().numpy()
        vx_results[i, i : i + config.max_seq_len] = velocities[:, 0]
        vy_results[i, i : i + config.max_seq_len] = velocities[:, 1]
        vz_results[i, i : i + config.max_seq_len] = velocities[:, 2]
        
    os.makedirs(log_dir, exist_ok=True)
    pd.DataFrame(vx_results).to_csv(os.path.join(log_dir, 'predicted_vx.csv'), index=False, header=False)
    pd.DataFrame(vy_results).to_csv(os.path.join(log_dir, 'predicted_vy.csv'), index=False, header=False)
    pd.DataFrame(vz_results).to_csv(os.path.join(log_dir, 'predicted_vz.csv'), index=False, header=False)
    
    print(f"Measurement finished. Predicted velocity matrices saved in {log_dir}")


In [5]:
def analyze_results(log_dir):

    print("Analyzing prediction matrices...")
    
    try:
        vx_df = pd.read_csv(os.path.join(log_dir, 'predicted_vx.csv'), header=None)
        vy_df = pd.read_csv(os.path.join(log_dir, 'predicted_vy.csv'), header=None)
        vz_df = pd.read_csv(os.path.join(log_dir, 'predicted_vz.csv'), header=None)
    except FileNotFoundError:
        print(f"Prediction files not found in {log_dir}. Please run measure() first.")
        return

    # --- 核心计算：对每一列求平均值 ---
    # axis=0 表示沿着列的方向操作。pandas会自动忽略NaN。

    original_cols_indices = [k for k in range(120)]
    vx_df_original = vx_df.iloc[:, original_cols_indices]
    vy_df_original = vy_df.iloc[:, original_cols_indices]
    vz_df_original = vz_df.iloc[:, original_cols_indices]

    final_vx = vx_df_original.mean(axis=0)
    final_vy = vy_df_original.mean(axis=0)
    final_vz = vz_df_original.mean(axis=0)

    # --- 组合成最终的轨迹DataFrame ---
    final_trajectory_df = pd.DataFrame({
        'vx_final': final_vx,
        'vy_final': final_vy,
        'vz_final': final_vz
    })

    # --- 保存最终结果 ---
    output_path = os.path.join(log_dir, 'final_estimated_velocity.csv')
    final_trajectory_df.to_csv(output_path, index_label='timestamp_idx')
    
    print(f"Final estimated velocity trajectory saved to {output_path}")


In [7]:
for idx in range(28, 93):
    root_dir = f'./dataset/testv2/{idx:04d}_1'
    log_dir = f'./log/testv2/{idx:04d}_1'
    measure(model, root_dir, log_dir, config, device)

Measuring: 100%|██████████| 115/115 [00:04<00:00, 27.87it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0028_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.31it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0029_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.01it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0030_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.56it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0031_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.09it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0032_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.92it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0033_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.60it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0034_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.84it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0035_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.21it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0036_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.17it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0037_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.06it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0038_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.92it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0039_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.33it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0040_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.09it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0041_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.57it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0042_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.10it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0043_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.64it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0044_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.10it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0045_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.05it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0046_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.15it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0047_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.82it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0048_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.13it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0049_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.92it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0050_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.02it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0051_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.40it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0052_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.73it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0053_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.60it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0054_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.89it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0055_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.84it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0056_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.27it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0057_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.37it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0058_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.48it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0059_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.94it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0060_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.01it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0061_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.35it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0062_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.32it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0063_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.17it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0064_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.16it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0065_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 28.81it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0066_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.25it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0067_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.47it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0068_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.21it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0069_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.47it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0070_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.35it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0071_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.56it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0072_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.50it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0073_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.26it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0074_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.81it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0075_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.22it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0076_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.42it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0077_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.38it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0078_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.17it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0079_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.93it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0080_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.75it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0081_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.59it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0082_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.57it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0083_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.08it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0084_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.91it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0085_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.75it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0086_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 28.78it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0087_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 31.00it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0088_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.56it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0089_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 29.54it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0090_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.12it/s]


Measurement finished. Predicted velocity matrices saved in ./log/testv2/0091_1


Measuring: 100%|██████████| 115/115 [00:03<00:00, 30.11it/s]

Measurement finished. Predicted velocity matrices saved in ./log/testv2/0092_1





In [9]:
for idx in range(28, 93):
    log_dir = f'./log/testv2/{idx:04d}_1'
    analyze_results(log_dir)

Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0028_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0029_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0030_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0031_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0032_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0033_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0034_1/final_estimated_velocity.csv
Analyzing prediction matrices...
Final estimated velocity trajectory saved to ./log/testv2/0035_1/final_estimat

In [11]:
output = dict()
for idx in range(28, 93):
    log_dir = f'./log/testv2/{idx:04d}_1'
    df = pd.read_csv(os.path.join(log_dir, 'final_estimated_velocity.csv'))
    output[idx] = {"vx": df['vx_final'].tolist(), "vy": df['vy_final'].tolist(), "vz": df['vz_final'].tolist()}

import json
with open('submission.json', 'wt') as f:
    json.dump(output, f, indent=4)