In [1]:
import pickle
import torch
import sys
import os
import numpy as np
from train_utils import *

sys.path.append('..')


In [2]:
print(torch.__version__)


1.9.0


In [2]:
model_name = 'rho_reg_pecco'
trained_model = torch.load('../checkpoints/' + model_name + '.pth',map_location=torch.device('cpu'))
trained_model.eval()

ECCONetwork(
  (conv_fluid): EquiCtsConv2dRho1ToReg()
  (conv_obstacle): EquiCtsConv2dRho1ToReg()
  (dense_fluid): Sequential(
    (0): EquiLinearRho1ToReg()
    (1): EquiLinearRegToReg()
  )
  (convs): ModuleList(
    (0): EquiCtsConv2dRegToReg()
    (1): EquiCtsConv2dRegToReg()
    (2): EquiCtsConv2dRegToReg()
    (3): EquiCtsConv2dRegToRho1()
  )
  (denses): ModuleList(
    (0): EquiLinearRegToReg()
    (1): EquiLinearRegToReg()
    (2): EquiLinearRegToReg()
    (3): Sequential(
      (0): EquiLinearRegToReg()
      (1): EquiLinearRegToRho1()
    )
  )
)

In [3]:
from evaluate_network import evaluate
from datasets.argoverse_lane_loader import read_pkl_data
from argoverse.map_representation.map_api import ArgoverseMap


In [13]:
dataset_path = '../../argoverse_data/rose'
val_path = os.path.join(dataset_path, 'val') #, 'lane_data'
val_dataset = read_pkl_data(val_path, batch_size=1, shuffle=False, repeat=False)


In [14]:
with torch.no_grad():
    total_loss, prediction_gt = evaluate(trained_model, val_dataset,
                                       train_window=4, max_iter=50, 
                                       device='cpu', start_iter=50, use_lane=False)

evaluating.. {'nll': 0.0, 'ADE': nan, 'ADE_std': nan, 'DE@1s': nan, 'DE@1s_std': nan, 'DE@2s': nan, 'DE@2s_std': nan, 'DE@3s': nan, 'DE@3s_std': nan}
done


In [17]:
from evaluate_network import *

In [25]:
model = trained_model
max_iter =50
start_iter =50
device = 'cpu'
batch_size=1

In [36]:
for i, sample in enumerate(val_dataset):
    pred = []
    gt = []

    data = process_batch(sample, device)

    lane = data['lane']
    lane_normals = data['lane_norm']
    agent_id = data['agent_id']
    city = data['city']
    scenes = data['scene_idx'].tolist()

    inputs = ([
        data['pos_2s'], data['vel_2s'], 
        data['pos0'], data['vel0'], 
        data['accel'], data['sigmas'],
        data['lane'], data['lane_norm'], 
        data['car_mask'], data['lane_mask']
    ])

    pr_pos1, pr_vel1, pr_m1, states = model(inputs)
    gt_pos1 = data['pos1']

    losses = nll(pr_pos1, gt_pos1, pr_m1, data['car_mask'].squeeze(-1))

    pr_agent, gt_agent = get_agent(pr_pos1, data['pos1'],
                                   data['track_id0'].squeeze(-1), 
                                   data['track_id1'].squeeze(-1), 
                                   agent_id.squeeze(-1), device)

    pred.append(pr_agent.unsqueeze(1).detach().cpu())
    gt.append(gt_agent.unsqueeze(1).detach().cpu())
    del pr_agent, gt_agent
    clean_cache(device)

    pos0 = data['pos0']
    vel0 = data['vel0']
    m0 = torch.zeros((batch_size, 60, 2, 2), device=pos0.device)
    for j in range(29):
        pos_enc = torch.unsqueeze(pos0, 2)
        vel_enc = torch.unsqueeze(vel0, 2)
        inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'],
                  torch.cat([m0, pr_m1], dim=-2), 
                  data['lane'],
                  data['lane_norm'], data['car_mask'], data['lane_mask'])
        
        pos0, vel0, m0 = pr_pos1, pr_vel1, pr_m1

        pr_pos1, pr_vel1, pr_m1, states = model(inputs, states)
        clean_cache(device)

        gt_pos1 = data['pos'+str(j+1)]
        losses += nll(pr_pos1, gt_pos1, pr_m1, data['car_mask'].squeeze(-1))

        pr_agent, gt_agent = get_agent(pr_pos1, data['pos'+str(j+1)],
                                       data['track_id0'].squeeze(-1), 
                                       data['track_id'+str(j+1)].squeeze(-1),
                                       agent_id.squeeze(-1), device)

        pred.append(pr_agent.unsqueeze(1).detach().cpu())
        gt.append(gt_agent.unsqueeze(1).detach().cpu())
        
    print("batch complete: ", i)


batch complete:  0
batch complete:  1
batch complete:  2
batch complete:  3
batch complete:  4


KeyboardInterrupt: 

tensor(4.8100e+15, grad_fn=<AddBackward0>)

In [38]:
pred

[tensor([[[ 589.4445, 1381.8385]]]),
 tensor([[[ 589.6328, 1381.6088]]]),
 tensor([[[ 589.3992, 1380.2305]]]),
 tensor([[[ 588.9202, 1380.0366]]]),
 tensor([[[ 579.9214, 1367.5590]]]),
 tensor([[[ 584.3655, 1377.3729]]]),
 tensor([[[ 596.6658, 1271.8363]]]),
 tensor([[[ 640.7068, 1356.7920]]]),
 tensor([[[597.8172, 894.6785]]]),
 tensor([[[ 828.4139, 1853.6736]]]),
 tensor([[[  226.0106, -1018.5403]]]),
 tensor([[[1994.9950, 6174.8740]]]),
 tensor([[[ -2903.6455, -13136.2568]]]),
 tensor([[[10917.8076, 36676.0664]]]),
 tensor([[[-29450.9258, -93312.3594]]]),
 tensor([[[ 86699.4688, 243709.0000]]]),
 tensor([[[-246646.8906, -628298.2500]]]),
 tensor([[[ 739073.2500, 1645480.5000]]]),
 tensor([[[-2385644.5000, -4320317.0000]]]),
 tensor([[[ 7071482., 11239530.]]]),
 tensor([[[-20403056., -27931576.]]])]

In [39]:
gt

[tensor([[[ 587.6106, 1381.2997]]]),
 tensor([[[ 587.6106, 1381.2997]]]),
 tensor([[[ 586.9632, 1380.9113]]]),
 tensor([[[ 586.6739, 1380.7225]]]),
 tensor([[[ 586.3462, 1380.4492]]]),
 tensor([[[ 585.7589, 1379.9354]]]),
 tensor([[[ 585.4659, 1379.6121]]]),
 tensor([[[ 584.9191, 1379.0210]]]),
 tensor([[[ 584.6808, 1378.6652]]]),
 tensor([[[ 584.3546, 1378.2303]]]),
 tensor([[[ 583.8191, 1377.5601]]]),
 tensor([[[ 583.6369, 1377.1196]]]),
 tensor([[[ 583.3925, 1376.6952]]]),
 tensor([[[ 582.9846, 1375.8479]]]),
 tensor([[[ 582.7881, 1375.3280]]]),
 tensor([[[ 582.6030, 1374.8840]]]),
 tensor([[[ 582.4222, 1374.3217]]]),
 tensor([[[ 582.2029, 1373.8376]]]),
 tensor([[[ 582.0811, 1373.3918]]]),
 tensor([[[ 581.8999, 1372.7820]]]),
 tensor([[[ 581.7318, 1372.3556]]])]

In [31]:
prediction_gt.keys()

dict_keys([])