In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torchvision.transforms as transforms

import numpy as np
import os
import argparse
import visdom
from tqdm import tqdm

from mtp2 import MTP, MTPLoss
import util

In [30]:
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
batch_size = 64
num_workers = 8
shuffle = False

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

dataset = util.DataSet_proj('./dataset_chh/' + 'train_val', 'train_val')
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)

parser = argparse.ArgumentParser()
parser.add_argument('--name',       required=True,  type=str,   help='experiment name. saved to ./exps/[name]')
parser.add_argument('--ep',         required=True,  type=str)
args = parser.parse_args('--name hoon17 --ep best'.split(' '))
exp_path, train_path, val_path, infer_path, ckpt_path = util.make_path(args)


./exps/hoon17 already exsits.


In [31]:
model = torch.load(ckpt_path + '/' + 'model.archi')
model.load_state_dict(torch.load(ckpt_path + '/' + 'weight_' + args.ep + '.pth')['state_dict'])

<All keys matched successfully>

In [34]:
model = model.to(device)
model.eval()
transform_dac = transforms.Compose([transforms.ToPILImage(), transforms.Resize(500), transforms.ToTensor()])

loss_val_mean = []
dac = 0.
num_samp = 0
ade = 0.
fde = 0.
for raster, road, lane, agents, state, past, gt in tqdm(dataloader):
    raster, road, lane, agents, state, past, gt = util.NaN2Zero(raster), util.NaN2Zero(road), util.NaN2Zero(lane), util.NaN2Zero(agents), util.NaN2Zero(state), util.NaN2Zero(past),  util.NaN2Zero(gt)
    raster, road, lane, agents, state, past, gt = raster.to(device), road.to(device), lane.to(device), agents.to(device), state.to(device), past.to(device), gt.to(device)

    prediction = model(road, lane, agents, state, past)

    for road_, pred_ in zip(road, prediction):
        road_ = util.restore_img(road_.cpu())
        road_ = transform_dac(road_).numpy()
        road_ = (road_ * 255).astype(np.uint8)

        dac_ = util.dac_metric(road_, pred_)
        dac += dac_

        gt_e = gt_.view(-1)
        pred_e = pred_[:-1]

        ade_ = ((gt_e - pred_e)**2).sum(dim=0) / len(gt_e)
        ade += ade_.item()

        fde_ = ((gt_e[-2:] - pred_e[-2:])**2).sum(dim=0) / 2
        fde += fde_.item()

        num_samp += 1

print('\n', 'dac : ', dac/num_samp)
print('\n', 'ade : ', ade/num_samp)
print('\n', 'fde : ', fde/num_samp)



100%|██████████| 49/49 [01:47<00:00,  2.19s/it]
 dac :  0.9605934907466485

 ade :  74.30407556311846

 fde :  193.41256296569222



## Baseline

In [35]:
ckpt_path = './exps/1130_mode1/ckpt'
ep = 'best'

model = torch.load(ckpt_path + '/' + 'model.archi')
model.load_state_dict(torch.load(ckpt_path + '/' + 'weight_' + ep + '.pth')['state_dict'])

<All keys matched successfully>

In [36]:
model = model.to(device)
model.eval()
transform_dac = transforms.Compose([transforms.ToPILImage(), transforms.Resize(500), transforms.ToTensor()])

loss_val_mean = []
dac = 0.
num_samp = 0
ade = 0.
fde = 0.
for raster, road, lane, agents, state, past, gt in tqdm(dataloader):
    raster, road, lane, agents, state, past, gt = util.NaN2Zero(raster), util.NaN2Zero(road), util.NaN2Zero(lane), util.NaN2Zero(agents), util.NaN2Zero(state), util.NaN2Zero(past),  util.NaN2Zero(gt)
    raster, road, lane, agents, state, past, gt = raster.to(device), road.to(device), lane.to(device), agents.to(device), state.to(device), past.to(device), gt.to(device)

    prediction = model(raster, state)

    for road_, pred_, gt_ in zip(road, prediction, gt):
        road_ = util.restore_img(road_.cpu())
        road_ = transform_dac(road_).numpy()
        road_ = (road_ * 255).astype(np.uint8)

        gt_e = gt_.view(-1)
        pred_e = pred_[:-1]

        ade_ = ((gt_e - pred_e)**2).sum(dim=0) / len(gt_e)
        ade += ade_.item()

        fde_ = ((gt_e[-2:] - pred_e[-2:])**2).sum(dim=0) / 2
        fde += fde_.item()

        dac_ = util.dac_metric(road_, pred_)
        dac += dac_
        num_samp += 1

print('\n', 'dac : ', dac/num_samp)
print('\n', 'ade : ', ade/num_samp)
print('\n', 'fde : ', fde/num_samp)

100%|██████████| 49/49 [01:23<00:00,  1.71s/it]
 dac :  0.952111252924909

 ade :  23.13349279402812

 fde :  82.25350927679288



In [33]:
fde_.item()

22.403121948242188