In [4]:
from config import config
import torch
import logging
import os
from PIL import Image
import pandas as pd
import numpy as np
from torchvision import transforms

from src.net import build_model
from src.utils import set_all_seeds

logger = logging.getLogger(__name__)

In [2]:
logging.basicConfig(level=logging.INFO)
if config.use_cuda and torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

logger.info(f"Test on device {device}")

set_all_seeds(config)
logger.info(f"Random seed value: {config.seed}")

INFO:__main__:Test on device cuda:0
INFO:__main__:Random seed value: 42


In [3]:
model = build_model(config)
weight_path = os.path.join(config.weight_dir, config.weight_name + ".pth")
if os.path.exists(weight_path):
    logger.info(f"Loading pretrained weights from {weight_path}")
    model.load_state_dict(torch.load(weight_path, map_location='cpu'), strict=False)
else:
    logger.warning(f"Pretrained weights not found at {weight_path}")
    exit(0)    

model.eval()
model.to(device)    

INFO:__main__:Loading pretrained weights from ./weight/event_bert.pth


EventBERT(
  (eventImg2Token): EventImg2Token(
    (conv): Sequential(
      (0): ConvDw(
        (conv): Sequential(
          (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
          (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
      (1): ConvDw(
        (conv): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.01, inplace=True)
          (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=Fa

In [4]:
from tqdm import tqdm
def measure(model, root_dir, log_dir, config, device):
    transform_base = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Normalize(mean=[0.5], std=[0.5])
    ])      
    img_dir = os.path.join(root_dir, 'img')
    csv_path = os.path.join(root_dir, 'data', 'trajectory.csv')
    df = pd.read_csv(csv_path)
    image_files = sorted([f for f in os.listdir(img_dir) if f.endswith('_pos.png')])
    num_timestamps = len(image_files)
    vx_results = np.full((num_timestamps, num_timestamps), np.nan)
    vy_results = np.full((num_timestamps, num_timestamps), np.nan)
    vz_results = np.full((num_timestamps, num_timestamps), np.nan)
    num_windows = num_timestamps - config.max_seq_len + 1
    for i in tqdm(range(num_windows), desc="Measuring"):
        pos_images = []
        neg_images = []
        traj_list = []
        for j in range(config.max_seq_len):
            frame_idx = i + j
            frame_name = str(frame_idx).zfill(4)
            pos_image_path = os.path.join(img_dir, f'{frame_name}_pos.png')
            neg_image_path = os.path.join(img_dir, f'{frame_name}_neg.png')
            pos_image = Image.open(pos_image_path).convert('L')
            neg_image = Image.open(neg_image_path).convert('L')
            pos_image = transform_base(pos_image)
            neg_image = transform_base(neg_image)
            pos_images.append(pos_image)
            neg_images.append(neg_image)
            data = df.iloc[frame_idx, -7:].to_numpy()
            data_tensor = torch.tensor(data).float()
            traj_list.append(data_tensor)
            
            # It selects the last 7 columns of the dataframe row at index frame_idx
            
        x_pos_seq = torch.stack(pos_images) 
        x_neg_seq = torch.stack(neg_images)        
        x_seq = torch.cat([x_pos_seq, x_neg_seq], dim=1).unsqueeze(0)
        traj_seq = torch.stack(traj_list).unsqueeze(0)

        x_seq = x_seq.to(device)
        traj_seq = traj_seq.to(device)

        with torch.no_grad():
            output = model(x_seq, traj_seq)

        velocities = output[0].detach().cpu().numpy()
        vx_results[i, i : i + config.max_seq_len] = velocities[:, 0]
        vy_results[i, i : i + config.max_seq_len] = velocities[:, 1]
        vz_results[i, i : i + config.max_seq_len] = velocities[:, 2]
        
    os.makedirs(log_dir, exist_ok=True)
    pd.DataFrame(vx_results).to_csv(os.path.join(log_dir, 'predicted_vx.csv'), index=False, header=False)
    pd.DataFrame(vy_results).to_csv(os.path.join(log_dir, 'predicted_vy.csv'), index=False, header=False)
    pd.DataFrame(vz_results).to_csv(os.path.join(log_dir, 'predicted_vz.csv'), index=False, header=False)
    
    print(f"Measurement finished. Predicted velocity matrices saved in {log_dir}")


In [2]:
def analyze_results(log_dir):

    print("Analyzing prediction matrices...")
    
    try:
        vx_df = pd.read_csv(os.path.join(log_dir, 'predicted_vx.csv'), header=None)
        vy_df = pd.read_csv(os.path.join(log_dir, 'predicted_vy.csv'), header=None)
        vz_df = pd.read_csv(os.path.join(log_dir, 'predicted_vz.csv'), header=None)
    except FileNotFoundError:
        print(f"Prediction files not found in {log_dir}. Please run measure() first.")
        return

    # --- 核心计算：对每一列求平均值 ---
    # axis=0 表示沿着列的方向操作。pandas会自动忽略NaN。

    original_cols_indices = [k * 4 for k in range(120)]
    vx_df_original = vx_df.iloc[:, original_cols_indices]
    vy_df_original = vy_df.iloc[:, original_cols_indices]
    vz_df_original = vz_df.iloc[:, original_cols_indices]

    final_vx = vx_df_original.mean(axis=0)
    final_vy = vy_df_original.mean(axis=0)
    final_vz = vz_df_original.mean(axis=0)

    # --- 组合成最终的轨迹DataFrame ---
    final_trajectory_df = pd.DataFrame({
        'vx_final': final_vx,
        'vy_final': final_vy,
        'vz_final': final_vz
    })

    # --- 保存最终结果 ---
    output_path = os.path.join(log_dir, 'final_estimated_velocity.csv')
    final_trajectory_df.to_csv(output_path, index_label='timestamp_idx')
    
    print(f"Final estimated velocity trajectory saved to {output_path}")


In [6]:
for idx in range(28, 93):
    root_dir = f'./dataset/test/{idx:04d}_4'
    log_dir = f'./log/test/{idx:04d}_4'
    measure(model, root_dir, log_dir, config, device)

Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.01it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0028_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.44it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0029_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.95it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0030_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.07it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0031_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.52it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0032_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.99it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0033_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.79it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0034_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.98it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0035_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.81it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0036_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.33it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0037_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.98it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0038_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.96it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0039_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.10it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0040_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.29it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0041_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 20.96it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0042_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.85it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0043_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.44it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0044_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.34it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0045_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.44it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0046_4


Measuring: 100%|██████████| 454/454 [00:22<00:00, 20.48it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0047_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.92it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0048_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.59it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0049_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.18it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0050_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 20.76it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0051_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.11it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0052_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.56it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0053_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.49it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0054_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.06it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0055_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.36it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0056_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.27it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0057_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 20.97it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0058_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.01it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0059_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.46it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0060_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.91it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0061_4


Measuring: 100%|██████████| 454/454 [00:22<00:00, 20.38it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0062_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.68it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0063_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.11it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0064_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.23it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0065_4


Measuring: 100%|██████████| 454/454 [00:22<00:00, 20.52it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0066_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.20it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0067_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.08it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0068_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.17it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0069_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.05it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0070_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.81it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0071_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.84it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0072_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 20.94it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0073_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.29it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0074_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.24it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0075_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.21it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0076_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.00it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0077_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.10it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0078_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.05it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0079_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.39it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0080_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.45it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0081_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.89it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0082_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.33it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0083_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.59it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0084_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.20it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0085_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 22.15it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0086_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.02it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0087_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.99it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0088_4


Measuring: 100%|██████████| 454/454 [00:20<00:00, 21.68it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0089_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.04it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0090_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.05it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0091_4


Measuring: 100%|██████████| 454/454 [00:21<00:00, 21.21it/s]


Measurement finished. Predicted velocity matrices saved in ./log/test/0092_4


In [None]:
for idx in range(28, 93):
    log_dir = f'./log/test/{idx:04d}_4'
    analyze_results(log_dir)

In [8]:
output = dict()
for idx in range(28, 93):
    log_dir = f'./log/test/{idx:04d}_4'
    df = pd.read_csv(os.path.join(log_dir, 'final_estimated_velocity.csv'))
    output[idx] = {"vx": df['vx_final'].tolist(), "vy": df['vy_final'].tolist(), "vz": df['vz_final'].tolist()}

import json
with open('submission.json', 'wt') as f:
    json.dump(output, f, indent=4)