In [1]:
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import torch

from src.data.argoverse_datamodule import ArgoverseDataModule
from src.data.argoverse_dataset import ArgoverseDataset
from src.data.dummy_datamodule import DummyDataModule
from src.models.WIMP import WIMP

from torch.utils.data import DataLoader, Dataset

import XAI_utils

In [2]:
args = {"IFC":True, "add_centerline":False, "attention_heads":4, "batch_norm":False, "batch_size":25, "check_val_every_n_epoch":3, 
          "dataroot":'./data/argoverse_processed_simple', "dataset":'argoverse', "distributed_backend":'ddp', "dropout":0.5, 
          "early_stop_threshold":5, "experiment_name":'example', "gpus":1, "gradient_clipping":True, "graph_iter":1, 
          "hidden_dim":512, "hidden_key_generator":True, "hidden_transform":False, "input_dim":2, "k_value_threshold":10, 
          "k_values":[6, 5, 4, 3, 2, 1], "lr":0.0001, "map_features":False, "max_epochs":200, "mode":'train', "model_name":'WIMP', 
          "no_heuristic":False, "non_linearity":'relu', "num_layers":4, "num_mixtures":6, "num_nodes":1, "output_conv":True, "output_dim":2, 
          "output_prediction":True, "precision":32, "predict_delta":False, "resume_from_checkpoint":None, 
          "scheduler_step_size":[60, 90, 120, 150, 180], "seed":None, "segment_CL":False, "segment_CL_Encoder":False, 
          "segment_CL_Encoder_Gaussian":False, "segment_CL_Encoder_Gaussian_Prob":False, "segment_CL_Encoder_Prob":True, 
          "segment_CL_Gaussian_Prob":False, "segment_CL_Prob":False, "use_centerline_features":True, "use_oracle":False, "waypoint_step":5, 
          "weight_decay":0.0, "workers":8, "wta":False}

In [3]:
train_loader = ArgoverseDataset(args['dataroot'], mode='train', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

val_loader = ArgoverseDataset(args['dataroot'], mode='val', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

test_loader = ArgoverseDataset(args['dataroot'], mode='test', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

train_dataset = DataLoader(train_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=True, drop_last=True)

val_dataset = DataLoader(val_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=False, drop_last=False)

trest_dataset = DataLoader(test_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=False, drop_last=False)
# dm = ArgoverseDataset(args)

In [4]:
model = WIMP(args)
model.load_state_dict(torch.load("experiments/example/checkpoints/epoch=122.ckpt")['state_dict'])
model = model.cuda()

In [None]:
#  optimizer는 선언하지 않았습니다.
#  따라서 아마 model weight에 gradient는 계속 쌓이겠지만 저희가 중요한 것은 adjacency matrix의 graident이므로 상관 없을 것으로 예측됩니다.
import os
import math
import copy
import torch.nn as nn
Relu = nn.ReLU()

save_foler = "ResultsImg/"

save_XAI = save_foler + "/XAI/"
save_attention = save_foler + "/attention"

os.mkdir(save_XAI)
os.mkdir(save_attention)

for batch_idx, batch in enumerate(tqdm(val_dataset)):
    input_dict, target_dict = batch[0], batch[1]
    input_dict['agent_features'] = input_dict['agent_features'].cuda()
    input_dict['social_features'] = input_dict['social_features'].cuda()
    input_dict['social_label_features'] = input_dict['social_label_features'].cuda()
    input_dict['adjacency'] = input_dict['adjacency'].cuda()
    input_dict['label_adjacency'] = input_dict['label_adjacency'].cuda()
    input_dict['num_agent_mask'] = input_dict['num_agent_mask'].cuda()
    input_dict['ifc_helpers']['agent_oracle_centerline'] = input_dict['ifc_helpers']['agent_oracle_centerline'].cuda()
    input_dict['ifc_helpers']['agent_oracle_centerline_lengths'] = input_dict['ifc_helpers']['agent_oracle_centerline_lengths'].cuda()
    input_dict['ifc_helpers']['social_oracle_centerline'] = input_dict['ifc_helpers']['social_oracle_centerline'].cuda()
    input_dict['ifc_helpers']['social_oracle_centerline_lengths'] = input_dict['ifc_helpers']['social_oracle_centerline_lengths'].cuda()
    input_dict['ifc_helpers']['agent_oracle_centerline'] = input_dict['ifc_helpers']['agent_oracle_centerline'].cuda()
    target_dict['agent_labels'] = target_dict['agent_labels'].cuda()
    
    
    preds, waypoint_preds, all_dist_params, attention, adjacency = model(**input_dict)
    input_dict['adjacency'].requires_grad = True
    loss, metrics = model.eval_preds(preds, target_dict, waypoint_preds)
    input_dict['adjacency'].retain_grad()
    adjacency.retain_grad()
    loss.backward()
    batch_preds = preds
    for idx in range(25):

        att = attention[idx].cpu().numpy()
        agent_features = input_dict['agent_features'][idx].cpu().numpy()
        social_features = input_dict['social_features'][idx].cpu().numpy()
        # target = target_dict['agent_labels'][idx].cpu().numpy()
        preds = batch_preds[idx][:,:,:,:2][0].cpu().detach().numpy()
        city_name = input_dict['ifc_helpers']['city'][idx]
        rotation = input_dict['ifc_helpers']['rotation'][idx].numpy()
        translation = input_dict['ifc_helpers']['translation'][idx].numpy()
        weight = adjacency.grad[idx][0].cpu().numpy()
        
        XAI_utils.draw_attention(agent_features, social_features, preds, city_name, rotation, translation, 
                                       weight = copy.deepcopy(att), draw_future=True, save_fig = True, 
                                       save_name = save_attention + "/" + str(batch_idx) + "_" + str(idx) + ".png")
        
        XAI_utils.draw(agent_features, social_features, preds, city_name, rotation, translation, 
                           weight = copy.deepcopy(weight), draw_future=True, save_fig = True, 
                           save_name = save_XAI + "/" + str(batch_idx) + "_" + str(idx) + ".png")

    if batch_idx > 100:
        break


  1%|          | 11/1579 [03:28<7:57:34, 18.27s/it]