In [3]:
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import torch
from src.data.argoverse_datamodule import ArgoverseDataModule
from src.data.argoverse_dataset import ArgoverseDataset
from src.data.dummy_datamodule import DummyDataModule
from src.models.WIMP import WIMP


In [146]:
args = {"IFC":True, "add_centerline":False, "attention_heads":4, "batch_norm":False, "batch_size":1, "check_val_every_n_epoch":3, 
          "dataroot":'./data/argoverse_processed_simple', "dataset":'argoverse', "distributed_backend":'ddp', "dropout":0.0, 
          "early_stop_threshold":5, "experiment_name":'example', "gpus":1, "gradient_clipping":True, "graph_iter":1, 
          "hidden_dim":512, "hidden_key_generator":True, "hidden_transform":False, "input_dim":2, "k_value_threshold":10, 
          "k_values":[6, 5, 4, 3, 2, 1], "lr":0.0001, "map_features":False, "max_epochs":200, "mode":'train', "model_name":'WIMP', 
          "no_heuristic":False, "non_linearity":'relu', "num_layers":4, "num_mixtures":6, "num_nodes":1, "output_conv":True, "output_dim":2, 
          "output_prediction":True, "precision":32, "predict_delta":False, "resume_from_checkpoint":None, 
          "scheduler_step_size":[60, 90, 120, 150, 180], "seed":None, "segment_CL":False, "segment_CL_Encoder":False, 
          "segment_CL_Encoder_Gaussian":False, "segment_CL_Encoder_Gaussian_Prob":False, "segment_CL_Encoder_Prob":True, 
          "segment_CL_Gaussian_Prob":False, "segment_CL_Prob":False, "use_centerline_features":True, "use_oracle":False, "waypoint_step":5, 
          "weight_decay":0.0, "workers":8, "wta":False, "draw_image" : False, "remove_high_related_score" : True, "maximum_delete_num" : 3}

In [147]:
train_loader = ArgoverseDataset(args['dataroot'], mode='train', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

val_loader = ArgoverseDataset(args['dataroot'], mode='val', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

test_loader = ArgoverseDataset(args['dataroot'], mode='test', delta=args['predict_delta'],
                              map_features_flag=args['map_features'],
                              social_features_flag=True, heuristic=(not args['no_heuristic']),
                              ifc=args['IFC'], is_oracle=args['use_oracle'])

train_dataset = DataLoader(train_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=True, drop_last=True)

val_dataset = DataLoader(val_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=False, drop_last=False)

trest_dataset = DataLoader(test_loader, batch_size=args['batch_size'], num_workers=args['workers'],
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=False, drop_last=False)
# dm = ArgoverseDataset(args)

In [148]:
model = WIMP(args)
model.load_state_dict(torch.load("experiments/example/checkpoints/epoch=122.ckpt")['state_dict'])
model = model.cuda()

In [149]:
def get_metric(metric_dict, ade,fde,mr,loss):
    metric_dict["ade"] += ade
    metric_dict["fde"] += fde
    metric_dict["mr"] += mr
    metric_dict["loss"] += loss
    metric_dict["length"]+=1

# a = torch.FloatTensor([1])
# get_metric(origina_model_metric, a,a,a,a)
# origina_model_metric

In [150]:
#  optimizer는 선언하지 않았습니다.
#  따라서 아마 model weight에 gradient는 계속 쌓이겠지만 저희가 중요한 것은 adjacency matrix의 graident이므로 상관 없을 것으로 예측됩니다.
import os
import math
import copy
import torch.nn as nn
Relu = nn.ReLU()

save_foler = "ResultsImg/"

save_XAI = save_foler + "/XAI/"
save_attention = save_foler + "/attention"

slicing = lambda a, idx: torch.cat((a[:, :idx], a[:, idx+1:]), axis=1)



In [151]:

abs_min = lambda weight : torch.argmin(abs(weight)[1:]).item()
abs_max = lambda weight : torch.argmax(abs(weight)[1:]).item()
simple_min = lambda weight : torch.argmin(weight[1:]).item()
simple_max = lambda weight : torch.argmax(weight[1:]).item()


names = ["abs_min", "abs_max", "simple_min","simple_max"]
for name_idx, function in enumerate([abs_min, abs_max, simple_min, simple_max]):
    origina_model_metric  ={
            "ade" : torch.FloatTensor([0]).cuda(),                                                                                                                                         "mr" : torch.FloatTensor([0]).cuda(),
            "loss" : torch.FloatTensor([0]).cuda(),
        "length" : 0
    }

    DA_model_metric = [
        {
            "ade" : torch.FloatTensor([0]).cuda(),
            "fde" : torch.FloatTensor([0]).cuda(),
            "mr" : torch.FloatTensor([0]).cuda(),
            "loss" : torch.FloatTensor([0]).cuda(),
            "length" : 0    
        },
        {
            "ade" : torch.FloatTensor([0]).cuda(),
            "fde" : torch.FloatTensor([0]).cuda(),
            "mr" : torch.FloatTensor([0]).cuda(),
            "loss" : torch.FloatTensor([0]).cuda(),
            "length" : 0    
        },
        {
            "ade" : torch.FloatTensor([0]).cuda(),
            "fde" : torch.FloatTensor([0]).cuda(),
            "mr" : torch.FloatTensor([0]).cuda(),
            "loss" : torch.FloatTensor([0]).cuda(),
            "length" : 0
        },
    ] # 하나씩 지우면서 metric을 잴것임



    for batch_idx, batch in enumerate(tqdm(val_dataset)):
        input_dict, target_dict = batch[0], batch[1]

        # get cuda
        input_dict['agent_features'] = input_dict['agent_features'].cuda()
        input_dict['social_features'] = input_dict['social_features'].cuda()
        input_dict['social_label_features'] = input_dict['social_label_features'].cuda()
        input_dict['adjacency'] = input_dict['adjacency'].cuda()
        input_dict['label_adjacency'] = input_dict['label_adjacency'].cuda()
        input_dict['num_agent_mask'] = input_dict['num_agent_mask'].cuda()
        input_dict['ifc_helpers']['agent_oracle_centerline'] = input_dict['ifc_helpers']['agent_oracle_centerline'].cuda()
        input_dict['ifc_helpers']['agent_oracle_centerline_lengths'] = input_dict['ifc_helpers']['agent_oracle_centerline_lengths'].cuda()
        input_dict['ifc_helpers']['social_oracle_centerline'] = input_dict['ifc_helpers']['social_oracle_centerline'].cuda()
        input_dict['ifc_helpers']['social_oracle_centerline_lengths'] = input_dict['ifc_helpers']['social_oracle_centerline_lengths'].cuda()
        input_dict['ifc_helpers']['agent_oracle_centerline'] = input_dict['ifc_helpers']['agent_oracle_centerline'].cuda()
        target_dict['agent_labels'] = target_dict['agent_labels'].cuda()


        preds, waypoint_preds, all_dist_params, attention, adjacency = model(**input_dict)
        loss, (ade, fde, mr) = model.eval_preds(preds, target_dict, waypoint_preds)
        get_metric(origina_model_metric, ade,fde,mr,loss)


        input_dict['adjacency'].requires_grad = True
        input_dict['adjacency'].retain_grad()

        adjacency.retain_grad()
        print(torch.sum(model.decoder.value_generator.weight))
        loss.backward()
        print(torch.sum(model.decoder.value_generator.weight))
        batch_preds = preds



        for idx in range(args["batch_size"]):
            if args["draw_image"]:
                weight = adjacency.grad[idx][0].cpu().numpy()
                att = attention[idx].cpu().numpy()
                agent_features = input_dict['agent_features'][idx].cpu().numpy()
                social_features = input_dict['social_features'][idx].cpu().numpy()
                # target = target_dict['agent_labels'][idx].cpu().numpy()
                preds = batch_preds[idx][:,:,:,:2][0].cpu().detach().numpy()
                city_name = input_dict['ifc_helpers']['city'][idx]
                rotation = input_dict['ifc_helpers']['rotation'][idx].numpy()
                translation = input_dict['ifc_helpers']['translation'][idx].numpy()
                XAI_utils.draw_attention(agent_features, social_features, preds, city_name, rotation, translation, 
                                               weight = copy.deepcopy(att), draw_future=True, save_fig = True, 
                                                 = save_attention + "/" + str(batch_idx) + "_" + str(idx) + ".png")

                XAI_utils.draw(agent_features, social_features, preds, city_name, rotation, translation, 
                                   weight = copy.deepcopy(weight), draw_future=True, save_fig = True, 
                                   save_name = save_XAI + "/" + str(batch_idx) + "_" + str(idx) + ".png")

            if args["remove_high_related_score"]:
                weight = adjacency.grad[idx][0]
                for i in range(args["maximum_delete_num"]):
    #                 print(batch_idx, input_dict["social_features"].shape)
                    if len(input_dict["social_features"][0] > 1):
                        arg = function(weight)
    #                     arg_max = torch.argmin(weight[1:]).item()

                        weight = torch.cat((weight[:arg+1], weight[arg+2:]))                                                                                                                                                                                                                                                                                                                                                                             
                        input_dict["social_features"]                                        = slicing(input_dict["social_features"], arg)
                        input_dict["num_agent_mask"]                                    = slicing(input_dict["num_agent_mask"], arg+1)
                        input_dict["ifc_helpers"                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        ['social_oracle_centerline']             = slicing(input_dict["ifc_helpers"]['social_oracle_centerline'], arg)
                        input_dict["ifc_helpers"]['social_oracle_centerline_lengths']  = slicing(input_dict["ifc_helpers"]['social_oracle_centerline_lengths'], arg)
                    else:
                        break

                    with torch.no_grad():
                        preds, waypoint_preds, all_dist_params, att_weights, adjacency = model(input_dict["agent_features"],
                                                    input_dict["social_features"],
                                                    None,
                                                    input_dict["num_agent_mask"],
                                                    ifc_helpers ={
                                                        "social_oracle_centerline": input_dict["ifc_helpers"]['social_oracle_centerline'], 
                                                        "social_oracle_centerline_lengths": input_dict["ifc_helpers"]['social_oracle_centerline_lengths'],
                                                        "agent_oracle_centerline": input_dict["ifc_helpers"]["agent_oracle_centerline"],
                                                        "agent_oracle_centerline_lengths": input_dict["ifc_helpers"]["agent_oracle_centerline_lengths"]
                                                    })
                        loss, (ade, fde, mr) = model.eval_preds(preds, target_dict, waypoint_preds)
                        get_metric(DA_model_metric[i], ade,fde,mr,loss)

  0%|          | 0/39472 [00:00<?, ?it/s]

tensor(-39.4454, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-39.4454, device='cuda:0', grad_fn=<SumBackward0>)


  0%|          | 1/39472 [00:01<16:46:25,  1.53s/it]

tensor(-39.4454, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-39.4454, device='cuda:0', grad_fn=<SumBackward0>)


Exception ignored in: <Finalize object, dead>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
KeyboardInterrupt: 
  0%|          | 1/39472 [00:02<24:53:13,  2.27s/it]


KeyboardInterrupt: 

In [152]:
model = WIMP(args)
model.load_state_dict(torch.load("experiments/example/checkpoints/epoch=122.ckpt")['state_dict'])
model.cuda()
# model.eval()


WIMP(
  (encoder): WIMPEncoder(
    (xy_conv_filters): ModuleList(
      (0): Conv1d(2, 512, kernel_size=(1,), stride=(1,))
      (1): Conv1d(2, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): Conv1d(2, 512, kernel_size=(5,), stride=(1,), padding=(2,))
    )
    (xy_input_transform): Conv1d(1536, 512, kernel_size=(1,), stride=(1,))
    (non_linearity): ReLU()
    (lstm_input_transform): Linear(in_features=1024, out_features=512, bias=True)
    (lstm): LSTM(512, 512, num_layers=4, batch_first=True)
    (waypoint_predictor): Linear(in_features=512, out_features=2, bias=True)
    (waypoint_lstm): LSTM(512, 512, num_layers=4, batch_first=True)
    (cl_conv_filters): ModuleList(
      (0): Conv1d(2, 512, kernel_size=(1,), stride=(1,))
      (1): Conv1d(2, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (2): Conv1d(2, 512, kernel_size=(5,), stride=(1,), padding=(2,))
    )
    (cl_input_transform): Conv1d(1536, 512, kernel_size=(1,), stride=(1,))
    (leakyrelu): LeakyRe

In [157]:
# model.eval()
preds, waypoint_preds, all_dist_params, att_weights, adjacency = model(**input_dict)
loss, (ade, fde, mr) = model.eval_preds(preds, target_dict, waypoint_preds)
loss

tensor(3.4276, device='cuda:0', grad_fn=<AddBackward0>)

In [125]:
model.encoder.xy_conv_filters[0].weight.grad


tensor([[[ 0.3187],
         [ 0.0102]],

        [[-0.2003],
         [-0.0071]],

        [[ 0.0173],
         [-0.0028]],

        ...,

        [[ 0.0160],
         [-0.0020]],

        [[-0.0888],
         [-0.0026]],

        [[-0.0369],
         [-0.0020]]], device='cuda:0')