In [1]:
import torch
import random
import numpy as np
from itertools import chain
import torch.nn as nn
import itertools as it
import os, shutil
import pandas as pd
from pandas.core.common import flatten
from scipy.interpolate import make_interp_spline, interp1d
import matplotlib.pyplot as plt
from IPython.display import clear_output
import math
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
%cd GraphSSL

/home/shihan/GAT/GraphSSL


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Example binary classification results
y_true = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]  # True labels
y_pred = [0, 0, 0, 1, 0, 1, 1, 1, 0, 1]  # Predicted labels

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["safe", "dangerous"])
disp.plot(cmap=plt.cm.Blues)
plt.show()

### Customized Dataset

In [24]:
random.seed(42)
t_threshold = 0.5  # in seconds
min_d_threshold = 10  # in meters
# t_prediction = 15 # in frames
sequence_dist = 40 # in frames
lead_time = 15 # in frames

In [25]:
info_dir = '../csv_final/info/'   # directory of 'info' folder
all_info = []
for filename in os.listdir(info_dir):
    if filename.endswith('txt'):
        all_info.append(filename)
pedestrian = []
rear_end = []
switch_lane = [] # switch_lane
t_bone = []
opposite_front = []
opposite_merge = []
merging = []
other = []
for info in all_info:
    txt_file = open(info_dir+info, "r").read().split('\n') # split txt file with lines
    first_cls = txt_file[2].split(': ')[-1]
    second_cls = txt_file[9].split(': ')[-1]
    first_fault = txt_file[7].split(' ')[-1]
    second_fault = txt_file[14].split(' ')[-1]
    first_direction = txt_file[5].split(' ')[-1]
    second_direction = txt_file[12].split(' ')[-1]
    first_collision_side = txt_file[6].split(' ')[-1]
    second_collision_side = txt_file[13].split(' ')[-1]
    second_relative_direction_to_first = txt_file[15].split(' ')[-1]
    if first_cls == 'pedestrian' or second_cls == 'pedestrian':
        pedestrian.append(info)
    elif first_fault == 'rear_end' or second_fault == 'rear_end':
        rear_end.append(info)
    elif first_fault == 'switch_lane' or second_fault == 'switch_lane':
        switch_lane.append(info)
    elif first_direction == 'straight' and second_direction == 'straight':
        t_bone.append(info)
    elif second_relative_direction_to_first == 'opposite' and 'straight' in [first_direction, second_direction]:
        opposite_front.append(info)
    elif second_relative_direction_to_first == 'opposite':
        opposite_merge.append(info)
    elif first_fault == 'entrance_ramp' or second_fault == 'entrance_ramp':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='left' and second_direction=='straight':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='straight' and second_direction=='left':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='straight' and second_direction=='right':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='right' and second_direction=='straight':
        merging.append(info)
    else:
        other.append(info)
random.shuffle(pedestrian)
random.shuffle(rear_end)
random.shuffle(switch_lane)
random.shuffle(t_bone)
random.shuffle(opposite_front)
random.shuffle(opposite_merge)
random.shuffle(merging)
random.shuffle(other)
acc_all = {'pedestrian':len(pedestrian), 'rear_end':len(rear_end),
        'switch_lane':len(switch_lane), 't_bone':len(t_bone),
        'opposite_front':len(opposite_front), 'opposite_merge': len(opposite_merge),
        'merging':len(merging), 'other':len(other),
        'total': len(pedestrian+rear_end+switch_lane+t_bone+opposite_front+opposite_merge+merging+other)}
print(acc_all)
acc_train = {'pedestrian':math.ceil(len(pedestrian)*0.8), 'rear_end':math.ceil(len(rear_end)*0.8),
        'switch_lane':math.ceil(len(switch_lane)*0.8), 't_bone':math.ceil(len(t_bone)*0.8),
        'opposite_front':math.ceil(len(opposite_front)*0.8), 'opposite_merge': math.ceil(len(opposite_merge)*0.8),
        'merging':math.ceil(len(merging)*0.8), 'other':math.ceil(len(other)*0.8)}
acc_test = {'pedestrian':int(len(pedestrian)/5), 'rear_end':int(len(rear_end)/5),
        'switch_lane':int(len(switch_lane)/5), 't_bone':int(len(t_bone)/5),
        'opposite_front':int(len(opposite_front)/5), 'opposite_merge': int(len(opposite_merge)/5),
        'merging':int(len(merging)/5), 'other':int(len(other)/5)
        }
print('training dataset:')
print(acc_train)
print('testing dataset:')
print(acc_test)

{'pedestrian': 8, 'rear_end': 19, 'switch_lane': 13, 't_bone': 31, 'opposite_front': 22, 'opposite_merge': 19, 'merging': 40, 'other': 44, 'total': 196}
training dataset:
{'pedestrian': 7, 'rear_end': 16, 'switch_lane': 11, 't_bone': 25, 'opposite_front': 18, 'opposite_merge': 16, 'merging': 32, 'other': 36}
testing dataset:
{'pedestrian': 1, 'rear_end': 3, 'switch_lane': 2, 't_bone': 6, 'opposite_front': 4, 'opposite_merge': 3, 'merging': 8, 'other': 8}


In [26]:
accident_dir = '../csv_final/accident/'
normal_dir = '../csv_final/normal/'
train_scene_list = []
test_scene_list = []
accident_names = ['pedestrian','rear_end','switch_lane','t_bone','opposite_front','opposite_merge','merging','other']
all_names = accident_names+['no_accident']
major_acc_num = 36
accident_type_all = [pedestrian,rear_end,switch_lane,t_bone,opposite_front,opposite_merge,merging,other]
balanced_all_acc_train = []
for accident_type in accident_type_all:
    random.shuffle(accident_type)
    test_scene_list += accident_type[ : int(len(accident_type)/5)]
    # train_scene_list += accident_type[int(len(accident_type)/5) : ]
    train_num_accident_type = math.ceil(len(accident_type)*0.8)
    minor_class_multiplier = int(major_acc_num/train_num_accident_type)
    balanced_all_acc_train.append(accident_type[int(len(accident_type)/5):]*minor_class_multiplier 
                                  + accident_type[int(len(accident_type)/5):][:major_acc_num-math.ceil(len(accident_type)*0.8)*minor_class_multiplier])
    train_scene_list = train_scene_list + accident_type[int(len(accident_type)/5):]*minor_class_multiplier + accident_type[int(len(accident_type)/5):][:major_acc_num-math.ceil(len(accident_type)*0.8)*minor_class_multiplier]
normal_scene_list = []
for filename in os.listdir(normal_dir):
    if filename.endswith('csv'):
        normal_scene_list.append(filename)
random.shuffle(normal_scene_list)
# train_scene_list += normal_scene_list[len(test_scene_list) : len(test_scene_list)+len(train_scene_list)]
# test_scene_list += normal_scene_list[ : len(test_scene_list)]

In [29]:
# dataset_train = []
# # test_scene_list += normal_scene_list[ : len(test_scene_list)]
# for scene in train_scene_list:
#     if scene in normal_scene_list:
#         # y = 0
#         df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
#         crash_frame_id = df[-1:]['frame_id'].item()
#         crash_ids = [314159, 314159]
#     else:
#         # y = 1
#         df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
#         crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
#         crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
#     crash_pred_frame_id = crash_frame_id - 1
#     if t_prediction+sequence_dist+t_prediction <= crash_pred_frame_id:
#         pred_frame_id_list = [crash_pred_frame_id-t_prediction-sequence_dist, crash_pred_frame_id]
#     else:
#         pred_frame_id_list = [t_prediction-1, crash_pred_frame_id]
#     print(scene)
#     for pred_frame_id in pred_frame_id_list:
#         for frame_id in range(pred_frame_id-t_prediction+1, pred_frame_id+1):
#             df_frame = df[df['frame_id'] == frame_id]
#             df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
#             # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
#             df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
#             id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
#             node_idx_list = list(range(len(id_list)))
#             df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
#                                                                                         'van',
#                                                                                         'truck',
#                                                                                         'motorcycle',
#                                                                                         'cyclist',
#                                                                                         'pedestrian'
#                                                                                         ])), dtype=float)
#             df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
#             node_features = df_features.to_numpy()
#             edge_index = [[],[]]
#             edge_features = []
#             edge_id = []
#             for node_connection in it.combinations(node_idx_list, 2):
#                 x1 = df_numeric['x_center'].tolist()[node_connection[0]]
#                 y1 = df_numeric['y_center'].tolist()[node_connection[0]]
#                 v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
#                 v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
#                 x2 = df_numeric['x_center'].tolist()[node_connection[1]]
#                 y2 = df_numeric['y_center'].tolist()[node_connection[1]]
#                 v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
#                 v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
#                 delta_x = x2 - x1
#                 delta_y = y2 - y1
#                 delta_v_x = v_x1 - v_x2
#                 delta_v_y = v_y1 - v_y2
#                 if delta_v_x**2+delta_v_y**2 == 0:
#                     min_d = math.sqrt(delta_x**2+delta_y**2)
#                     t_accident = 0
#                 else:
#                     t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
#                     if t_accident < 0:
#                         min_d = math.sqrt(delta_x**2+delta_y**2)
#                     else:
#                         min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
#                 if min_d <= min_d_threshold:
#                     edge_index[0].append(node_connection[0])
#                     edge_index[0].append(node_connection[1])
#                     edge_index[1].append(node_connection[1])
#                     edge_index[1].append(node_connection[0])
#                     edge_features.append([t_accident, min_d])
#                     edge_features.append([t_accident, min_d])
#                     edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
#                     edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
#             edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
#             if pred_frame_id_list.index(pred_frame_id) == 0:
#                 y = torch.tensor([0], dtype=torch.long)
#             else:
#                 y = torch.tensor([1], dtype=torch.long)
#             data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
#                         x = torch.tensor(node_features, dtype=torch.float),
#                         edge_attr = edge_features_norm,
#                         y = y,
#                         frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
#                         # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
#                         obj_ids = torch.tensor(id_list, dtype=torch.long),
#                         crash_ids = torch.tensor(crash_ids, dtype=torch.long)
#                         )
#             dataset_train.append(data)

train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_

In [None]:
# group_train = []
# # test_scene_list += normal_scene_list[ : len(test_scene_list)]
# for accident_class in balanced_all_acc_train:
#     sub_class = []
#     for scene in accident_class:
#         if scene in normal_scene_list:
#             # y = 0
#             df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
#             crash_frame_id = df[-1:]['frame_id'].item()
#             crash_ids = [314159, 314159]
#         else:
#             # y = 1
#             df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
#             crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
#             crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
#         crash_pred_frame_id = crash_frame_id - 1
#         if t_prediction+sequence_dist+t_prediction <= crash_pred_frame_id:
#             pred_frame_id_list = [crash_pred_frame_id-t_prediction-sequence_dist, crash_pred_frame_id]
#         else:
#             pred_frame_id_list = [t_prediction-1, crash_pred_frame_id]
#         print(scene)
#         pair_seq = []
#         for pred_frame_id in pred_frame_id_list:
#             for frame_id in range(pred_frame_id-t_prediction+1, pred_frame_id+1):
#                 df_frame = df[df['frame_id'] == frame_id]
#                 df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
#                 # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
#                 df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
#                 id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
#                 node_idx_list = list(range(len(id_list)))
#                 df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
#                                                                                             'van',
#                                                                                             'truck',
#                                                                                             'motorcycle',
#                                                                                             'cyclist',
#                                                                                             'pedestrian'
#                                                                                             ])), dtype=float)
#                 df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
#                 node_features = df_features.to_numpy()
#                 edge_index = [[],[]]
#                 edge_features = []
#                 edge_id = []
#                 for node_connection in it.combinations(node_idx_list, 2):
#                     x1 = df_numeric['x_center'].tolist()[node_connection[0]]
#                     y1 = df_numeric['y_center'].tolist()[node_connection[0]]
#                     v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
#                     v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
#                     x2 = df_numeric['x_center'].tolist()[node_connection[1]]
#                     y2 = df_numeric['y_center'].tolist()[node_connection[1]]
#                     v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
#                     v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
#                     delta_x = x2 - x1
#                     delta_y = y2 - y1
#                     delta_v_x = v_x1 - v_x2
#                     delta_v_y = v_y1 - v_y2
#                     if delta_v_x**2+delta_v_y**2 == 0:
#                         min_d = math.sqrt(delta_x**2+delta_y**2)
#                         t_accident = 0
#                     else:
#                         t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
#                         if t_accident < 0:
#                             min_d = math.sqrt(delta_x**2+delta_y**2)
#                         else:
#                             min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
#                     if min_d <= min_d_threshold:
#                         edge_index[0].append(node_connection[0])
#                         edge_index[0].append(node_connection[1])
#                         edge_index[1].append(node_connection[1])
#                         edge_index[1].append(node_connection[0])
#                         edge_features.append([t_accident, min_d])
#                         edge_features.append([t_accident, min_d])
#                         edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
#                         edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
#                 edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
#                 if pred_frame_id_list.index(pred_frame_id) == 0:
#                     y = torch.tensor([0], dtype=torch.long)
#                 else:
#                     y = torch.tensor([1], dtype=torch.long)
#                 data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
#                             x = torch.tensor(node_features, dtype=torch.float),
#                             edge_attr = edge_features_norm,
#                             y = y,
#                             frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
#                             # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
#                             obj_ids = torch.tensor(id_list, dtype=torch.long),
#                             crash_ids = torch.tensor(crash_ids, dtype=torch.long)
#                             )
#                 pair_seq.append(data)
#         sub_class.append(pair_seq)
#     group_train.append(sub_class)

In [27]:
group_train = []
# test_scene_list += normal_scene_list[ : len(test_scene_list)]
for accident_class in balanced_all_acc_train:
    sub_class = []
    for scene in accident_class:
        if scene in normal_scene_list:
            # y = 0
            df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
            crash_frame_id = df[-1:]['frame_id'].item()
            crash_ids = [314159, 314159]
        else:
            # y = 1
            df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
            crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
            crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
        crash_pred_frame_id = crash_frame_id
        if sequence_dist+lead_time <= crash_pred_frame_id:
            pred_frame_id_list = [crash_pred_frame_id-lead_time-sequence_dist, crash_pred_frame_id-lead_time]
        else:
            pred_frame_id_list = [0, crash_pred_frame_id-lead_time]
        print(scene)
        pair_seq = []
        for pred_frame_id in pred_frame_id_list:
            df_frame = df[df['frame_id'] == pred_frame_id]
            df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
            # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
            df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
            id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
            node_idx_list = list(range(len(id_list)))
            df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                        'van',
                                                                                        'truck',
                                                                                        'motorcycle',
                                                                                        'cyclist',
                                                                                        'pedestrian'
                                                                                        ])), dtype=float)
            df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
            node_features = df_features.to_numpy()
            edge_index = [[],[]]
            edge_features = []
            edge_id = []
            for node_connection in it.combinations(node_idx_list, 2):
                x1 = df_numeric['x_center'].tolist()[node_connection[0]]
                y1 = df_numeric['y_center'].tolist()[node_connection[0]]
                v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
                v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
                x2 = df_numeric['x_center'].tolist()[node_connection[1]]
                y2 = df_numeric['y_center'].tolist()[node_connection[1]]
                v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
                v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
                delta_x = x2 - x1
                delta_y = y2 - y1
                delta_v_x = v_x1 - v_x2
                delta_v_y = v_y1 - v_y2
                if delta_v_x**2+delta_v_y**2 == 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                    t_accident = 0
                else:
                    t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                    if t_accident < 0:
                        min_d = math.sqrt(delta_x**2+delta_y**2)
                    else:
                        min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
                if min_d <= min_d_threshold and t_accident <= t_threshold:
                    edge_index[0].append(node_connection[0])
                    edge_index[0].append(node_connection[1])
                    edge_index[1].append(node_connection[1])
                    edge_index[1].append(node_connection[0])
                    edge_features.append([t_accident, min_d])
                    edge_features.append([t_accident, min_d])
                    edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                    edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
            edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
            if pred_frame_id_list.index(pred_frame_id) == 0:
                y = torch.tensor([0], dtype=torch.long)
            else:
                y = torch.tensor([1], dtype=torch.long)
            data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                        x = torch.tensor(node_features, dtype=torch.float),
                        edge_attr = edge_features_norm,
                        y = y,
                        frame_to_crash = torch.tensor(crash_frame_id-pred_frame_id, dtype=torch.float),
                        # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                        obj_ids = torch.tensor(id_list, dtype=torch.long),
                        crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                        )
            pair_seq.append(data)
        sub_class.append(pair_seq)
    group_train.append(sub_class)

train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_s

In [28]:
dataset_test = []
# test_scene_list += normal_scene_list[ : len(test_scene_list)]
for scene in test_scene_list:
    if scene in normal_scene_list:
        # y = 0
        df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
        crash_frame_id = df[-1:]['frame_id'].item()
        crash_ids = [314159, 314159]
    else:
        # y = 1
        df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
        crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
        crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
    crash_pred_frame_id = crash_frame_id
    if sequence_dist+lead_time <= crash_pred_frame_id:
        pred_frame_id_list = [crash_pred_frame_id-lead_time-sequence_dist, crash_pred_frame_id-lead_time]
    else:
        pred_frame_id_list = [0, crash_pred_frame_id-lead_time]
    print(scene)
    for pred_frame_id in pred_frame_id_list:
        df_frame = df[df['frame_id'] == pred_frame_id]
        df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
        # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
        df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
        id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
        node_idx_list = list(range(len(id_list)))
        df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                    'van',
                                                                                    'truck',
                                                                                    'motorcycle',
                                                                                    'cyclist',
                                                                                    'pedestrian'
                                                                                    ])), dtype=float)
        df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
        node_features = df_features.to_numpy()
        edge_index = [[],[]]
        edge_features = []
        edge_id = []
        for node_connection in it.combinations(node_idx_list, 2):
            x1 = df_numeric['x_center'].tolist()[node_connection[0]]
            y1 = df_numeric['y_center'].tolist()[node_connection[0]]
            v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
            v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
            x2 = df_numeric['x_center'].tolist()[node_connection[1]]
            y2 = df_numeric['y_center'].tolist()[node_connection[1]]
            v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
            v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
            delta_x = x2 - x1
            delta_y = y2 - y1
            delta_v_x = v_x1 - v_x2
            delta_v_y = v_y1 - v_y2
            if delta_v_x**2+delta_v_y**2 == 0:
                min_d = math.sqrt(delta_x**2+delta_y**2)
                t_accident = 0
            else:
                t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                if t_accident < 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                else:
                    min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
            if min_d <= min_d_threshold and t_accident <= t_threshold:
                edge_index[0].append(node_connection[0])
                edge_index[0].append(node_connection[1])
                edge_index[1].append(node_connection[1])
                edge_index[1].append(node_connection[0])
                edge_features.append([t_accident, min_d])
                edge_features.append([t_accident, min_d])
                edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
        edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
        if pred_frame_id_list.index(pred_frame_id) == 0:
            y = torch.tensor([0], dtype=torch.long)
        else:
            y = torch.tensor([1], dtype=torch.long)
        data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                    x = torch.tensor(node_features, dtype=torch.float),
                    edge_attr = edge_features_norm,
                    y = y,
                    frame_to_crash = torch.tensor(crash_frame_id-pred_frame_id, dtype=torch.float),
                    # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                    obj_ids = torch.tensor(id_list, dtype=torch.long),
                    crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                    )
        dataset_test.append(data)

val_01_Town03_type001_subtype0002_scenario00030.txt
train_09_Town03_type001_subtype0001_scenario00014.txt
train_01_Town05_type001_subtype0002_scenario00019.txt
train_04_Town07_type001_subtype0001_scenario00010.txt
train_01_Town05_type001_subtype0002_scenario00012.txt
train_03_Town05_type001_subtype0001_scenario00020.txt
train_06_Town04_type001_subtype0002_scenario00022.txt
val_01_Town04_type001_subtype0002_scenario00025.txt
train_06_Town05_type001_subtype0002_scenario00002.txt
train_03_Town03_type001_subtype0001_scenario00015.txt
train_05_Town05_type001_subtype0001_scenario00005.txt
train_05_Town03_type001_subtype0002_scenario00010.txt
train_02_Town10HD_type001_subtype0001_scenario00021.txt
train_09_Town04_type001_subtype0001_scenario00013.txt
val_01_Town10HD_type001_subtype0001_scenario00020.txt
train_10_Town07_type001_subtype0001_scenario00019.txt
train_04_Town05_type001_subtype0002_scenario00028.txt
train_07_Town04_type001_subtype0001_scenario00005.txt
train_03_Town03_type001_subtyp

### Classifier Training

In [29]:
args = {
    "device" : torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
    "save" : "Single_Graph_BC_model",
    "load" : "ssl_model",
    "lr" : 0.001,
    "lr_reg" : 0.00001,
    "epochs" : 200,
    "batch_size" : 32,
    "num_workers" : 2,
    "model" : "gat", # choices are ["gcn", "gin", "resgcn", "gat", "graphsage", "sgc"]
    "input_dim" : 12, # input dimension for node features
    # "num_classes" : len(all_names), # collision
    "num_classes" : 1,
    "feat_dim" : 8,
    "layers" : 3,
    "train_data_percent" : 0.8,
}

class AttributeDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttributeDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = AttributeDict(args)

In [30]:
# train_dataset = dataset_train
grouped_train_dataset = group_train
val_dataset = dataset_test
test_dataset = dataset_test
# train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)

print("Dataset split: {} {} {}".format(len(grouped_train_dataset), len(val_dataset), len(val_dataset)))
print("Number of classes: {}".format(args.num_classes))

Dataset split: 8 70 70
Number of classes: 1


In [35]:
import importlib, models_Single_Graph_BC_Reg
importlib.reload(models_Single_Graph_BC_Reg)
from models_Single_Graph_BC_Reg import *

# classification model is a GNN encoder followed by linear layer
model = GraphClassificationModel(args.input_dim, args.feat_dim, n_layers=args.layers, gnn=args.model, output_dim=args.num_classes
                                # , load=args.load
                                 )

for param in model.parameters():
    param.requires_grad = True

model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer_reg = torch.optim.Adam(model.parameters(), lr=args.lr_reg)

In [36]:
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_mse = torch.nn.MSELoss()

best_train_loss, best_val_loss = float("inf"), float("inf")
is_best_loss = False
best_val_acc = 0
for epoch in range(args.epochs):
    losses = []
    edge_losses = []
    reg_losses = []
    correct = 0
    edge_correct = 0
    frame_error = 0

    for acc_type in grouped_train_dataset:
        random.shuffle(acc_type)
    train_dataset = []
    for i in range(major_acc_num):
        for acc_type in grouped_train_dataset:
            train_dataset += acc_type[i]
    flattened = train_dataset   # (15+15)*36*8
    train_loader = DataLoader(flattened, num_workers=args.num_workers, batch_size=32, shuffle=False)

    for data in train_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr
        labels = data.y
        
        # regression_score = model(data_input)[1][data.y == 1]
        # regression_gt = data.frame_to_crash[data.y == 1]
        # loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/(t_prediction*10)
        # frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
        # reg_losses.append(loss_reg.item())
        # optimizer.zero_grad()
        # loss_reg.backward()
        # optimizer.step()

        gt = labels.float()
        scores = model(data_input)[0].squeeze(1)
        loss = loss_fn(scores, gt)
        pred = []
        for pr in scores:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        correct += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
        losses.append(loss.item())

        # backprop
        optimizer.zero_grad()
        # (loss+loss_reg).backward()
        loss.backward()
        # loss_edge.backward()
        # loss_reg.backward()
        optimizer.step()


    epoch_loss_train = sum(losses) / len(losses)
    # epoch_edge_loss_train = sum(edge_losses) / len(edge_losses)
    # epoch_reg_loss_train = sum(reg_losses) / len(reg_losses)
    epoch_acc_train = correct / len(train_loader)
    # epoch_edge_acc_train = edge_correct / len(train_loader)
    # epoch_reg_error_train = frame_error / len(train_loader)


    # validation
    losses = []
    edge_losses = []
    reg_losses = []
    correct = 0
    edge_correct = 0
    frame_error = 0

    for data in val_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr
        labels = data.y
        gt = labels.float()
        scores = model(data_input)[0].squeeze(1)
        loss = loss_fn(scores, gt)
        pred = []
        for pr in scores:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        correct += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
        losses.append(loss.item())
        # regression_score = model(data_input)[1][data.y == 1]
        # regression_gt = data.frame_to_crash[data.y == 1]
        # loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/(t_prediction*10)
        # frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
        # reg_losses.append(loss_reg.item())


    # gather the results for the epoch
    epoch_loss_val = sum(losses) / len(losses)
    epoch_acc_val = correct / len(val_loader)
    # epoch_edge_loss_val = sum(edge_losses) / len(edge_losses)
    # epoch_edge_acc_val = edge_correct / len(val_loader)
    # epoch_reg_loss_val = sum(reg_losses) / len(reg_losses)
    # epoch_reg_error_val = frame_error / len(val_loader)


    log = "Epoch {},   Train Loss: {:.3f},            Train Accuracy: {:.3f},            Val Loss: {:.3f}, Val Accuracy: {:.3f}, \n"
    print(log.format(epoch, 
                     epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val, 
                     ))
    
    if epoch_acc_val > best_val_acc:
        best_epoch, best_train_loss, best_val_loss, is_best_loss = epoch, epoch_loss_train, epoch_loss_val, True
        best_train_acc, best_val_acc = epoch_acc_train, epoch_acc_val
# model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_loss, best_val_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_loss, ', best total val loss: ', best_val_loss)
print('best train accuracy: ', best_train_acc, ', best validation accuracy: ', best_val_acc)

Epoch 0,   Train Loss: 0.713,            Train Accuracy: 0.526,            Val Loss: 0.684, Val Accuracy: 0.566, 

Epoch 1,   Train Loss: 0.690,            Train Accuracy: 0.557,            Val Loss: 0.712, Val Accuracy: 0.444, 

Epoch 2,   Train Loss: 0.674,            Train Accuracy: 0.590,            Val Loss: 0.740, Val Accuracy: 0.476, 

Epoch 3,   Train Loss: 0.677,            Train Accuracy: 0.562,            Val Loss: 0.794, Val Accuracy: 0.441, 

Epoch 4,   Train Loss: 0.683,            Train Accuracy: 0.557,            Val Loss: 0.717, Val Accuracy: 0.576, 

Epoch 5,   Train Loss: 0.684,            Train Accuracy: 0.549,            Val Loss: 0.668, Val Accuracy: 0.601, 

Epoch 6,   Train Loss: 0.681,            Train Accuracy: 0.575,            Val Loss: 0.690, Val Accuracy: 0.562, 

Epoch 7,   Train Loss: 0.676,            Train Accuracy: 0.587,            Val Loss: 0.870, Val Accuracy: 0.410, 

Epoch 8,   Train Loss: 0.660,            Train Accuracy: 0.597,            Val L

  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 17,   Train Loss: 0.629,            Train Accuracy: 0.615,            Val Loss: 0.723, Val Accuracy: 0.556, 

Epoch 18,   Train Loss: 0.605,            Train Accuracy: 0.660,            Val Loss: 0.753, Val Accuracy: 0.434, 

Epoch 19,   Train Loss: 0.631,            Train Accuracy: 0.625,            Val Loss: 0.678, Val Accuracy: 0.628, 



  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 20,   Train Loss: 0.620,            Train Accuracy: 0.656,            Val Loss: 0.757, Val Accuracy: 0.556, 

Epoch 21,   Train Loss: 0.600,            Train Accuracy: 0.667,            Val Loss: 0.815, Val Accuracy: 0.434, 

Epoch 22,   Train Loss: 0.586,            Train Accuracy: 0.663,            Val Loss: 0.719, Val Accuracy: 0.476, 

Epoch 23,   Train Loss: 0.586,            Train Accuracy: 0.681,            Val Loss: 0.841, Val Accuracy: 0.510, 

Epoch 24,   Train Loss: 0.586,            Train Accuracy: 0.655,            Val Loss: 0.779, Val Accuracy: 0.576, 

Epoch 25,   Train Loss: 0.570,            Train Accuracy: 0.698,            Val Loss: 0.717, Val Accuracy: 0.611, 

Epoch 26,   Train Loss: 0.598,            Train Accuracy: 0.646,            Val Loss: 0.769, Val Accuracy: 0.521, 

Epoch 27,   Train Loss: 0.568,            Train Accuracy: 0.675,            Val Loss: 0.728, Val Accuracy: 0.587, 

Epoch 28,   Train Loss: 0.562,            Train Accuracy: 0.696,        

  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 30,   Train Loss: 0.548,            Train Accuracy: 0.691,            Val Loss: 0.792, Val Accuracy: 0.465, 

Epoch 31,   Train Loss: 0.551,            Train Accuracy: 0.689,            Val Loss: 0.813, Val Accuracy: 0.622, 

Epoch 32,   Train Loss: 0.543,            Train Accuracy: 0.689,            Val Loss: 0.880, Val Accuracy: 0.465, 

Epoch 33,   Train Loss: 0.546,            Train Accuracy: 0.681,            Val Loss: 0.757, Val Accuracy: 0.542, 

Epoch 34,   Train Loss: 0.533,            Train Accuracy: 0.686,            Val Loss: 0.771, Val Accuracy: 0.521, 

Epoch 35,   Train Loss: 0.520,            Train Accuracy: 0.724,            Val Loss: 0.865, Val Accuracy: 0.542, 

Epoch 36,   Train Loss: 0.545,            Train Accuracy: 0.701,            Val Loss: 0.863, Val Accuracy: 0.420, 

Epoch 37,   Train Loss: 0.524,            Train Accuracy: 0.717,            Val Loss: 0.852, Val Accuracy: 0.562, 

Epoch 38,   Train Loss: 0.543,            Train Accuracy: 0.686,        

  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 109,   Train Loss: 0.231,            Train Accuracy: 0.892,            Val Loss: 0.922, Val Accuracy: 0.729, 

Epoch 110,   Train Loss: 0.246,            Train Accuracy: 0.875,            Val Loss: 1.004, Val Accuracy: 0.608, 

Epoch 111,   Train Loss: 0.235,            Train Accuracy: 0.873,            Val Loss: 0.890, Val Accuracy: 0.694, 

Epoch 112,   Train Loss: 0.217,            Train Accuracy: 0.889,            Val Loss: 1.116, Val Accuracy: 0.618, 

Epoch 113,   Train Loss: 0.256,            Train Accuracy: 0.863,            Val Loss: 0.932, Val Accuracy: 0.708, 

Epoch 114,   Train Loss: 0.244,            Train Accuracy: 0.872,            Val Loss: 1.034, Val Accuracy: 0.604, 

Epoch 115,   Train Loss: 0.240,            Train Accuracy: 0.877,            Val Loss: 1.122, Val Accuracy: 0.573, 

Epoch 116,   Train Loss: 0.229,            Train Accuracy: 0.875,            Val Loss: 1.171, Val Accuracy: 0.628, 

Epoch 117,   Train Loss: 0.232,            Train Accuracy: 0.870

  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 123,   Train Loss: 0.289,            Train Accuracy: 0.840,            Val Loss: 1.532, Val Accuracy: 0.542, 

Epoch 124,   Train Loss: 0.278,            Train Accuracy: 0.833,            Val Loss: 1.094, Val Accuracy: 0.740, 

Epoch 125,   Train Loss: 0.255,            Train Accuracy: 0.849,            Val Loss: 1.770, Val Accuracy: 0.573, 

Epoch 126,   Train Loss: 0.259,            Train Accuracy: 0.849,            Val Loss: 1.366, Val Accuracy: 0.625, 

Epoch 127,   Train Loss: 0.243,            Train Accuracy: 0.868,            Val Loss: 1.216, Val Accuracy: 0.639, 

Epoch 128,   Train Loss: 0.212,            Train Accuracy: 0.889,            Val Loss: 1.181, Val Accuracy: 0.628, 

Epoch 129,   Train Loss: 0.217,            Train Accuracy: 0.889,            Val Loss: 0.996, Val Accuracy: 0.715, 

Epoch 130,   Train Loss: 0.219,            Train Accuracy: 0.873,            Val Loss: 1.114, Val Accuracy: 0.715, 

Epoch 131,   Train Loss: 0.237,            Train Accuracy: 0.851

  value = torch.cat(values, dim=cat_dim or 0, out=out)


Epoch 138,   Train Loss: 0.223,            Train Accuracy: 0.863,            Val Loss: 1.234, Val Accuracy: 0.604, 

Epoch 139,   Train Loss: 0.222,            Train Accuracy: 0.880,            Val Loss: 1.374, Val Accuracy: 0.573, 

Epoch 140,   Train Loss: 0.202,            Train Accuracy: 0.891,            Val Loss: 1.286, Val Accuracy: 0.573, 

Epoch 141,   Train Loss: 0.190,            Train Accuracy: 0.891,            Val Loss: 2.026, Val Accuracy: 0.618, 

Epoch 142,   Train Loss: 0.204,            Train Accuracy: 0.885,            Val Loss: 1.460, Val Accuracy: 0.597, 

Epoch 143,   Train Loss: 0.227,            Train Accuracy: 0.865,            Val Loss: 1.502, Val Accuracy: 0.625, 

Epoch 144,   Train Loss: 0.222,            Train Accuracy: 0.873,            Val Loss: 1.207, Val Accuracy: 0.684, 

Epoch 145,   Train Loss: 0.233,            Train Accuracy: 0.894,            Val Loss: 2.037, Val Accuracy: 0.670, 

Epoch 146,   Train Loss: 0.197,            Train Accuracy: 0.905

In [None]:
best Epoch:  179 , best total train loss:  0.2608414151602321 , best total val loss:  0.6423882593711218
best train accuracy:  0.8402777777777778 , best validation accuracy:  0.8020833333333334

# d=20 t=1.5 lt=?
best Epoch:  185 , best total train loss:  0.18674196095930207 , best total val loss:  1.5184612572193146
best train accuracy:  0.8819444444444444 , best validation accuracy:  0.7604166666666666

best Epoch:  183 , best total train loss:  0.26141541947921115 , best total val loss:  0.7638967633247375
best train accuracy:  0.859375 , best validation accuracy:  0.78125

Epoch 49,   Train Loss: 0.416,            Train Accuracy: 0.835,            Val Loss: 0.696, Val Accuracy: 0.680, 
Train Regression Loss: 0.030, Train Regression Error: 3.824, Val Regression Loss: 0.029, Val Regression Error: 3.818

Epoch 89,   Train Loss: 0.270,            Train Accuracy: 0.914,            Val Loss: 0.853, Val Accuracy: 0.669, 
Train Regression Loss: 0.029, Train Regression Error: 3.717, Val Regression Loss: 0.029, Val Regression Error: 3.754

In [42]:
best_epoch, best_train_loss, best_val_loss = model.load_checkpoint(os.path.join("logs", args.save), optimizer)
model.eval()

losses = []
edge_losses = []
reg_losses = []
correct = 0
edge_correct = 0
frame_error = 0

for data in test_loader:
    data.to(args.device)

    data_input = data.x, data.edge_index, data.batch, data.edge_attr

    labels = data.y

    gt = []
    for label in labels:
        if label == 0:
            gt.append([1,0])
        else:
            gt.append([0,1])
    gt = torch.tensor(gt, dtype=torch.float32).to(args.device)


    # get class scores from model
    scores = model(data_input)[0]

    loss = loss_fn(scores, gt)

    # Keep track of loss and accuracy
    pred = scores.argmax(dim=1)
    correct += int((pred == labels).sum()) / labels.size(0)
    losses.append(loss.item())


    batch_node_label = []
    for data_point in data.batch:
        if data.y[data_point] == 0:
            batch_node_label.append(False)
        else:
            batch_node_label.append(True)
    batch_node_label = torch.tensor(batch_node_label, dtype=torch.bool).to(args.device)
    batch_edge_label = batch_node_label[data.edge_index[0]]
    edge_score = model(data_input)[1][batch_edge_label]
    gt_edge = []
    batch_edge_index = data.obj_ids[data.edge_index[0]][batch_edge_label]
    for i in range(edge_score.size(0)):
        if batch_edge_index[i] in data.crash_ids:
            gt_edge.append([0,1])
        else:
            gt_edge.append([1,0])
    gt_edge = torch.tensor(gt_edge, dtype=torch.float32).to(args.device)
    loss_edge = loss_fn(edge_score, gt_edge)
    edge_losses.append(loss_edge.item())
    pred_edge = edge_score.argmax(dim=1)
    labels_edge = gt_edge.argmax(dim=1)
    edge_correct += int((pred_edge == labels_edge).sum()) / labels_edge.size(0)


    regression_score = model(data_input)[2][data.y == 1]
    regression_gt = data.frame_to_crash[data.y == 1]
    loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/t_prediction
    frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
    reg_losses.append(loss_reg.item())

# gather the results for the epoch
loss_test = sum(losses) / len(losses)
correct_test = correct / len(test_loader)
epoch_edge_loss_test = sum(edge_losses) / len(edge_losses)
epoch_edge_acc_test = edge_correct / len(test_loader)
epoch_reg_loss_test = sum(reg_losses) / len(reg_losses)
epoch_reg_error_test = frame_error / len(test_loader)

print("Test Loss at epoch {}: {:.3f}, Test Accuracy: {:.3f}, \nTest Edge Loss: {:.3f}, Test Edge Accuracy: {:.3f} \nTest Regression Loss: {:.3f}, Test Regression Error: {:.3f}".format(best_epoch, loss_test, correct_test, epoch_edge_loss_test, epoch_edge_acc_test, epoch_reg_loss_test, epoch_reg_error_test))

Test Loss at epoch 99: 2.376, Test Accuracy: 0.720, 
Test Edge Loss: 0.604, Test Edge Accuracy: 0.821 
Test Regression Loss: 0.301, Test Regression Error: 7.256
