In [2]:
# -*- coding: utf-8 -*-
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import copy
import json
import os
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.nn as nn
import torch


from src.data.argoverse_datamodule import ArgoverseDataModule
from src.data.argoverse_dataset import ArgoverseDataset
from src.data.dummy_datamodule import DummyDataModule
from src.models.WIMP import WIMP
import math

import torch.backends.cudnn as cudnn
import random
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

In [71]:
args = {"IFC":True, "add_centerline":False, "attention_heads":4, "batch_norm":False, "batch_size":100, "check_val_every_n_epoch":3, 
          "dataroot":'./data/LRP_adjacency', "dataset":'argoverse', "distributed_backend":'ddp', "dropout":0.0, 
          "early_stop_threshold":5, "experiment_name":'example', "gpus":3, "gradient_clipping":True, "graph_iter":1, 
          "hidden_dim":512, "hidden_key_generator":True, "hidden_transform":False, "input_dim":2, "k_value_threshold":10, 
          "k_values":[6, 5, 4, 3, 2, 1], "lr":0.0001, "map_features":False, "max_epochs":200, "mode":'train', "model_name":'WIMP', 
          "no_heuristic":False, "non_linearity":'relu', "num_layers":4, "num_mixtures":6, "num_nodes":1, "output_conv":True, "output_dim":2, 
          "output_prediction":True, "precision":32, "predict_delta":False, "resume_from_checkpoint":None, 
          "scheduler_step_size":[60, 90, 120, 150, 180], "seed":None, "segment_CL":False, "segment_CL_Encoder":False, 
          "segment_CL_Encoder_Gaussian":False, "segment_CL_Encoder_Gaussian_Prob":False, "segment_CL_Encoder_Prob":True, 
          "segment_CL_Gaussian_Prob":False, "segment_CL_Prob":False, "use_centerline_features":True, "use_oracle":False, "waypoint_step":5, 
          "weight_decay":0.0, "workers":8, "wta":False, "draw_image" : False, "remove_high_related_score" : False, "maximum_delete_num" : 3, 
          "save_json": False, "make_submit_file" : False, "use_hidden_feature" : True, "is_LRP": False, "adjacency_exp" : True}

from argparse import ArgumentParser
parser = ArgumentParser()


for k in args:
    parser.add_argument(str("--" + k), default = args[k], type= type(args[k]))
parser.add_argument("--XAI_lambda", default = 0.2, type= float)
parser.add_argument("--name", default = "", type=str)

parser = parser.parse_args(args=[])

In [72]:
def get_metric(metric_dict, ade,fde,mr,loss, length):
    metric_dict["ade"] += (ade * length).cpu().item()
    metric_dict["fde"] += (fde * length).cpu().item()
    metric_dict["mr"] += (mr * length).cpu().item()
    metric_dict["loss"] += (loss * length).cpu().item()
    metric_dict["length"]+=length

def calc_mean(metric):
    metric["fde"] /= metric["length"]
    metric["ade"] /= metric["length"]
    metric["mr"] /= metric["length"]
    metric["loss"] /= metric["length"]
    return {
        "fde": metric["fde"],
        "ade": metric["ade"],
        "loss": metric["loss"],
        "mr": metric["mr"]
    }

Relu = nn.ReLU()
softmax = torch.nn.Softmax(dim=2)


def normalize_max1(w):
    for i in range(len(w)):
        w[i] = w[i] / torch.max(abs(w[i]))
    return w


# f: adjacency matrix를 어떻게 normalize할 지(현재는 lrp값으로 되어 있음)
def Test(dataset, f, f_name=None):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    np.random.seed(0)
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(0)
    torch.backends.cudnn.benchmark = False

    metrics = {
        "ade": 0.0,
        "fde": 0.0,
        "mr": 0.0,
        "loss": 0.0,
        "length": 0,
    }
    
    for batch_idx, batch in enumerate(tqdm(dataset)):
        input_dict, target_dict = batch[0], batch[1]

        # get cuda
        input_dict["agent_features"] = input_dict["agent_features"].cuda()
        input_dict["social_features"] = input_dict["social_features"].cuda()
        input_dict["social_label_features"] = input_dict["social_label_features"].cuda()
        input_dict["label_adjacency"] = input_dict["label_adjacency"].cuda()
        input_dict["num_agent_mask"] = input_dict["num_agent_mask"].cuda()
        input_dict["ifc_helpers"]["agent_oracle_centerline"] = input_dict["ifc_helpers"]["agent_oracle_centerline"].cuda()
        input_dict["ifc_helpers"]["agent_oracle_centerline_lengths"] = input_dict["ifc_helpers"]["agent_oracle_centerline_lengths"].cuda()
        input_dict["ifc_helpers"]["social_oracle_centerline"] = input_dict["ifc_helpers"]["social_oracle_centerline"].cuda()
        input_dict["ifc_helpers"]["social_oracle_centerline_lengths"] = input_dict["ifc_helpers"]["social_oracle_centerline_lengths"].cuda()
        input_dict["ifc_helpers"]["agent_oracle_centerline"] = input_dict["ifc_helpers"]["agent_oracle_centerline"].cuda()
        input_dict["adjacency"] = input_dict["adjacency"].cuda()
        
        target_dict["agent_labels"] = target_dict["agent_labels"].cuda()
        input_dict["adjacency"] =f(input_dict["adjacency"], input_dict['num_agent_mask'])
        
        input_dict['adjacency'][:,0,0] = input_dict['adjacency'][:,0,0].clamp(min = 1.0)

        with torch.no_grad():
            preds, waypoint_preds, all_dist_params, attention, adjacency, gan_features, graph_output = model(**input_dict)

            loss, (ade, fde, mr) = model.eval_preds(preds, target_dict, waypoint_preds)

            get_metric(metrics, ade, fde, mr, loss, len(input_dict["adjacency"]))

    write_json = {
        "metric": calc_mean(metrics)
    }
    print(f_name)
    print(write_json)


In [73]:
train_loader = ArgoverseDataset(parser.dataroot, mode='train', delta=parser.predict_delta,
                              map_features_flag=parser.map_features,
                              social_features_flag=True, heuristic=(not parser.no_heuristic),
                              ifc=parser.IFC, is_oracle=parser.use_oracle)

val_loader = ArgoverseDataset(parser.dataroot, mode='val', delta=parser.predict_delta,
                              map_features_flag=parser.map_features,
                              social_features_flag=True, heuristic=(not parser.no_heuristic),
                              ifc=parser.IFC, is_oracle=parser.use_oracle)

train_dataset = DataLoader(train_loader, batch_size=parser.batch_size, num_workers=parser.workers,
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=True, drop_last=True)

val_dataset = DataLoader(val_loader, batch_size=parser.batch_size, num_workers=parser.workers,
                                pin_memory=True, collate_fn=ArgoverseDataset.collate,
                                shuffle=False, drop_last=False)

model = WIMP(parser)
model.load_state_dict(torch.load("experiments/example_old/checkpoints/epoch=122.ckpt")['state_dict'], strict=False) # 학습할 때에는 graph 모듈에서 p에 해당하는 network가 없었으므로

model = model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)


In [74]:
softmax = torch.nn.Softmax(dim=2)

f1 = lambda w, _: softmax(normalize_max1(abs(w))) * len(w[0])
# f2 = lambda w, _: softmax(normalize_max1(relu(w) + 0.000001)) * len(w[0])
f3 = lambda w, _: softmax(normalize_max1(abs(w))*2) * len(w[0])
# f4 = lambda w, _: softmax(normalize_max1(relu(w))*2 + 0.000001) * len(w[0])
f5 = lambda w, _: softmax(normalize_max1(abs(w))*0.5) * len(w[0])
# f6 = lambda w, _: softmax(normalize_max1(relu(w))*0.5 + 0.000001) * len(w[0])

f10 = lambda w, _: softmax(abs(w)) * len(w[0])
f11 = lambda w, _: softmax(abs(w)*0.5) * len(w[0])
f12 = lambda w, _: softmax(abs(w)*2) * len(w[0])

for f in [f3, f5, f1]:
    Test(val_dataset, f, "softmax(normalize_max1(abs(w))) * len(w[0])")

  6%|▋         | 25/395 [00:07<01:55,  3.19it/s]


KeyboardInterrupt: 

In [58]:
for f in [f3, f5]:
    Test(val_dataset, f, "softmax(normalize_max1(abs(w))) * len(w[0])")

  2%|▏         | 6/395 [00:02<03:09,  2.05it/s]


KeyboardInterrupt: 

In [61]:
softmax = torch.nn.Softmax(dim=0)

def f1(w, mask):
    w = abs(w)
    for i in range(len(w)):
#         print(w[i].shape)
        num_agent = int(torch.sum(mask[i]))+1
#         print(num_agent)
        w[i,0,:num_agent] = softmax(w[i,0,:num_agent]*0.25) * num_agent
        w[i,1:,:] = 1
    return w
    
# f1 = lambda w: softmax(normalize_max1(abs(w))) * len(w[0])
Test(val_dataset, f1, "softmax(normalize_max1(abs(w))) * len(w[0])")

100%|██████████| 395/395 [01:45<00:00,  3.74it/s]

softmax(normalize_max1(abs(w))) * len(w[0])
{'metric': {'fde': 1.1460760357798379, 'ade': 0.7543620654432021, 'loss': 6.051380298956232, 'mr': 0.11765301784475564}}





In [70]:
f1 = lambda w: softmax(normalize_max1(abs(w))*2) * len(w[0])
to_gaussian = lambda arr, mean = 0, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean
softmax0 = torch.nn.Softmax(dim=0)


def f2(w, mask = None):
    for i in range(len(w)):
        num_agent = int(torch.sum(mask[i]))+1
        
        w[i,0, :num_agent] = (softmax0(to_gaussian(abs(w[i,0, :num_agent]) * num_agent, mean = 1, std = 0.12))*num_agent) ** 1.5
#         print(w[i,0, :num_agent])
#         assert False, "SS"
    return w

Test(val_dataset, f2, "to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)")

100%|██████████| 395/395 [01:47<00:00,  3.66it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.144936785796513, 'ade': 0.7533492711472715, 'loss': 6.033516781408389, 'mr': 0.11674097837539738}}





In [78]:
f1 = lambda w: softmax(normalize_max1(abs(w))*2) * len(w[0])
to_gaussian = lambda arr, mean = 0, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean
softmax0 = torch.nn.Softmax(dim=0)


for std in range(10, 20):
    print(std/100)
    def f2(w, mask = None):
        for i in range(len(w)):
            num_agent = int(torch.sum(mask[i]))+1
            w[i,0, :num_agent] = softmax0(to_gaussian(abs(w[i,0, :num_agent]) * num_agent, mean = 1, std = std/100))*num_agent
        return w



    Test(val_dataset, f2, "to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)")


  0%|          | 0/395 [00:00<?, ?it/s]

0.1


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1439233589713564, 'ade': 0.7533491165182781, 'loss': 6.0448233631877795, 'mr': 0.1165636378258503}}
0.11


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1439675648878305, 'ade': 0.7531639085784343, 'loss': 6.044360651388118, 'mr': 0.11610761848982404}}
0.12


100%|██████████| 395/395 [01:48<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1439954319317234, 'ade': 0.7534587692528251, 'loss': 6.043938089006576, 'mr': 0.11593027774699072}}
0.13


100%|██████████| 395/395 [01:47<00:00,  3.68it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1441172041968386, 'ade': 0.7532955338996482, 'loss': 6.043616700375375, 'mr': 0.11610761855022599}}
0.14


100%|██████████| 395/395 [01:48<00:00,  3.64it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1441735618260402, 'ade': 0.7535290388525572, 'loss': 6.0433739228049514, 'mr': 0.11610761847774365}}
0.15


100%|██████████| 395/395 [01:48<00:00,  3.64it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.144359085692246, 'ade': 0.7535849920302979, 'loss': 6.043218140279324, 'mr': 0.11618362159974477}}
0.16


100%|██████████| 395/395 [01:48<00:00,  3.65it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1446375454491722, 'ade': 0.7538069694883968, 'loss': 6.043156997655924, 'mr': 0.11610761835693975}}
0.17


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1448083592659557, 'ade': 0.7537462764487945, 'loss': 6.043074806160805, 'mr': 0.11620895586904843}}
0.18


100%|██████████| 395/395 [01:48<00:00,  3.64it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1453997102110407, 'ade': 0.7541039863674565, 'loss': 6.043130600167326, 'mr': 0.11623429047660302}}
0.19


100%|██████████| 395/395 [01:48<00:00,  3.64it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1459321774690316, 'ade': 0.7545272774381293, 'loss': 6.043216008718651, 'mr': 0.11653830341158196}}





In [35]:
from matplotlib import pyplot as plt

for ii, d in enumerate(val_dataset):
    
    input_dict, traget_dict = d[0], d[1]
    break
# plt.show()
a = torch.tensor([[1,2,3],[1,2,3]]).float()
print(input_dict['adjacency'][0,0])
print(f1(input_dict['adjacency'], input_dict['num_agent_mask'])[0,0])


tensor([ 0.0421,  0.0122,  0.0135, -0.0362,  0.0171, -0.0192, -0.0164,  0.0150,
         0.0025,  0.0181, -0.0038, -0.0078,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000])
tensor([1.1896, 1.0306, 1.0422, 1.1675, 1.0692, 1.0836, 1.0642, 1.0540, 0.8833,
        1.0761, 0.9177, 0.9842, 0.4376, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000])


In [24]:
f1 = lambda w: softmax(normalize_max1(abs(w))*2) * len(w[0])
to_gaussian = lambda arr, mean = 0, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean


def f2(w, mask = None):
    for i in range(len(w)):
        num_agent = int(torch.sum(mask[i]))
        
        w[i,0, :num_agent] = to_gaussian(abs(w[i,0, :num_agent]), mean = 1, std = 0.1)
    return w

Test(val_dataset, f2, "to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)")

100%|██████████| 395/395 [01:46<00:00,  3.70it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1442772898540643, 'ade': 0.7533639042687696, 'loss': 6.041107798569783, 'mr': 0.11671564417857606}}





In [37]:
f1 = lambda w: softmax(normalize_max1(abs(w))*2) * len(w[0])
to_gaussian = lambda arr, mean = 0, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean
softmax0 = torch.nn.Softmax(dim=0)


def f2(w, mask = None):
    for i in range(len(w)):
        num_agent = int(torch.sum(mask[i]))+1
        
        w[i,0, :num_agent] = to_gaussian(softmax0(abs(w[i,0, :num_agent])) * num_agent, mean = 1, std = 0.12)
        print(w[i,0, :num_agent])
        assert False, "SS"
    return w

Test(val_dataset, f2, "to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)")

100%|██████████| 395/395 [01:45<00:00,  3.73it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.144170692878355, 'ade': 0.7532816630990872, 'loss': 6.039364519803496, 'mr': 0.11653830343574273}}





In [None]:
f1 = lambda w: softmax(normalize_max1(abs(w))*2) * len(w[0])
to_gaussian = lambda arr, mean = 0, std = 1: ((arr - torch.mean(arr))/ (torch.std(arr) + 0.00001)) * std + mean


for std in range(1, 100):
    def f2(w, mask = None, mean = 1, std = std/100):
        for i in range(len(w)):
            num_agent = int(torch.sum(mask[i]))

            w[i,0, :num_agent] = to_gaussian(abs(w[i,0, :num_agent]), mean = mean, std = std)
        return w

    Test(val_dataset, f2, f"to_gaussian(w[0,0, :num_agent], mean = 1, std = {std/100})")

100%|██████████| 395/395 [01:44<00:00,  3.76it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.01)
{'metric': {'fde': 1.1459254470488105, 'ade': 0.7541138015394516, 'loss': 6.050708923834657, 'mr': 0.11777968962616799}}


100%|██████████| 395/395 [01:45<00:00,  3.75it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.02)
{'metric': {'fde': 1.1457067941140684, 'ade': 0.7540020617000671, 'loss': 6.049396532674912, 'mr': 0.11793169630506428}}


100%|██████████| 395/395 [01:45<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.03)
{'metric': {'fde': 1.1456022392075025, 'ade': 0.7538116609320545, 'loss': 6.048189703636339, 'mr': 0.11767835200533579}}


100%|██████████| 395/395 [01:47<00:00,  3.67it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.04)
{'metric': {'fde': 1.1451893668451143, 'ade': 0.7536560880253563, 'loss': 6.047007648481935, 'mr': 0.11734900463192773}}


100%|██████████| 395/395 [01:45<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.05)
{'metric': {'fde': 1.1451274054611942, 'ade': 0.7534275749795679, 'loss': 6.045887819127027, 'mr': 0.11691831969808943}}


100%|██████████| 395/395 [01:45<00:00,  3.76it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.06)
{'metric': {'fde': 1.1448225669644359, 'ade': 0.75341788235105, 'loss': 6.044649250941283, 'mr': 0.1164623003137416}}


100%|██████████| 395/395 [01:45<00:00,  3.75it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.07)
{'metric': {'fde': 1.144603864548416, 'ade': 0.7533016757632636, 'loss': 6.0436585275180725, 'mr': 0.11648763476425113}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.08)
{'metric': {'fde': 1.1445466361646957, 'ade': 0.7530758408916398, 'loss': 6.042845696247102, 'mr': 0.11658897226427942}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.09)
{'metric': {'fde': 1.1444995002220844, 'ade': 0.7534844160176599, 'loss': 6.04200772771383, 'mr': 0.11674097867740713}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.1)
{'metric': {'fde': 1.1442772898540643, 'ade': 0.7533639042687696, 'loss': 6.041107798569783, 'mr': 0.11671564417857606}}


100%|██████████| 395/395 [01:49<00:00,  3.62it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.11)
{'metric': {'fde': 1.1441264173788341, 'ade': 0.7533306694726593, 'loss': 6.040248816729462, 'mr': 0.11646230003589264}}


100%|██████████| 395/395 [01:45<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.12)
{'metric': {'fde': 1.1441575014789336, 'ade': 0.7533443949219666, 'loss': 6.0394837805071875, 'mr': 0.11651296903355478}}


100%|██████████| 395/395 [01:45<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.13)
{'metric': {'fde': 1.144088762319663, 'ade': 0.7532748039502289, 'loss': 6.03876812118821, 'mr': 0.116487634667608}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.14)
{'metric': {'fde': 1.144130922494555, 'ade': 0.7532348453056296, 'loss': 6.038126864701956, 'mr': 0.11651296886442931}}


100%|██████████| 395/395 [01:46<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.15)
{'metric': {'fde': 1.1441213430351316, 'ade': 0.7533688790700578, 'loss': 6.037506592143244, 'mr': 0.11679164749386342}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.16)
{'metric': {'fde': 1.144252773620457, 'ade': 0.7534449949987417, 'loss': 6.037009073357182, 'mr': 0.11684231635864128}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.17)
{'metric': {'fde': 1.1441901883085488, 'ade': 0.7533014768717209, 'loss': 6.036437188076209, 'mr': 0.11671564414233489}}


100%|██████████| 395/395 [01:46<00:00,  3.71it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.18)
{'metric': {'fde': 1.1442964145612466, 'ade': 0.7531704510277856, 'loss': 6.035872099512639, 'mr': 0.11651296855033917}}


100%|██████████| 395/395 [01:45<00:00,  3.75it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.19)
{'metric': {'fde': 1.1444113661414659, 'ade': 0.7534620741576282, 'loss': 6.035370227920414, 'mr': 0.11658897186562654}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.2)
{'metric': {'fde': 1.1445597850411442, 'ade': 0.7534030736290012, 'loss': 6.034875516423726, 'mr': 0.11661430627989489}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.21)
{'metric': {'fde': 1.1446437362142068, 'ade': 0.7535020423698966, 'loss': 6.034370739352718, 'mr': 0.11663964081496712}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.22)
{'metric': {'fde': 1.14472300270863, 'ade': 0.7533952301699574, 'loss': 6.033933597003394, 'mr': 0.11641163106239127}}


100%|██████████| 395/395 [01:45<00:00,  3.75it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.23)
{'metric': {'fde': 1.1448546774144785, 'ade': 0.7537027883974221, 'loss': 6.033498225929189, 'mr': 0.11646229992716912}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.24)
{'metric': {'fde': 1.145087909350571, 'ade': 0.7539708555397061, 'loss': 6.0330929959115105, 'mr': 0.11658897197435006}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.25)
{'metric': {'fde': 1.1451399837566572, 'ade': 0.7540527873547589, 'loss': 6.0326802563367705, 'mr': 0.11681698179940826}}


100%|██████████| 395/395 [01:46<00:00,  3.72it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.26)
{'metric': {'fde': 1.1454268435403603, 'ade': 0.7537265822237013, 'loss': 6.032376155227024, 'mr': 0.11666497537420015}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.27)
{'metric': {'fde': 1.145597511426031, 'ade': 0.7541485957689599, 'loss': 6.032023083888233, 'mr': 0.11676631323664015}}


100%|██████████| 395/395 [01:47<00:00,  3.67it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.28)
{'metric': {'fde': 1.1456930273015054, 'ade': 0.7540821199915793, 'loss': 6.0316335572179725, 'mr': 0.11681698216181996}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.29)
{'metric': {'fde': 1.1458623269138715, 'ade': 0.7542104263947375, 'loss': 6.0313460802713, 'mr': 0.11679164774755162}}


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.3)
{'metric': {'fde': 1.1459956975132046, 'ade': 0.75433028976493, 'loss': 6.031032692144214, 'mr': 0.11691831981889332}}


100%|██████████| 395/395 [01:47<00:00,  3.67it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.31)
{'metric': {'fde': 1.1462657442931758, 'ade': 0.7544094501452066, 'loss': 6.030798133104786, 'mr': 0.11684231647944518}}


100%|██████████| 395/395 [01:46<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.32)
{'metric': {'fde': 1.1464330615618474, 'ade': 0.7547402061041811, 'loss': 6.030586867376915, 'mr': 0.11679164746970265}}


100%|██████████| 395/395 [01:46<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.33)
{'metric': {'fde': 1.146663884764796, 'ade': 0.7550356597807614, 'loss': 6.030336409384391, 'mr': 0.11658897211931474}}


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.34)
{'metric': {'fde': 1.1469550229394585, 'ade': 0.7551113650728248, 'loss': 6.030210124658678, 'mr': 0.11686765078499001}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.35)
{'metric': {'fde': 1.1471590914608352, 'ade': 0.7551838422913661, 'loss': 6.030043539751576, 'mr': 0.1170703260387348}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.36)
{'metric': {'fde': 1.1472761432893566, 'ade': 0.7549211565279796, 'loss': 6.029829070110963, 'mr': 0.11704499170902918}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.37)
{'metric': {'fde': 1.1474007713586925, 'ade': 0.7551624068471648, 'loss': 6.029716509481553, 'mr': 0.11701965734308241}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.38)
{'metric': {'fde': 1.1475719804458402, 'ade': 0.7552700427364231, 'loss': 6.029593139512364, 'mr': 0.11701965750012748}}


100%|██████████| 395/395 [01:46<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.39)
{'metric': {'fde': 1.1476701835799323, 'ade': 0.7553925953947391, 'loss': 6.029489486286115, 'mr': 0.11709566080333446}}


100%|██████████| 395/395 [01:46<00:00,  3.71it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.4)
{'metric': {'fde': 1.1478208964007268, 'ade': 0.7554983221957884, 'loss': 6.029416961422515, 'mr': 0.11714632959562997}}


100%|██████████| 395/395 [01:46<00:00,  3.72it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.41)
{'metric': {'fde': 1.1479590794409051, 'ade': 0.755613097305669, 'loss': 6.029340308737939, 'mr': 0.11686765113532133}}


100%|██████████| 395/395 [01:46<00:00,  3.71it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.42)
{'metric': {'fde': 1.1481327086650848, 'ade': 0.7555645229268471, 'loss': 6.029295758578668, 'mr': 0.11696898869575158}}


100%|██████████| 395/395 [01:45<00:00,  3.74it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.43)
{'metric': {'fde': 1.148380135929338, 'ade': 0.7555524118042302, 'loss': 6.029287154248334, 'mr': 0.11704499195063699}}


100%|██████████| 395/395 [01:45<00:00,  3.73it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.44)
{'metric': {'fde': 1.1486063393978247, 'ade': 0.7556602633042909, 'loss': 6.029251652204609, 'mr': 0.11732367064047304}}


100%|██████████| 395/395 [01:46<00:00,  3.71it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.45)
{'metric': {'fde': 1.1488476470263806, 'ade': 0.7557419579571251, 'loss': 6.029225986884295, 'mr': 0.11737433934820582}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.46)
{'metric': {'fde': 1.1490765485775147, 'ade': 0.7559737992953519, 'loss': 6.02920055350747, 'mr': 0.11750101139538675}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.47)
{'metric': {'fde': 1.1492937642357637, 'ade': 0.7558581100381914, 'loss': 6.029223153307992, 'mr': 0.11760234899205818}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.48)
{'metric': {'fde': 1.1494637401566894, 'ade': 0.7556886275770406, 'loss': 6.029254688344893, 'mr': 0.1177290210392391}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.49)
{'metric': {'fde': 1.1495970873683872, 'ade': 0.755743377451284, 'loss': 6.029310434416426, 'mr': 0.11775435542934666}}


100%|██████████| 395/395 [01:47<00:00,  3.68it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.5)
{'metric': {'fde': 1.149794403241625, 'ade': 0.7559191827886154, 'loss': 6.029321877735078, 'mr': 0.11798236519400292}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.51)
{'metric': {'fde': 1.150099993428577, 'ade': 0.7562049930280261, 'loss': 6.029386993550462, 'mr': 0.11815970609388131}}


100%|██████████| 395/395 [01:46<00:00,  3.72it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.52)
{'metric': {'fde': 1.1502090472657043, 'ade': 0.7560954280454328, 'loss': 6.029428294954582, 'mr': 0.11808370271819199}}


100%|██████████| 395/395 [01:46<00:00,  3.71it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.53)
{'metric': {'fde': 1.1504347381390778, 'ade': 0.7560833441761761, 'loss': 6.029428289542567, 'mr': 0.11783035862383014}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.54)
{'metric': {'fde': 1.1506303737750168, 'ade': 0.7560422858295046, 'loss': 6.02948346271368, 'mr': 0.11815970592475585}}


100%|██████████| 395/395 [01:46<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.55)
{'metric': {'fde': 1.1508958054593956, 'ade': 0.7562428328688443, 'loss': 6.029565576894302, 'mr': 0.11823570922796282}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.56)
{'metric': {'fde': 1.1510936084139622, 'ade': 0.7564512700458553, 'loss': 6.029626719517702, 'mr': 0.11828637812898185}}


100%|██████████| 395/395 [01:46<00:00,  3.70it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.57)
{'metric': {'fde': 1.1513573579281797, 'ade': 0.756648798630602, 'loss': 6.0297407120115905, 'mr': 0.11826104378719585}}


100%|██████████| 395/395 [01:47<00:00,  3.69it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.58)
{'metric': {'fde': 1.151510543459864, 'ade': 0.7568411894737587, 'loss': 6.029837495844809, 'mr': 0.11846371908926219}}


100%|██████████| 395/395 [01:47<00:00,  3.68it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.59)
{'metric': {'fde': 1.1517590977761927, 'ade': 0.7569591278096274, 'loss': 6.029950408255179, 'mr': 0.11856505661345128}}


100%|██████████| 395/395 [01:47<00:00,  3.66it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.6)
{'metric': {'fde': 1.1519634454029022, 'ade': 0.75705187920863, 'loss': 6.030067751559353, 'mr': 0.11861572541782718}}


100%|██████████| 395/395 [01:47<00:00,  3.67it/s]
  0%|          | 0/395 [00:00<?, ?it/s]

to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.61)
{'metric': {'fde': 1.1523433066013615, 'ade': 0.7573017387869486, 'loss': 6.030222504255884, 'mr': 0.11886906953634982}}


 87%|████████▋ | 345/395 [01:33<00:13,  3.84it/s]

In [16]:
from matplotlib import pyplot as plt

for ii, d in enumerate(val_dataset):
    
    input_dict, traget_dict = d[0], d[1]
#     for i in range(100):
#         xy = d[0]['agent_features'][i]
#         xy.shape
#         plt.scatter(xy[:,0], xy[:, 1])
    if ii > 10:
        break
# plt.show()
a = torch.tensor([[1,2,3],[1,2,3]]).float()
print(input_dict['adjacency'][0,0])
print(f1(input_dict['adjacency'], input_dict['num_agent_mask'])[0,0])

tensor([-0.0147, -0.0019,  0.0052, -0.0009,  0.0026, -0.0006,  0.0020, -0.0007,
         0.0001, -0.0009, -0.0002, -0.0030,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000])
tensor([2.1649, 0.9039, 1.1309, 0.8452, 0.9480, 0.8276, 0.9109, 0.8326, 0.8028,
        0.8466, 0.8075, 0.9792, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])


In [20]:
from matplotlib import pyplot as plt

for ii, d in enumerate(val_dataset):
    
    input_dict, traget_dict = d[0], d[1]
#     for i in range(100):
#         xy = d[0]['agent_features'][i]
#         xy.shape
#         plt.scatter(xy[:,0], xy[:, 1])
    if ii > 10:
        break
# plt.show()


In [44]:
def f2(w, mask):
    for i in range(len(input_dict["adjacency"])):
        num_agent = int(torch.sum(mask[0]))
        w[0,0, :num_agent] = to_gaussian(w[0,0, :num_agent], mean = 1, std = 0.000001)
    return w

f2(input_dict['adjacency'], input_dict['num_agent_mask'])[0][0]

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000])