In [1]:
import os
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import math
import time

from Nuscenes_dataset_test import NuscenesDatasetTest
from Model import EnDeWithPooling, EnDeConvLSTM_ws, SkipLSTMEnDe
from torchvision import transforms
from PIL import Image

In [2]:
def saveTransformedImages(imageTensor):
    to_pil = torchvision.transforms.ToPILImage()
    im = to_pil(imageTensor)
    mn, mx = np.min(im), np.max(im)
    im = (im - mn) / (mx - mn)
    print(im)
    plt.imshow(im, cmap='gray')
    plt.show()

In [3]:
def plotTrajectory(xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti, seqLen, im_path, numFrames=None):
    fig = plt.figure(figsize=(8, 8))
    plt.plot(yValsGT, xValsGT, c='r', label='Ground Truth')
    plt.plot(yValsPred, xValsPred, c='g', label='Prediction')
    plt.plot(yValsPredMulti, xValsPredMulti, c='b', label='Multimodal Prediction', alpha=0.8)
    axes = plt.gca()
    axes.set_xlim([1, 512])
    axes.set_ylim([1, 512])
    plt.xlabel('X-Axis')
    plt.ylabel('Y-Axis')
    plt.legend(loc='upper right')
    if numFrames == None:
        plt.title('Trajectory')
    else:
        plot_title = 'Trajectory (' + str(numFrames // 10 - 2) + "s)"
        plt.title(plot_title)
    plt.savefig(im_path)
    plt.close()

In [4]:
def heatmapAccuracy(outputMap, labelMap, thr=1.5):
    pred = np.unravel_index(outputMap.argmax(), outputMap.shape)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)

    dist = math.sqrt((pred[0] - gt[0]) ** 2 + (pred[1] - gt[1]) ** 2)
    if dist <= thr:
        return 1, dist, (pred[0], pred[1]), (gt[0], gt[1])
    return 0, dist, (pred[0], pred[1]), (gt[0], gt[1])

In [5]:
def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

In [6]:
def multiAccuracy(outputMap, labelMap, topK=5):
    pred = largest_indices(outputMap, topK)
    gt = np.unravel_index(labelMap.argmax(), labelMap.shape)
    dist_arr = []
    for i in range(len(pred[0])):
        dist = math.sqrt((pred[0][i] - gt[0]) ** 2 + (pred[1][i] - gt[1]) ** 2)
        dist_arr.append(dist)
    
    min_val = np.min(dist_arr)
    min_idx = np.argmin(dist_arr)
    within_radius = 0
    if min_val <= 4:
        within_radius = 1
    return 0, min_val, (pred[0][min_idx], pred[1][min_idx]), (gt[0], gt[1]), within_radius

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [8]:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

### Set Cross Validation No

In [9]:
checkpoint_path = "/home/sas115/trajectory_prediction_INFER-master/ablation_cache_nuscenes/skipLSTM/split-0/checkpoint_future_best.tar"

In [10]:
# Choose the corresponding model for each split
# checkpoint_path = os.path.join(repo_dir, "models", "kitti-main", "cv-0", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "kitti-main", "cv-1", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "kitti-main", "cv-2", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "kitti-main", "cv-3", "checkpoint_future.tar")
# checkpoint_path = os.path.join(repo_dir, "models", "kitti-main", "cv-4", "checkpoint_future.tar")

In [11]:
checkpoint = torch.load(checkpoint_path)
model = SkipLSTMEnDe(activation="relu", initType="default", numChannels=5, imageHeight=256, imageWidth=256, batchnorm=False, softmax=False)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.cuda()
model.convlstm = model.convlstm.cuda()

In [12]:
# data_dir = "/home/fbd/rrc/submission/INFER-datasets/kitti"
# val_dir = os.path.join(data_dir, "final-validation", "test" + str(cv_num) + ".csv")
# val_dataset = KittiDataset(data_dir, height=256, width=256, train=False, infoPath=val_dir, augmentation=False, groundTruth=True)

In [13]:
data_dir = "/home/sas115/trajectory_prediction_INFER-master/nuScenes_project_dataset_test"
val_dataset = NuscenesDatasetTest(data_dir, height=256, width=256,)

### Future Prediction (Final, Validation)

In [20]:
upsample_512 = torch.nn.Upsample(scale_factor=2, mode='bilinear')
labelTransform = transforms.Compose([
    transforms.ToTensor()
])

valLoss1, valLoss2, valLoss3, valLoss4, valLoss = [], [], [], [], []
topK = 5
totalPreds = 0
hitPreds = 0

In [21]:
debug, prevOut, state = True, None, None
prevChannels = None
xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti = [], [], [], [], [], []
seqLoss, seqVals = [], []
seqNum, seqLen = 0, 0

start_time = time.time()
model.eval()
new_seq_loss = []

for i in range(len(val_dataset)):
    grid, sceneNum, seqNum, frame_num, endOfSequence = val_dataset[i]
    grid = grid.type(torch.cuda.FloatTensor)
    
    # The Last Channel is the target frame and first n - 1 are source frames
    inp = grid[:-1, :].unsqueeze(0).to(device)
    currLabel = grid[-1:, :].unsqueeze(0).to(device)
    
    if frame_num < 4:
        prevChannels = inp

    if frame_num >= 4:
        new_inp = inp.clone().squeeze(0)
        mn, mx = torch.min(prevOut), torch.max(prevOut)
        prevOut = (prevOut - mn) / (mx - mn)
        new_inp[0] = prevOut
        new_inp[4] = prevChannels[0, 4, :, :]        
        inp = new_inp.unsqueeze(0).cuda()

    # Forward the input and obtain the result
    out = model.forward(inp, state)
    state = (model.h, model.c, model.h1, model.c1, model.h2, model.c2)
    currOutputMap = out.clone()
    newOutputMap = upsample_512(currOutputMap)
    
    nextTargetTensor = grid[5,:,:]
    
    prevOut = currOutputMap.detach().cpu().squeeze(0).squeeze(0)
    currOutputMap = currOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
    currLabel = currLabel.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    # Upsampled outputs and inputs
#     currOutputMap1 = newOutputMap.detach().cpu().numpy().squeeze(0).squeeze(0)
#     currLabel1 = nextTargetTensor_upsampled.detach().cpu().numpy().squeeze(0).squeeze(0)
    
    _, dist1, predCoordinates1, gtCoordinates1 = heatmapAccuracy(currOutputMap, currLabel)
    _, dist2, predCoordinates2, gtCoordinates2, within_radius = multiAccuracy(currOutputMap, currLabel, topK=topK)
    
    if frame_num >= 4:
        seqLoss.append(dist2)
        new_seq_loss.append([dist2])        
        totalPreds += 1
        if within_radius == 1:
            hitPreds += 1
            
    
    seqLen += 1
    xValsGT.append(gtCoordinates1[0])
    yValsGT.append(gtCoordinates1[1])
    xValsPred.append(predCoordinates1[0])
    yValsPred.append(predCoordinates1[1])
    xValsPredMulti.append(predCoordinates2[0])
    yValsPredMulti.append(predCoordinates2[1])
    
    #import pdb; pdb.set_trace()
    if endOfSequence:
        seqVals.append(seqLen)
        xValsGT, yValsGT, xValsPred, yValsPred, xValsPredMulti, yValsPredMulti = [], [], [], [], [], []
        seqNum +=1
        state = None
        valLoss.append(np.mean(seqLoss))
        valLoss1.append(np.mean(seqLoss[:2]))
        valLoss2.append(np.mean(seqLoss[:4]))
        valLoss3.append(np.mean(seqLoss[:6]))
        valLoss4.append(np.mean(seqLoss[:8]))
        print("Scene: {}, Sequence: {}, Frame: {}, Seq Loss: {}".format(sceneNum, seqNum, frame_num, np.mean(seqLoss)))
        #import pdb; pdb.set_trace()
        seqLoss = []

end_time = time.time()

Scene: 44.0, Sequence: 1.0, Frame: 11, Seq Loss: 18.863141051654853
Scene: 44.0, Sequence: 2.0, Frame: 11, Seq Loss: 17.051666480941222
Scene: 44.0, Sequence: 3.0, Frame: 11, Seq Loss: 13.065673497156007
Scene: 44.0, Sequence: 4.0, Frame: 11, Seq Loss: 18.870043407738603
Scene: 44.0, Sequence: 5.0, Frame: 11, Seq Loss: 17.10087016907076
Scene: 44.0, Sequence: 6.0, Frame: 11, Seq Loss: 24.861222851055068
Scene: 44.0, Sequence: 7.0, Frame: 11, Seq Loss: 28.187009557687553
Scene: 44.0, Sequence: 8.0, Frame: 11, Seq Loss: 23.88066000181337
Scene: 44.0, Sequence: 9.0, Frame: 11, Seq Loss: 20.33354319797276
Scene: 44.0, Sequence: 10.0, Frame: 11, Seq Loss: 28.88684620209215
Scene: 44.0, Sequence: 11.0, Frame: 11, Seq Loss: 24.038254719523742
Scene: 44.0, Sequence: 12.0, Frame: 11, Seq Loss: 29.07378766858003
Scene: 44.0, Sequence: 13.0, Frame: 11, Seq Loss: 29.242405934600594
Scene: 44.0, Sequence: 14.0, Frame: 11, Seq Loss: 24.46201954806481
Scene: 44.0, Sequence: 15.0, Frame: 11, Seq Loss:

Scene: 48.0, Sequence: 18.0, Frame: 11, Seq Loss: 11.713130071595804
Scene: 48.0, Sequence: 19.0, Frame: 11, Seq Loss: 9.943978831174121
Scene: 48.0, Sequence: 20.0, Frame: 11, Seq Loss: 18.48655106339094
Scene: 48.0, Sequence: 21.0, Frame: 11, Seq Loss: 8.485792129122839
Scene: 48.0, Sequence: 22.0, Frame: 11, Seq Loss: 8.330818966704374
Scene: 48.0, Sequence: 23.0, Frame: 11, Seq Loss: 11.953490677190317
Scene: 48.0, Sequence: 24.0, Frame: 11, Seq Loss: 17.022283788318944
Scene: 48.0, Sequence: 25.0, Frame: 11, Seq Loss: 25.718979425183264
Scene: 48.0, Sequence: 26.0, Frame: 11, Seq Loss: 14.394732339738983
Scene: 49.0, Sequence: 1.0, Frame: 11, Seq Loss: 14.987533170640413
Scene: 49.0, Sequence: 2.0, Frame: 11, Seq Loss: 10.804497871008717
Scene: 49.0, Sequence: 3.0, Frame: 11, Seq Loss: 6.5573584813073476
Scene: 49.0, Sequence: 4.0, Frame: 11, Seq Loss: 4.700535882715187
Scene: 49.0, Sequence: 5.0, Frame: 11, Seq Loss: 3.686186517852847
Scene: 49.0, Sequence: 6.0, Frame: 11, Seq Lo

In [22]:
print("1s: {}, 2s: {}, 3s: {}, 4s: {}".format(np.mean(valLoss1), np.mean(valLoss2), np.mean(valLoss3), np.mean(valLoss4)))

1s: 11.41506158263851, 2s: 15.13690567416456, 3s: 17.95680730955954, 4s: 22.228192675476198


In [23]:
print("HitPreds: {}, TotalPreds: {}, Hit Rate: {}".format(hitPreds, totalPreds, hitPreds / totalPreds))

HitPreds: 648, TotalPreds: 1872, Hit Rate: 0.34615384615384615
