In [2]:
import torch
import random
import numpy as np
from itertools import chain
import torch.nn as nn
import itertools as it
import os, shutil
import pandas as pd
from pandas.core.common import flatten
from scipy.interpolate import make_interp_spline, interp1d
import matplotlib.pyplot as plt
from IPython.display import clear_output
import math
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
%cd GraphSSL

/home/shihan/GAT/GraphSSL


### Customized Dataset

In [4]:
random.seed(42)
t_threshold = 1.5  # in seconds
min_d_threshold = 30  # in meters
t_prediction = 15 # in frames
sequence_dist = 40 # in frames

In [5]:
info_dir = '../csv_final/info/'   # directory of 'info' folder
all_info = []
for filename in os.listdir(info_dir):
    if filename.endswith('txt'):
        all_info.append(filename)
pedestrian = []
rear_end = []
switch_lane = [] # switch_lane
t_bone = []
opposite_front = []
opposite_merge = []
merging = []
other = []
for info in all_info:
    txt_file = open(info_dir+info, "r").read().split('\n') # split txt file with lines
    first_cls = txt_file[2].split(': ')[-1]
    second_cls = txt_file[9].split(': ')[-1]
    first_fault = txt_file[7].split(' ')[-1]
    second_fault = txt_file[14].split(' ')[-1]
    first_direction = txt_file[5].split(' ')[-1]
    second_direction = txt_file[12].split(' ')[-1]
    first_collision_side = txt_file[6].split(' ')[-1]
    second_collision_side = txt_file[13].split(' ')[-1]
    second_relative_direction_to_first = txt_file[15].split(' ')[-1]
    if first_cls == 'pedestrian' or second_cls == 'pedestrian':
        pedestrian.append(info)
    elif first_fault == 'rear_end' or second_fault == 'rear_end':
        rear_end.append(info)
    elif first_fault == 'switch_lane' or second_fault == 'switch_lane':
        switch_lane.append(info)
    elif first_direction == 'straight' and second_direction == 'straight':
        t_bone.append(info)
    elif second_relative_direction_to_first == 'opposite' and 'straight' in [first_direction, second_direction]:
        opposite_front.append(info)
    elif second_relative_direction_to_first == 'opposite':
        opposite_merge.append(info)
    elif first_fault == 'entrance_ramp' or second_fault == 'entrance_ramp':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='left' and second_direction=='straight':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='straight' and second_direction=='left':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='straight' and second_direction=='right':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='right' and second_direction=='straight':
        merging.append(info)
    else:
        other.append(info)
random.shuffle(pedestrian)
random.shuffle(rear_end)
random.shuffle(switch_lane)
random.shuffle(t_bone)
random.shuffle(opposite_front)
random.shuffle(opposite_merge)
random.shuffle(merging)
random.shuffle(other)
acc_all = {'pedestrian':len(pedestrian), 'rear_end':len(rear_end),
        'switch_lane':len(switch_lane), 't_bone':len(t_bone),
        'opposite_front':len(opposite_front), 'opposite_merge': len(opposite_merge),
        'merging':len(merging), 'other':len(other),
        'total': len(pedestrian+rear_end+switch_lane+t_bone+opposite_front+opposite_merge+merging+other)}
print(acc_all)
acc_train = {'pedestrian':math.ceil(len(pedestrian)*0.8), 'rear_end':math.ceil(len(rear_end)*0.8),
        'switch_lane':math.ceil(len(switch_lane)*0.8), 't_bone':math.ceil(len(t_bone)*0.8),
        'opposite_front':math.ceil(len(opposite_front)*0.8), 'opposite_merge': math.ceil(len(opposite_merge)*0.8),
        'merging':math.ceil(len(merging)*0.8), 'other':math.ceil(len(other)*0.8)}
acc_test = {'pedestrian':int(len(pedestrian)/5), 'rear_end':int(len(rear_end)/5),
        'switch_lane':int(len(switch_lane)/5), 't_bone':int(len(t_bone)/5),
        'opposite_front':int(len(opposite_front)/5), 'opposite_merge': int(len(opposite_merge)/5),
        'merging':int(len(merging)/5), 'other':int(len(other)/5)
        }
print('training dataset:')
print(acc_train)
print('testing dataset:')
print(acc_test)

{'pedestrian': 8, 'rear_end': 19, 'switch_lane': 13, 't_bone': 31, 'opposite_front': 22, 'opposite_merge': 19, 'merging': 40, 'other': 44, 'total': 196}
training dataset:
{'pedestrian': 7, 'rear_end': 16, 'switch_lane': 11, 't_bone': 25, 'opposite_front': 18, 'opposite_merge': 16, 'merging': 32, 'other': 36}
testing dataset:
{'pedestrian': 1, 'rear_end': 3, 'switch_lane': 2, 't_bone': 6, 'opposite_front': 4, 'opposite_merge': 3, 'merging': 8, 'other': 8}


In [6]:
accident_dir = '../csv_final/accident/'
normal_dir = '../csv_final/normal/'
train_scene_list = []
test_scene_list = []
accident_names = ['pedestrian','rear_end','switch_lane','t_bone','opposite_front','opposite_merge','merging','other']
all_names = accident_names+['no_accident']
major_acc_num = 36
accident_type_all = [pedestrian,rear_end,switch_lane,t_bone,opposite_front,opposite_merge,merging,other]
balanced_all_acc_train = []
for accident_type in accident_type_all:
    random.shuffle(accident_type)
    test_scene_list += accident_type[ : int(len(accident_type)/5)]
    # train_scene_list += accident_type[int(len(accident_type)/5) : ]
    train_num_accident_type = math.ceil(len(accident_type)*0.8)
    minor_class_multiplier = int(major_acc_num/train_num_accident_type)
    balanced_all_acc_train.append(accident_type[int(len(accident_type)/5):]*minor_class_multiplier 
                                  + accident_type[int(len(accident_type)/5):][:major_acc_num-math.ceil(len(accident_type)*0.8)*minor_class_multiplier])
    train_scene_list = train_scene_list + accident_type[int(len(accident_type)/5):]*minor_class_multiplier + accident_type[int(len(accident_type)/5):][:major_acc_num-math.ceil(len(accident_type)*0.8)*minor_class_multiplier]
normal_scene_list = []
for filename in os.listdir(normal_dir):
    if filename.endswith('csv'):
        normal_scene_list.append(filename)
random.shuffle(normal_scene_list)
# train_scene_list += normal_scene_list[len(test_scene_list) : len(test_scene_list)+len(train_scene_list)]
# test_scene_list += normal_scene_list[ : len(test_scene_list)]

In [8]:
# dataset_train = []
# test_scene_list += normal_scene_list[ : len(test_scene_list)]
for scene in train_scene_list:
    if scene in normal_scene_list:
        # y = 0
        df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
        crash_frame_id = df[-1:]['frame_id'].item()
        crash_ids = [314159, 314159]
    else:
        # y = 1
        df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
        crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
        crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
    crash_pred_frame_id = crash_frame_id - 1
    if t_prediction+sequence_dist+t_prediction <= crash_pred_frame_id:
        pred_frame_id_list = [crash_pred_frame_id-t_prediction-sequence_dist, crash_pred_frame_id]
    else:
        pred_frame_id_list = [t_prediction-1, crash_pred_frame_id]
    print(scene)
    for pred_frame_id in pred_frame_id_list:
        for frame_id in range(pred_frame_id-t_prediction+1, pred_frame_id+1):
            df_frame = df[df['frame_id'] == frame_id]
            df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
            # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
            df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
            id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
            node_idx_list = list(range(len(id_list)))
            df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                        'van',
                                                                                        'truck',
                                                                                        'motorcycle',
                                                                                        'cyclist',
                                                                                        'pedestrian'
                                                                                        ])), dtype=float)
            df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
            node_features = df_features.to_numpy()
            edge_index = [[],[]]
            edge_features = []
            edge_id = []
            for node_connection in it.combinations(node_idx_list, 2):
                x1 = df_numeric['x_center'].tolist()[node_connection[0]]
                y1 = df_numeric['y_center'].tolist()[node_connection[0]]
                v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
                v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
                x2 = df_numeric['x_center'].tolist()[node_connection[1]]
                y2 = df_numeric['y_center'].tolist()[node_connection[1]]
                v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
                v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
                delta_x = x2 - x1
                delta_y = y2 - y1
                delta_v_x = v_x1 - v_x2
                delta_v_y = v_y1 - v_y2
                if delta_v_x**2+delta_v_y**2 == 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                    t_accident = 0
                else:
                    t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                    if t_accident < 0:
                        min_d = math.sqrt(delta_x**2+delta_y**2)
                    else:
                        min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
                if min_d <= min_d_threshold:
                    edge_index[0].append(node_connection[0])
                    edge_index[0].append(node_connection[1])
                    edge_index[1].append(node_connection[1])
                    edge_index[1].append(node_connection[0])
                    edge_features.append([t_accident, min_d])
                    edge_features.append([t_accident, min_d])
                    edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                    edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
            edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
            if pred_frame_id_list.index(pred_frame_id) == 0:
                y = torch.tensor([0], dtype=torch.long)
            else:
                y = torch.tensor([1], dtype=torch.long)
            data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                        x = torch.tensor(node_features, dtype=torch.float),
                        edge_attr = edge_features_norm,
                        y = y,
                        frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
                        # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                        obj_ids = torch.tensor(id_list, dtype=torch.long),
                        crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                        )
            dataset_train.append(data)

train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_s

In [7]:
group_train = []
# test_scene_list += normal_scene_list[ : len(test_scene_list)]
for accident_class in balanced_all_acc_train:
    sub_class = []
    for scene in accident_class:
        if scene in normal_scene_list:
            # y = 0
            df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
            crash_frame_id = df[-1:]['frame_id'].item()
            crash_ids = [314159, 314159]
        else:
            # y = 1
            df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
            crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
            crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
        crash_pred_frame_id = crash_frame_id - 1
        if t_prediction+sequence_dist+t_prediction <= crash_pred_frame_id:
            pred_frame_id_list = [crash_pred_frame_id-t_prediction-sequence_dist, crash_pred_frame_id]
        else:
            pred_frame_id_list = [t_prediction-1, crash_pred_frame_id]
        print(scene)
        pair_seq = []
        for pred_frame_id in pred_frame_id_list:
            for frame_id in range(pred_frame_id-t_prediction+1, pred_frame_id+1):
                df_frame = df[df['frame_id'] == frame_id]
                df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
                # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
                df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
                id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
                node_idx_list = list(range(len(id_list)))
                df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                            'van',
                                                                                            'truck',
                                                                                            'motorcycle',
                                                                                            'cyclist',
                                                                                            'pedestrian'
                                                                                            ])), dtype=float)
                df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
                node_features = df_features.to_numpy()
                edge_index = [[],[]]
                edge_features = []
                edge_id = []
                for node_connection in it.combinations(node_idx_list, 2):
                    x1 = df_numeric['x_center'].tolist()[node_connection[0]]
                    y1 = df_numeric['y_center'].tolist()[node_connection[0]]
                    v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
                    v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
                    x2 = df_numeric['x_center'].tolist()[node_connection[1]]
                    y2 = df_numeric['y_center'].tolist()[node_connection[1]]
                    v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
                    v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
                    delta_x = x2 - x1
                    delta_y = y2 - y1
                    delta_v_x = v_x1 - v_x2
                    delta_v_y = v_y1 - v_y2
                    if delta_v_x**2+delta_v_y**2 == 0:
                        min_d = math.sqrt(delta_x**2+delta_y**2)
                        t_accident = 0
                    else:
                        t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                        if t_accident < 0:
                            min_d = math.sqrt(delta_x**2+delta_y**2)
                        else:
                            min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
                    if min_d <= min_d_threshold and t_accident <= t_threshold:
                        edge_index[0].append(node_connection[0])
                        edge_index[0].append(node_connection[1])
                        edge_index[1].append(node_connection[1])
                        edge_index[1].append(node_connection[0])
                        edge_features.append([t_accident, min_d])
                        edge_features.append([t_accident, min_d])
                        edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                        edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
                edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
                if pred_frame_id_list.index(pred_frame_id) == 0:
                    y = torch.tensor([0], dtype=torch.long)
                else:
                    y = torch.tensor([1], dtype=torch.long)
                data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                            x = torch.tensor(node_features, dtype=torch.float),
                            edge_attr = edge_features_norm,
                            y = y,
                            frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
                            # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                            obj_ids = torch.tensor(id_list, dtype=torch.long),
                            crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                            )
                pair_seq.append(data)
        sub_class.append(pair_seq)
    group_train.append(sub_class)

train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_subtype0001_scenario00009.txt
train_04_Town04_type001_subtype0001_scenario00033.txt
train_04_Town05_type001_subtype0001_scenario00007.txt
train_09_Town05_type001_subtype0002_scenario00006.txt
train_09_Town03_type001_subtype0002_scenario00027.txt
Mini_ds_Town05_type001_subtype0001_scenario00007.txt
train_01_Town10HD_type001_subtype0002_scenario00005.txt
train_08_Town02_type001_s

In [8]:
dataset_test = []
# test_scene_list += normal_scene_list[ : len(test_scene_list)]
for scene in test_scene_list:
    if scene in normal_scene_list:
        # y = 0
        df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
        crash_frame_id = df[-1:]['frame_id'].item()
        crash_ids = [314159, 314159]
    else:
        # y = 1
        df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
        crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
        crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
    crash_pred_frame_id = crash_frame_id - 1
    if t_prediction+sequence_dist+t_prediction <= crash_pred_frame_id:
        pred_frame_id_list = [crash_pred_frame_id-t_prediction-sequence_dist, crash_pred_frame_id]
    else:
        pred_frame_id_list = [t_prediction-1, crash_pred_frame_id]
    print(scene)
    for pred_frame_id in pred_frame_id_list:
        for frame_id in range(pred_frame_id-t_prediction+1, pred_frame_id+1):
            df_frame = df[df['frame_id'] == frame_id]
            df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
            # x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
            df_numeric_norm = df_numeric / df_numeric.to_numpy().max() 
            id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
            node_idx_list = list(range(len(id_list)))
            df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                        'van',
                                                                                        'truck',
                                                                                        'motorcycle',
                                                                                        'cyclist',
                                                                                        'pedestrian'
                                                                                        ])), dtype=float)
            df_features = pd.concat([df_dummies, df_numeric_norm], axis=1)
            node_features = df_features.to_numpy()
            edge_index = [[],[]]
            edge_features = []
            edge_id = []
            for node_connection in it.combinations(node_idx_list, 2):
                x1 = df_numeric['x_center'].tolist()[node_connection[0]]
                y1 = df_numeric['y_center'].tolist()[node_connection[0]]
                v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
                v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
                x2 = df_numeric['x_center'].tolist()[node_connection[1]]
                y2 = df_numeric['y_center'].tolist()[node_connection[1]]
                v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
                v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
                delta_x = x2 - x1
                delta_y = y2 - y1
                delta_v_x = v_x1 - v_x2
                delta_v_y = v_y1 - v_y2
                if delta_v_x**2+delta_v_y**2 == 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                    t_accident = 0
                else:
                    t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                    if t_accident < 0:
                        min_d = math.sqrt(delta_x**2+delta_y**2)
                    else:
                        min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
                if min_d <= min_d_threshold and t_accident <= t_threshold:
                    edge_index[0].append(node_connection[0])
                    edge_index[0].append(node_connection[1])
                    edge_index[1].append(node_connection[1])
                    edge_index[1].append(node_connection[0])
                    edge_features.append([t_accident, min_d])
                    edge_features.append([t_accident, min_d])
                    edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                    edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
            edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
            if pred_frame_id_list.index(pred_frame_id) == 0:
                y = torch.tensor([0], dtype=torch.long)
            else:
                y = torch.tensor([1], dtype=torch.long)
            data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                        x = torch.tensor(node_features, dtype=torch.float),
                        edge_attr = edge_features_norm,
                        y = y,
                        frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
                        # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                        obj_ids = torch.tensor(id_list, dtype=torch.long),
                        crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                        )
            dataset_test.append(data)

val_01_Town03_type001_subtype0002_scenario00030.txt
train_09_Town03_type001_subtype0001_scenario00014.txt
train_01_Town05_type001_subtype0002_scenario00019.txt
train_04_Town07_type001_subtype0001_scenario00010.txt
train_01_Town05_type001_subtype0002_scenario00012.txt
train_03_Town05_type001_subtype0001_scenario00020.txt
train_06_Town04_type001_subtype0002_scenario00022.txt
val_01_Town04_type001_subtype0002_scenario00025.txt
train_06_Town05_type001_subtype0002_scenario00002.txt
train_03_Town03_type001_subtype0001_scenario00015.txt
train_05_Town05_type001_subtype0001_scenario00005.txt
train_05_Town03_type001_subtype0002_scenario00010.txt
train_02_Town10HD_type001_subtype0001_scenario00021.txt
train_09_Town04_type001_subtype0001_scenario00013.txt
val_01_Town10HD_type001_subtype0001_scenario00020.txt
train_10_Town07_type001_subtype0001_scenario00019.txt
train_04_Town05_type001_subtype0002_scenario00028.txt
train_07_Town04_type001_subtype0001_scenario00005.txt
train_03_Town03_type001_subtyp

In [5]:
# create a list to store all .csv files
train_scene_dir_list = ['../csv_data/train/accident/', '../csv_data/train/normal/']
dataset_train = []
for train_scene_dir in train_scene_dir_list:
  all_scenes = []
  for filename in os.listdir(train_scene_dir):
      if filename.endswith('csv'):
          all_scenes.append(filename)
  for scene in all_scenes:
    df = pd.read_csv(train_scene_dir+scene)
    frame_id_list = list(set(df['frame_id'].tolist()))
    if train_scene_dir == '../csv_data/train/accident/':
      crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
      crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist()
    else:
      crash_frame_id = max(frame_id_list)
      crash_ids = [314159, 314159]
    for frame_id in frame_id_list:
      df_frame = df[df['frame_id'] == frame_id]
      df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                               'van',
                                                                                               'truck',
                                                                                               'motorcycle',
                                                                                               'cyclist',
                                                                                               'pedestrian'])), dtype=float)
      df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
      normalized_df_numeric=(df_numeric - df_numeric.mean()) / df_numeric.std()
      df_features = pd.concat([df_dummies, normalized_df_numeric], axis=1)
      node_features = df_features.to_numpy()
      node_idx_list = list(range(node_features.shape[0]))
      edge_index = [[],[]]
      edge_features = []
      for node_connection in it.combinations(node_idx_list, 2):
        x1 = df_numeric['x_center'].tolist()[node_connection[0]]
        y1 = df_numeric['y_center'].tolist()[node_connection[0]]
        v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
        v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
        x2 = df_numeric['x_center'].tolist()[node_connection[1]]
        y2 = df_numeric['y_center'].tolist()[node_connection[1]]
        v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
        v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
        delta_x = x2 - x1
        delta_y = y2 - y1
        delta_v_x = v_x2 - v_x1
        delta_v_y = v_y2 - v_y1
        if delta_v_x**2+delta_v_y**2 == 0:
          min_d = math.sqrt(delta_x**2+delta_y**2)
          t_accident = -1
        else:
          min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x)/math.sqrt(delta_v_x**2+delta_v_y**2)) # absolute distance of smallest distance
          t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2) # if t_accident is negative, they will never crash.
        if min_d < min_d_threshold and t_accident < t_threshold and t_accident > 0:
          edge_index[0].append(node_connection[0])
          edge_index[0].append(node_connection[1])
          edge_index[1].append(node_connection[1])
          edge_index[1].append(node_connection[0])
          edge_features.append([delta_v_x/delta_x, delta_v_y/delta_y])
          edge_features.append([delta_v_x/delta_x, delta_v_y/delta_y])
      edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)

      obj_id_list = df_frame['obj_id'].to_list()

      if len(edge_index[0]) != 0:
        if frame_id >= crash_frame_id - t_prediction or frame_id < crash_frame_id - t_prediction - 30:
          if crash_ids[0] in obj_id_list and crash_ids[1] in obj_id_list:
            y = torch.tensor([1], dtype=torch.long)
          else:
            y = torch.tensor([0], dtype=torch.long)
          data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                      x = torch.tensor(node_features, dtype=torch.float),
                      edge_attr = edge_features_norm,
                      y = y,
                      frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
                      # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                      obj_ids = torch.tensor(obj_id_list, dtype=torch.long),
                      crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                      )
          dataset_train.append(data)
    print(round(all_scenes.index(scene)/len(all_scenes), 2)*100, '% done.')
    clear_output(wait=True)
print('training dataset completed!')

############################################################################################################################
# test dataset
test_scene_dir_list = ['../csv_data/test/accident/', '../csv_data/test/normal/']
dataset_test = []
for train_scene_dir in test_scene_dir_list:
  all_scenes = []
  for filename in os.listdir(train_scene_dir):
      if filename.endswith('csv'):
          all_scenes.append(filename)
  for scene in all_scenes:
    df = pd.read_csv(train_scene_dir+scene)
    frame_id_list = list(set(df['frame_id'].tolist()))
    if train_scene_dir == '../csv_data/test/accident/':
      crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
      crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist()
    else:
      crash_frame_id = max(frame_id_list)
      crash_ids = [314159, 314159]
    for frame_id in frame_id_list:
      df_frame = df[df['frame_id'] == frame_id]
      df_dummies = pd.get_dummies(df_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                               'van',
                                                                                               'truck',
                                                                                               'motorcycle',
                                                                                               'cyclist',
                                                                                               'pedestrian'])), dtype=float)
      df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
      normalized_df_numeric=(df_numeric - df_numeric.mean()) / df_numeric.std()
      df_features = pd.concat([df_dummies, normalized_df_numeric], axis=1)
      node_features = df_features.to_numpy()
      node_idx_list = list(range(node_features.shape[0]))
      edge_index = [[],[]]
      edge_features = []
      for node_connection in it.combinations(node_idx_list, 2):
        x1 = df_numeric['x_center'].tolist()[node_connection[0]]
        y1 = df_numeric['y_center'].tolist()[node_connection[0]]
        v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
        v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
        x2 = df_numeric['x_center'].tolist()[node_connection[1]]
        y2 = df_numeric['y_center'].tolist()[node_connection[1]]
        v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
        v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
        delta_x = x2 - x1
        delta_y = y2 - y1
        delta_v_x = v_x2 - v_x1
        delta_v_y = v_y2 - v_y1
        if delta_v_x**2+delta_v_y**2 == 0:
          min_d = math.sqrt(delta_x**2+delta_y**2)
          t_accident = -1
        else:
          min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x)/math.sqrt(delta_v_x**2+delta_v_y**2)) # absolute distance of smallest distance
          t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2) # if t_accident is negative, they will never crash.
        if min_d < min_d_threshold and t_accident < t_threshold and t_accident > 0:
          edge_index[0].append(node_connection[0])
          edge_index[0].append(node_connection[1])
          edge_index[1].append(node_connection[1])
          edge_index[1].append(node_connection[0])
          edge_features.append([delta_v_x/delta_x, delta_v_y/delta_y])
          edge_features.append([delta_v_x/delta_x, delta_v_y/delta_y])
      edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)

      obj_id_list = df_frame['obj_id'].to_list()

      if len(edge_index[0]) != 0:
        if frame_id >= crash_frame_id - t_prediction or frame_id < crash_frame_id - t_prediction - 30:
          if crash_ids[0] in obj_id_list and crash_ids[1] in obj_id_list:
            y = torch.tensor([1], dtype=torch.long)
          else:
            y = torch.tensor([0], dtype=torch.long)
          data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                      x = torch.tensor(node_features, dtype=torch.float),
                      edge_attr = edge_features_norm,
                      y = y,
                      frame_to_crash = torch.tensor(crash_frame_id-frame_id, dtype=torch.float),
                      # crash_node_idx = torch.tensor([crash_node_idx], dtype=torch.long)
                      obj_ids = torch.tensor(obj_id_list, dtype=torch.long),
                      crash_ids = torch.tensor(crash_ids, dtype=torch.long)
                      )
          dataset_test.append(data)
    print(round(all_scenes.index(scene)/len(all_scenes), 2)*100, '% done.')
    clear_output(wait=True)
print('testing dataset completed!')

testing dataset completed!


### Classifier Training

In [9]:
args = {
    "device" : torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
    "save" : "Single_Graph_BC_Reg_model",
    "load" : "ssl_model",
    "lr" : 0.0003,
    "lr_reg" : 0.00001,
    "epochs" : 200,
    "batch_size" : 64,
    "num_workers" : 2,
    "model" : "gat", # choices are ["gcn", "gin", "resgcn", "gat", "graphsage", "sgc"]
    "input_dim" : 12, # input dimension for node features
    # "num_classes" : len(all_names), # collision
    "num_classes" : 1,
    "feat_dim" : 8,
    "layers" : 3,
    "train_data_percent" : 0.8,
}

class AttributeDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttributeDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = AttributeDict(args)

In [11]:
# train_dataset = dataset_train
grouped_train_dataset = group_train
val_dataset = dataset_test
test_dataset = dataset_test
# train_loader = DataLoader(train_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True)

# print("Dataset split: {} {} {}".format(len(train_dataset), len(val_dataset), len(val_dataset)))
print("Number of classes: {}".format(args.num_classes))

Number of classes: 1


In [14]:
import importlib, models_Single_Graph_BC_Reg
importlib.reload(models_Single_Graph_BC_Reg)
from models_Single_Graph_BC_Reg import *

# classification model is a GNN encoder followed by linear layer
model = GraphClassificationModel(args.input_dim, args.feat_dim, n_layers=args.layers, gnn=args.model, output_dim=args.num_classes
                                # , load=args.load
                                 )

for param in model.parameters():
    param.requires_grad = True

model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer_reg = torch.optim.Adam(model.parameters(), lr=args.lr_reg)

In [15]:
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_mse = torch.nn.MSELoss()

best_train_loss, best_val_loss = float("inf"), float("inf")
is_best_loss = False
best_val_acc = 0
for epoch in range(args.epochs):
    losses = []
    edge_losses = []
    reg_losses = []
    correct = 0
    edge_correct = 0
    frame_error = 0

    for acc_type in grouped_train_dataset:
        random.shuffle(acc_type)
    train_dataset = []
    for i in range(major_acc_num):
        for acc_type in grouped_train_dataset:
            train_dataset += acc_type[i]
    flattened = train_dataset   # (15+15)*36*8
    train_loader = DataLoader(flattened, num_workers=args.num_workers, batch_size=(15+15)*8, shuffle=False)

    for data in train_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr
        labels = data.y
        
        regression_score = model(data_input)[1][data.y == 1]
        regression_gt = data.frame_to_crash[data.y == 1]
        loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/(t_prediction*10)
        frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
        reg_losses.append(loss_reg.item())
        optimizer.zero_grad()
        loss_reg.backward()
        optimizer.step()

        gt = labels.float()
        scores = model(data_input)[0].squeeze(1)
        loss = loss_fn(scores, gt)
        pred = []
        for pr in scores:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        correct += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
        losses.append(loss.item())

        # backprop
        optimizer.zero_grad()
        # (loss+loss_reg).backward()
        loss.backward()
        # loss_edge.backward()
        # loss_reg.backward()
        optimizer.step()


    epoch_loss_train = sum(losses) / len(losses)
    # epoch_edge_loss_train = sum(edge_losses) / len(edge_losses)
    epoch_reg_loss_train = sum(reg_losses) / len(reg_losses)
    epoch_acc_train = correct / len(train_loader)
    # epoch_edge_acc_train = edge_correct / len(train_loader)
    epoch_reg_error_train = frame_error / len(train_loader)


    # validation
    losses = []
    edge_losses = []
    reg_losses = []
    correct = 0
    edge_correct = 0
    frame_error = 0

    for data in val_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr
        labels = data.y
        gt = labels.float()
        scores = model(data_input)[0].squeeze(1)
        loss = loss_fn(scores, gt)
        pred = []
        for pr in scores:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        correct += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
        losses.append(loss.item())
        regression_score = model(data_input)[1][data.y == 1]
        regression_gt = data.frame_to_crash[data.y == 1]
        loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/(t_prediction*10)
        frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
        reg_losses.append(loss_reg.item())


    # gather the results for the epoch
    epoch_loss_val = sum(losses) / len(losses)
    epoch_acc_val = correct / len(val_loader)
    # epoch_edge_loss_val = sum(edge_losses) / len(edge_losses)
    # epoch_edge_acc_val = edge_correct / len(val_loader)
    epoch_reg_loss_val = sum(reg_losses) / len(reg_losses)
    epoch_reg_error_val = frame_error / len(val_loader)


    log = "Epoch {},   Train Loss: {:.3f},            Train Accuracy: {:.3f},            Val Loss: {:.3f}, Val Accuracy: {:.3f}, \nTrain Regression Loss: {:.3f}, Train Regression Error: {:.3f}, Val Regression Loss: {:.3f}, Val Regression Error: {:.3f}\n"
    print(log.format(epoch, 
                     epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val, 
                     epoch_reg_loss_train, epoch_reg_error_train, epoch_reg_loss_val, epoch_reg_error_val))

    # log = "Epoch {},   Train Loss: {:.3f},            Train Accuracy: {:.3f},            Val Loss: {:.3f}, Val Accuracy: {:.3f}, \n      Train Edge Loss: {:.3f},       Train Edge Accuracy: {:.3f},       Val Edge Loss: {:.3f}, Val Edge Accuracy: {:.3f},\nTrain Regression Loss: {:.3f}, Train Regression Accuracy: {:.3f}, Val Regression Loss: {:.3f}, Val Regression Accuracy: {:.3f}\n"
    # print(log.format(epoch, 
    #                  epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val, 
    #                  epoch_edge_loss_train, epoch_edge_acc_train, epoch_edge_loss_val, epoch_edge_acc_val,
    #                  epoch_reg_loss_train, epoch_reg_error_train, epoch_reg_loss_val, epoch_reg_error_val))

    if epoch_acc_val > best_val_acc:
        best_epoch, best_train_graph_loss, best_val_loss, is_best_loss = epoch, epoch_reg_loss_train, epoch_reg_loss_val, True
        best_train_acc, best_val_acc = epoch_acc_train, epoch_acc_val

# model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_graph_loss, best_val_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_graph_loss, ', best total val loss: ', best_val_loss)
print('graph classification train acc: ', epoch_acc_train, 'graph classification val acc: ', epoch_acc_val)
# print('edge classification train acc: ', epoch_edge_acc_train, 'edge classification val acc: ', epoch_edge_acc_val)
print('regression train error: ', epoch_reg_error_train, 'regression test error: ', epoch_reg_error_val)

Epoch 0,   Train Loss: 0.699,            Train Accuracy: 0.520,            Val Loss: 0.689, Val Accuracy: 0.526, 
Train Regression Loss: 0.059, Train Regression Error: 7.717, Val Regression Loss: 0.048, Val Regression Error: 5.998

Epoch 1,   Train Loss: 0.681,            Train Accuracy: 0.575,            Val Loss: 0.680, Val Accuracy: 0.579, 
Train Regression Loss: 0.042, Train Regression Error: 5.219, Val Regression Loss: 0.034, Val Regression Error: 4.285

Epoch 2,   Train Loss: 0.672,            Train Accuracy: 0.577,            Val Loss: 0.685, Val Accuracy: 0.529, 
Train Regression Loss: 0.035, Train Regression Error: 4.315, Val Regression Loss: 0.033, Val Regression Error: 4.155

Epoch 3,   Train Loss: 0.662,            Train Accuracy: 0.604,            Val Loss: 0.679, Val Accuracy: 0.551, 
Train Regression Loss: 0.034, Train Regression Error: 4.273, Val Regression Loss: 0.033, Val Regression Error: 4.105

Epoch 4,   Train Loss: 0.654,            Train Accuracy: 0.611,         

best Epoch:  173 , best total train loss:  0.027313787231428757 , best total val loss:  0.0294601630200358
graph classification train acc:  0.8805555555555555 graph classification val acc:  0.7162047511312217
regression train error:  3.4533276381316007 regression test error:  3.7947435231917686

Epoch 49,   Train Loss: 0.416,            Train Accuracy: 0.835,            Val Loss: 0.696, Val Accuracy: 0.680, 
Train Regression Loss: 0.030, Train Regression Error: 3.824, Val Regression Loss: 0.029, Val Regression Error: 3.818

Epoch 89,   Train Loss: 0.270,            Train Accuracy: 0.914,            Val Loss: 0.853, Val Accuracy: 0.669, 
Train Regression Loss: 0.029, Train Regression Error: 3.717, Val Regression Loss: 0.029, Val Regression Error: 3.754

In [None]:
model.train()
loss_fn = torch.nn.BCEWithLogitsLoss()
loss_mse = torch.nn.MSELoss()

for epoch in range(args.epochs):
    best_train_loss, best_val_loss = float("inf"), float("inf")
    losses = []
    correct = 0

    for data in train_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr

        labels = data.y
        gt = []
        for label in labels:
            if label == 0:
                gt.append([1,0])
            else:
                gt.append([0,1])
        gt = torch.tensor(gt, dtype=torch.float32).to(args.device)
        scores = model(data_input)[0]
        loss = loss_fn(scores, gt)
        pred = scores.argmax(dim=1)
        correct += int((pred == labels).sum()) / labels.size(0)
        losses.append(loss.item())
 
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss_train = sum(losses) / len(losses)
    epoch_acc_train = correct / len(train_loader)

    # validation
    losses = []
    correct = 0

    for data in val_loader:
        data.to(args.device)

        data_input = data.x, data.edge_index, data.batch, data.edge_attr

        labels = data.y

        gt = []
        for label in labels:
            if label == 0:
                gt.append([1,0])
            else:
                gt.append([0,1])
        gt = torch.tensor(gt, dtype=torch.float32).to(args.device)


        # get class scores from model
        scores = model(data_input)[0]

        loss = loss_fn(scores, gt)

        # Keep track of loss and accuracy
        pred = scores.argmax(dim=1)
        correct += int((pred == labels).sum()) / labels.size(0)
        losses.append(loss.item())

    # gather the results for the epoch
    epoch_loss_val = sum(losses) / len(losses)
    epoch_acc_val = correct / len(val_loader)

    log = "Epoch {},   Train Loss: {:.3f},            Train Accuracy: {:.3f},            Val Loss: {:.3f}, Val Accuracy: {:.3f}"
    print(log.format(epoch, 
                     epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val))

    is_best_loss = False
    if epoch_loss_val < best_val_loss:
        best_epoch, best_train_graph_loss, best_val_loss, is_best_loss = epoch, epoch_loss_train, epoch_loss_val, True

model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_graph_loss, best_val_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_graph_loss, ', best total val loss: ', best_val_loss)

In [42]:
best_epoch, best_train_loss, best_val_loss = model.load_checkpoint(os.path.join("logs", args.save), optimizer)
model.eval()

losses = []
edge_losses = []
reg_losses = []
correct = 0
edge_correct = 0
frame_error = 0

for data in test_loader:
    data.to(args.device)

    data_input = data.x, data.edge_index, data.batch, data.edge_attr

    labels = data.y

    gt = []
    for label in labels:
        if label == 0:
            gt.append([1,0])
        else:
            gt.append([0,1])
    gt = torch.tensor(gt, dtype=torch.float32).to(args.device)


    # get class scores from model
    scores = model(data_input)[0]

    loss = loss_fn(scores, gt)

    # Keep track of loss and accuracy
    pred = scores.argmax(dim=1)
    correct += int((pred == labels).sum()) / labels.size(0)
    losses.append(loss.item())


    batch_node_label = []
    for data_point in data.batch:
        if data.y[data_point] == 0:
            batch_node_label.append(False)
        else:
            batch_node_label.append(True)
    batch_node_label = torch.tensor(batch_node_label, dtype=torch.bool).to(args.device)
    batch_edge_label = batch_node_label[data.edge_index[0]]
    edge_score = model(data_input)[1][batch_edge_label]
    gt_edge = []
    batch_edge_index = data.obj_ids[data.edge_index[0]][batch_edge_label]
    for i in range(edge_score.size(0)):
        if batch_edge_index[i] in data.crash_ids:
            gt_edge.append([0,1])
        else:
            gt_edge.append([1,0])
    gt_edge = torch.tensor(gt_edge, dtype=torch.float32).to(args.device)
    loss_edge = loss_fn(edge_score, gt_edge)
    edge_losses.append(loss_edge.item())
    pred_edge = edge_score.argmax(dim=1)
    labels_edge = gt_edge.argmax(dim=1)
    edge_correct += int((pred_edge == labels_edge).sum()) / labels_edge.size(0)


    regression_score = model(data_input)[2][data.y == 1]
    regression_gt = data.frame_to_crash[data.y == 1]
    loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/t_prediction
    frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
    reg_losses.append(loss_reg.item())

# gather the results for the epoch
loss_test = sum(losses) / len(losses)
correct_test = correct / len(test_loader)
epoch_edge_loss_test = sum(edge_losses) / len(edge_losses)
epoch_edge_acc_test = edge_correct / len(test_loader)
epoch_reg_loss_test = sum(reg_losses) / len(reg_losses)
epoch_reg_error_test = frame_error / len(test_loader)

print("Test Loss at epoch {}: {:.3f}, Test Accuracy: {:.3f}, \nTest Edge Loss: {:.3f}, Test Edge Accuracy: {:.3f} \nTest Regression Loss: {:.3f}, Test Regression Error: {:.3f}".format(best_epoch, loss_test, correct_test, epoch_edge_loss_test, epoch_edge_acc_test, epoch_reg_loss_test, epoch_reg_error_test))

Test Loss at epoch 99: 2.376, Test Accuracy: 0.720, 
Test Edge Loss: 0.604, Test Edge Accuracy: 0.821 
Test Regression Loss: 0.301, Test Regression Error: 7.256
