In [1]:
import torch
import random
import numpy as np
from itertools import chain
import torch.nn as nn
import itertools as it
import os, shutil
import pandas as pd
from pandas.core.common import flatten
from scipy.interpolate import make_interp_spline, interp1d
import matplotlib.pyplot as plt
from IPython.display import clear_output
import math
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
%cd GraphSSL

/home/shihan/GAT/GraphSSL


### Customized Dataset

In [21]:
random.seed(42)
t_threshold = 0.5  # in seconds
min_d_threshold = 10  # in meters
t_prediction = 15 # in frames
# if t_prediction == 5:
#     lr = 0.0003
# elif t_prediction == 10 or t_prediction == 15:
#     lr = 0.0005
sequence_len = 4 # in frames
sequence_dist = 40 # in frames

In [22]:
info_dir = '../csv_final/info/'   # directory of 'info' folder
all_info = []
for filename in os.listdir(info_dir):
    if filename.endswith('txt'):
        all_info.append(filename)
pedestrian = []
rear_end = []
switch_lane = [] # switch_lane
t_bone = []
opposite_front = []
opposite_merge = []
merging = []
other = []
for info in all_info:
    txt_file = open(info_dir+info, "r").read().split('\n') # split txt file with lines
    first_cls = txt_file[2].split(': ')[-1]
    second_cls = txt_file[9].split(': ')[-1]
    first_fault = txt_file[7].split(' ')[-1]
    second_fault = txt_file[14].split(' ')[-1]
    first_direction = txt_file[5].split(' ')[-1]
    second_direction = txt_file[12].split(' ')[-1]
    first_collision_side = txt_file[6].split(' ')[-1]
    second_collision_side = txt_file[13].split(' ')[-1]
    second_relative_direction_to_first = txt_file[15].split(' ')[-1]
    if first_cls == 'pedestrian' or second_cls == 'pedestrian':
        pedestrian.append(info)
    elif first_fault == 'rear_end' or second_fault == 'rear_end':
        rear_end.append(info)
    elif first_fault == 'switch_lane' or second_fault == 'switch_lane':
        switch_lane.append(info)
    elif first_direction == 'straight' and second_direction == 'straight':
        t_bone.append(info)
    elif second_relative_direction_to_first == 'opposite' and 'straight' in [first_direction, second_direction]:
        opposite_front.append(info)
    elif second_relative_direction_to_first == 'opposite':
        opposite_merge.append(info)
    elif first_fault == 'entrance_ramp' or second_fault == 'entrance_ramp':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='left' and second_direction=='straight':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='straight' and second_direction=='left':
        merging.append(info)
    elif second_relative_direction_to_first=='right' and first_direction=='straight' and second_direction=='right':
        merging.append(info)
    elif second_relative_direction_to_first=='left' and first_direction=='right' and second_direction=='straight':
        merging.append(info)
    else:
        other.append(info)
random.shuffle(pedestrian)
random.shuffle(rear_end)
random.shuffle(switch_lane)
random.shuffle(t_bone)
random.shuffle(opposite_front)
random.shuffle(opposite_merge)
random.shuffle(merging)
random.shuffle(other)
acc_all = {'pedestrian':len(pedestrian), 'rear_end':len(rear_end),
        'switch_lane':len(switch_lane), 't_bone':len(t_bone),
        'opposite_front':len(opposite_front), 'opposite_merge': len(opposite_merge),
        'merging':len(merging), 'other':len(other),
        'total': len(pedestrian+rear_end+switch_lane+t_bone+opposite_front+opposite_merge+merging+other)}
print(acc_all)
acc_train = {'pedestrian':math.ceil(len(pedestrian)*0.8), 'rear_end':math.ceil(len(rear_end)*0.8),
        'switch_lane':math.ceil(len(switch_lane)*0.8), 't_bone':math.ceil(len(t_bone)*0.8),
        'opposite_front':math.ceil(len(opposite_front)*0.8), 'opposite_merge': math.ceil(len(opposite_merge)*0.8),
        'merging':math.ceil(len(merging)*0.8), 'other':math.ceil(len(other)*0.8)}
acc_test = {'pedestrian':int(len(pedestrian)/5), 'rear_end':int(len(rear_end)/5),
        'switch_lane':int(len(switch_lane)/5), 't_bone':int(len(t_bone)/5),
        'opposite_front':int(len(opposite_front)/5), 'opposite_merge': int(len(opposite_merge)/5),
        'merging':int(len(merging)/5), 'other':int(len(other)/5)
        }
print('training dataset:')
print(acc_train)
print('testing dataset:')
print(acc_test)

{'pedestrian': 8, 'rear_end': 19, 'switch_lane': 13, 't_bone': 31, 'opposite_front': 22, 'opposite_merge': 19, 'merging': 40, 'other': 44, 'total': 196}
training dataset:
{'pedestrian': 7, 'rear_end': 16, 'switch_lane': 11, 't_bone': 25, 'opposite_front': 18, 'opposite_merge': 16, 'merging': 32, 'other': 36}
testing dataset:
{'pedestrian': 1, 'rear_end': 3, 'switch_lane': 2, 't_bone': 6, 'opposite_front': 4, 'opposite_merge': 3, 'merging': 8, 'other': 8}


In [24]:
accident_dir = '../csv_final/accident/'
normal_dir = '../csv_final/normal/'
train_scene_list = []
test_scene_list = []
accident_names = ['pedestrian','rear_end','switch_lane','t_bone','opposite_front','opposite_merge','merging','other']
all_names = accident_names+['no_accident']
major_acc_num = 36
accident_type_all = [pedestrian,rear_end,switch_lane,t_bone,opposite_front,opposite_merge,merging,other]
balanced_all_acc_train = []
for accident_type in accident_type_all:
    random.shuffle(accident_type)
    test_scene_list += accident_type[ : int(len(accident_type)/5)]
    train_num_accident_type = math.ceil(len(accident_type)*0.8)
    minor_class_multiplier = int(major_acc_num/train_num_accident_type)
    balanced_all_acc_train.append(accident_type[int(len(accident_type)/5):]*minor_class_multiplier 
                                  + accident_type[int(len(accident_type)/5):][:major_acc_num-math.ceil(len(accident_type)*0.8)*minor_class_multiplier])
for i in range(major_acc_num):
    for accident_type in balanced_all_acc_train:
        train_scene_list.append(accident_type[i])
#     train_scene_list += accident_type[int(len(accident_type)/5) : ]
normal_scene_list = []
for filename in os.listdir(normal_dir):
    if filename.endswith('csv'):
        normal_scene_list.append(filename)
random.shuffle(normal_scene_list)
# train_scene_list += normal_scene_list[len(test_scene_list) : len(test_scene_list)+len(train_scene_list)]
# test_scene_list += normal_scene_list[ : len(test_scene_list)]

In [25]:
grouped_train = []  # (8, 36, 2*sequence_len)
for accident_class in balanced_all_acc_train:
    sub_class = []
    for scene in accident_class:
        if scene in normal_scene_list:
            y = 'no_accident'
            df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
            crash_frame_id = df[-1:]['frame_id'].item()
            crash_ids = [314159, 314159]
        else:
            for accident_type in accident_type_all:
                if scene in accident_type:
                    y = accident_names[accident_type_all.index(accident_type)]
            df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
            crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
            crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
        frame_id_list = list(set(df['frame_id'].tolist()))  # all frame ids in a list
        crash_pred_frame_id = crash_frame_id - t_prediction
        # print(scene)
        edge_count = 0
        if sequence_len+sequence_dist+sequence_len <= crash_pred_frame_id:
            pred_frame_id_list = [crash_pred_frame_id-sequence_len-sequence_dist, crash_pred_frame_id]
        else:
            pred_frame_id_list = [sequence_len, crash_pred_frame_id]
        pair_seq = []
        for pred_frame_id in pred_frame_id_list:
            df_pred_frame = df[df['frame_id'] == pred_frame_id]
            pred_frame_id_list = df_pred_frame['obj_id'].tolist()  # all node ids in predicted frame
            sq_x_numeric = torch.tensor([]) # (# of nodes, 4, 6)
            for frame_id in list(range(pred_frame_id-sequence_len+1, pred_frame_id+1)):
                df_frame = df[df['frame_id'] == frame_id]
                df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
                x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
                id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
                frame_x_numeric = torch.tensor([])  # (# of nodes, 6)
                for node_id in pred_frame_id_list:
                    if node_id in id_list:
                        frame_x_numeric = torch.cat([frame_x_numeric, x_numeric[id_list.index(node_id)].unsqueeze(0)], 0)
                    else:
                        frame_x_numeric = torch.cat([frame_x_numeric, torch.full([1, 6], torch.nan)], 0)
                sq_x_numeric = torch.cat([sq_x_numeric, frame_x_numeric.unsqueeze(1)], 1)
            sq_x_numeric_normed = sq_x_numeric / torch.max(sq_x_numeric)    # (# of nodes, 4, 6)
            df_dummies = pd.get_dummies(df_pred_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                                'van',
                                                                                                'truck',
                                                                                                'motorcycle',
                                                                                                'cyclist',
                                                                                                'pedestrian'
                                                                                                ])), dtype=float)
            node_idx_list = list(range(len(id_list)))
            edge_index = [[],[]]
            edge_features = []
            edge_id = []
            for node_connection in it.combinations(node_idx_list, 2):
                x1 = df_numeric['x_center'].tolist()[node_connection[0]]
                y1 = df_numeric['y_center'].tolist()[node_connection[0]]
                v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
                v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
                x2 = df_numeric['x_center'].tolist()[node_connection[1]]
                y2 = df_numeric['y_center'].tolist()[node_connection[1]]
                v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
                v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
                delta_x = x2 - x1
                delta_y = y2 - y1
                delta_v_x = v_x1 - v_x2
                delta_v_y = v_y1 - v_y2
                if delta_v_x**2+delta_v_y**2 == 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                    t_accident = 0
                else:
                    t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                    if t_accident < 0:
                        min_d = math.sqrt(delta_x**2+delta_y**2)
                    else:
                        min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
                if min_d <= min_d_threshold and t_accident <= t_threshold:
                    edge_index[0].append(node_connection[0])
                    edge_index[0].append(node_connection[1])
                    edge_index[1].append(node_connection[1])
                    edge_index[1].append(node_connection[0])
                    edge_features.append([t_accident, min_d])
                    edge_features.append([t_accident, min_d])
                    edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                    edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
            edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
            if pred_frame_id == crash_pred_frame_id:
                y = 1
                if crash_ids in edge_id:
                    print('Crash Edge Built!')
                else:
                    print('Crash Edge Not Built!')
            else:
                y = 0
            data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
                x_numeric = sq_x_numeric_normed,
                x = torch.tensor(df_dummies.to_numpy(), dtype=torch.float),    # x_dummies
                edge_attr = edge_features_norm,
                y = torch.tensor([y], dtype=torch.long)
                )
            pair_seq.append(data)
        sub_class.append(pair_seq)
    grouped_train.append(sub_class)

Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Ed

In [26]:
paired_test = []
for scene in test_scene_list:
    if scene in normal_scene_list:
        y = 'no_accident'
        df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
        crash_frame_id = df[-1:]['frame_id'].item()
        crash_ids = [314159, 314159]
    else:
        for accident_type in accident_type_all:
            if scene in accident_type:
                y = accident_names[accident_type_all.index(accident_type)]
        df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
        crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
        crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
    frame_id_list = list(set(df['frame_id'].tolist()))  # all frame ids in a list
    crash_pred_frame_id = crash_frame_id - t_prediction
    # print(scene)
    edge_count = 0
    if sequence_len+sequence_dist+sequence_len <= crash_pred_frame_id:
        pred_frame_id_list = [crash_pred_frame_id-sequence_len-sequence_dist, crash_pred_frame_id]
    else:
        pred_frame_id_list = [sequence_len, crash_pred_frame_id]
    pair_seq = []
    for pred_frame_id in pred_frame_id_list:
        df_pred_frame = df[df['frame_id'] == pred_frame_id]
        pred_frame_id_list = df_pred_frame['obj_id'].tolist()  # all node ids in predicted frame
        sq_x_numeric = torch.tensor([]) # (# of nodes, 4, 6)
        for frame_id in list(range(pred_frame_id-sequence_len+1, pred_frame_id+1)):
            df_frame = df[df['frame_id'] == frame_id]
            df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
            x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
            id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
            frame_x_numeric = torch.tensor([])  # (# of nodes, 6)
            for node_id in pred_frame_id_list:
                if node_id in id_list:
                    frame_x_numeric = torch.cat([frame_x_numeric, x_numeric[id_list.index(node_id)].unsqueeze(0)], 0)
                else:
                    frame_x_numeric = torch.cat([frame_x_numeric, torch.full([1, 6], torch.nan)], 0)
            sq_x_numeric = torch.cat([sq_x_numeric, frame_x_numeric.unsqueeze(1)], 1)
        sq_x_numeric_normed = sq_x_numeric / torch.max(sq_x_numeric)    # (# of nodes, 4, 6)
        df_dummies = pd.get_dummies(df_pred_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                            'van',
                                                                                            'truck',
                                                                                            'motorcycle',
                                                                                            'cyclist',
                                                                                            'pedestrian'
                                                                                            ])), dtype=float)
        node_idx_list = list(range(len(id_list)))
        edge_index = [[],[]]
        edge_features = []
        edge_id = []
        for node_connection in it.combinations(node_idx_list, 2):
            x1 = df_numeric['x_center'].tolist()[node_connection[0]]
            y1 = df_numeric['y_center'].tolist()[node_connection[0]]
            v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
            v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
            x2 = df_numeric['x_center'].tolist()[node_connection[1]]
            y2 = df_numeric['y_center'].tolist()[node_connection[1]]
            v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
            v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
            delta_x = x2 - x1
            delta_y = y2 - y1
            delta_v_x = v_x1 - v_x2
            delta_v_y = v_y1 - v_y2
            if delta_v_x**2+delta_v_y**2 == 0:
                min_d = math.sqrt(delta_x**2+delta_y**2)
                t_accident = 0
            else:
                t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
                if t_accident < 0:
                    min_d = math.sqrt(delta_x**2+delta_y**2)
                else:
                    min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
            if min_d <= min_d_threshold and t_accident <= t_threshold:
                edge_index[0].append(node_connection[0])
                edge_index[0].append(node_connection[1])
                edge_index[1].append(node_connection[1])
                edge_index[1].append(node_connection[0])
                edge_features.append([t_accident, min_d])
                edge_features.append([t_accident, min_d])
                edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
                edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
        edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
        if pred_frame_id == crash_pred_frame_id:
            y = 1
            if crash_ids in edge_id:
                print('Crash Edge Built!')
            else:
                print('Crash Edge Not Built!')
        else:
            y = 0
        data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
            x_numeric = sq_x_numeric_normed,
            x = torch.tensor(df_dummies.to_numpy(), dtype=torch.float),    # x_dummies
            edge_attr = edge_features_norm,
            y = torch.tensor([y], dtype=torch.long)
            )
        pair_seq.append(data)
    paired_test.append(pair_seq)

Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!


In [57]:
dataset_test = []   # (70, 1)
test_scene_list += normal_scene_list[ : len(test_scene_list)]
for scene in test_scene_list:
    if scene in normal_scene_list:
        y = 0
        df = pd.read_csv(normal_dir+scene.replace('.txt','.csv'))  # read the i-th csv file
        crash_frame_id = df[-1:]['frame_id'].item()
        crash_ids = [314159, 314159]
    else:
        y = 1
        df = pd.read_csv(accident_dir+scene.replace('.txt','.csv'))
        crash_frame_id = min(df[df['acc_inv'] == 1]['frame_id'].tolist())
        crash_ids = df[df['acc_inv'] == 1]['obj_id'].tolist() 
    pred_frame_id = crash_frame_id - t_prediction
    # print(scene)
    df_pred_frame = df[df['frame_id'] == pred_frame_id]
    pred_frame_id_list = df_pred_frame['obj_id'].tolist()  # all node ids in predicted frame
    sq_x_numeric = torch.tensor([]) # (# of nodes, 4, 6)
    for frame_id in list(range(pred_frame_id-sequence_len+1, pred_frame_id+1)):
        df_frame = df[df['frame_id'] == frame_id]
        df_numeric = df_frame[['x_center', 'y_center', 'bbox_x', 'bbox_y', 'vel_x', 'vel_y']]
        x_numeric = torch.tensor(df_numeric.to_numpy(), dtype=torch.float)
        id_list = df_frame['obj_id'].tolist()  # all node ids in predicted frame
        frame_x_numeric = torch.tensor([])  # (# of nodes, 6)
        for node_id in pred_frame_id_list:
            if node_id in id_list:
                frame_x_numeric = torch.cat([frame_x_numeric, x_numeric[id_list.index(node_id)].unsqueeze(0)], 0)
            else:
                frame_x_numeric = torch.cat([frame_x_numeric, torch.full([1, 6], torch.nan)], 0)
        sq_x_numeric = torch.cat([sq_x_numeric, frame_x_numeric.unsqueeze(1)], 1)
    sq_x_numeric_normed = sq_x_numeric / torch.max(sq_x_numeric)    # (# of nodes, 4, 6)
    df_dummies = pd.get_dummies(df_pred_frame[['obj_cls']].astype(pd.CategoricalDtype(categories=['car',
                                                                                        'van',
                                                                                        'truck',
                                                                                        'motorcycle',
                                                                                        'cyclist',
                                                                                        'pedestrian'
                                                                                        ])), dtype=float)
    node_idx_list = list(range(len(id_list)))
    edge_index = [[],[]]
    edge_features = []
    edge_id = []
    for node_connection in it.combinations(node_idx_list, 2):
        x1 = df_numeric['x_center'].tolist()[node_connection[0]]
        y1 = df_numeric['y_center'].tolist()[node_connection[0]]
        v_x1 = df_numeric['vel_x'].tolist()[node_connection[0]]
        v_y1 = df_numeric['vel_y'].tolist()[node_connection[0]]
        x2 = df_numeric['x_center'].tolist()[node_connection[1]]
        y2 = df_numeric['y_center'].tolist()[node_connection[1]]
        v_x2 = df_numeric['vel_x'].tolist()[node_connection[1]]
        v_y2 = df_numeric['vel_y'].tolist()[node_connection[1]]
        delta_x = x2 - x1
        delta_y = y2 - y1
        delta_v_x = v_x1 - v_x2
        delta_v_y = v_y1 - v_y2
        if delta_v_x**2+delta_v_y**2 == 0:
            min_d = math.sqrt(delta_x**2+delta_y**2)
            t_accident = 0
        else:
            t_accident = (delta_x*delta_v_x+delta_y*delta_v_y)/(delta_v_x**2+delta_v_y**2)
            if t_accident < 0:
                min_d = math.sqrt(delta_x**2+delta_y**2)
            else:
                min_d = abs((delta_x*delta_v_y-delta_y*delta_v_x))/math.sqrt(delta_v_x**2+delta_v_y**2)
        if min_d <= min_d_threshold:
            edge_index[0].append(node_connection[0])
            edge_index[0].append(node_connection[1])
            edge_index[1].append(node_connection[1])
            edge_index[1].append(node_connection[0])
            edge_features.append([t_accident, min_d])
            edge_features.append([t_accident, min_d])
            edge_id.append([id_list[node_connection[0]], id_list[node_connection[1]]])
            edge_id.append([id_list[node_connection[1]], id_list[node_connection[0]]])
    edge_features_norm = nn.functional.normalize(torch.tensor(edge_features, dtype=torch.float), p=1.0, dim = 0)
    data = Data(edge_index = torch.tensor(edge_index, dtype=torch.long),
        x_numeric = sq_x_numeric_normed,
        x = torch.tensor(df_dummies.to_numpy(), dtype=torch.float),    # x_dummies
        edge_attr = edge_features_norm,
        y = torch.tensor([y], dtype=torch.long),
        )
    if crash_ids in edge_id:
        print('Crash Edge Built!')
    else:
        print('Crash Edge Not Built!')
    dataset_test.append(data)

Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Built!
Crash Edge Not Bui

### Classifier Training

In [27]:
args = {
    "device" : torch.device("cuda:3" if torch.cuda.is_available() else "cpu"),
    "save" : "LSTM_Node_GAT_model",
    "load" : "ssl_model",
    "lr" : 0.0004,
    "lr_edge" : 0.001,
    "epochs" : 200,
    "batch_size" : 8,
    "num_workers" : 2,
    "model" : "gat", # choices are ["gcn", "gin", "resgcn", "gat", "graphsage", "sgc"]
    "input_dim" : 12, # input dimension for node features
    # "num_classes" : len(all_names), # collision
    "num_classes" : 1,
    "feat_dim" : 8,
    "layers" : 3,
    "train_data_percent" : 0.8,
}

class AttributeDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttributeDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = AttributeDict(args)

In [28]:
grouped_train_dataset = grouped_train
val_dataset = list(chain.from_iterable(paired_test))
test_dataset = list(chain.from_iterable(paired_test))
# val_dataset = dataset_test
# test_dataset = test_dataset
val_loader = DataLoader(val_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

print("Dataset split: {} {} {}".format(str(len(grouped_train)*len(grouped_train[0]))+'*2', len(val_dataset), len(val_dataset)))
print("Number of classes: {}".format(args.num_classes))

Dataset split: 288*2 70 70
Number of classes: 1


In [29]:
import importlib, models_LSTM_Node_GAT_June22
importlib.reload(models_LSTM_Node_GAT_June22)

from models_LSTM_Node_GAT_June22 import *
# classification model is a GNN encoder followed by linear layer
model = GraphClassificationModel(args.input_dim, 
                                 args.feat_dim, 
                                 n_layers=args.layers, 
                                 gnn=args.model, 
                                 output_dim=args.num_classes,
                                 sequence_len = sequence_len
                                # , load=args.load
                                 )

for param in model.parameters():
    param.requires_grad = True

model = model.to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

In [30]:
import warnings
warnings.filterwarnings("ignore")

model.train()
loss_fn = torch.nn.CrossEntropyLoss()
loss_BCE = torch.nn.BCEWithLogitsLoss()

is_best_loss = False
best_epoch, best_train_loss, best_val_loss = 0, float("inf"), float("inf")
best_train_acc, best_val_acc = 0.0, 0.0
for epoch in range(args.epochs):
    losses = []
    edge_losses = []
    correct = 0
    edge_correct = 0
    for acc_type in grouped_train_dataset:
        random.shuffle(acc_type)
    train_dataset = []
    for i in range(major_acc_num):
        for acc_type in grouped_train_dataset:
            train_dataset += acc_type[i]
    flattened = train_dataset   # (1+1)*batch_size*36
    train_loader = DataLoader(flattened, num_workers=args.num_workers, batch_size=2*args.batch_size, shuffle=False)
    for data in tqdm(train_loader):
        data.to(args.device)
        data_input = data.x_numeric, data.x, data.edge_index, data.batch, data.edge_attr
        output = model(data_input).squeeze(1)
        gt = torch.tensor([0.,1.]*int(output.size(0)/2)).to(args.device)
        loss = loss_BCE(output, gt)
        losses.append(loss.item())
        pred = []
        for pr in output:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        correct += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    losses_val = []
    edge_losses_val = []
    correct_val = 0
    acc_count = 0
    preds = []
    gts = []
    for data in tqdm(val_loader):
        data.to(args.device)
        data_input = data.x_numeric, data.x, data.edge_index, data.batch, data.edge_attr
        output = model(data_input).squeeze(1)
        gt = data.y.float()
        loss = loss_BCE(output, gt)
        losses_val.append(loss.item())
        pred = []
        for pr in output:
            if torch.sigmoid(pr)>0.5:
                pred.append(1)
            else:
                pred.append(0)
        preds = preds + output.tolist()
        gts = gts + gt.tolist()
        correct_val += int((torch.tensor(pred).to(args.device) == gt).sum()) / gt.size(0)
    # print(preds)
    # print(gts)
    epoch_loss_train = sum(losses) / len(losses)
    epoch_acc_train = correct / len(train_loader)
    epoch_loss_val = sum(losses_val) / len(losses_val)
    epoch_acc_val = correct_val / len(val_loader)

    log = "Epoch {}, Train Loss: {:.3f}, Train Accuracy: {:.3f}, Val Loss: {:.3f}, Val Accuracy: {:.3f}" 
    print(log.format(epoch, epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val))

    if epoch_acc_val > best_val_acc:
        best_epoch, best_train_loss, best_val_loss, is_best_loss = epoch, epoch_loss_train, epoch_loss_val, True
        best_train_acc, best_val_acc = epoch_acc_train, epoch_acc_val
# model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_loss, best_val_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_loss, ', best total val loss: ', best_val_loss)
print('best train accuracy: ', best_train_acc, ', best validation accuracy: ', best_val_acc)


100%|██████████| 36/36 [00:00<00:00, 56.19it/s]
100%|██████████| 9/9 [00:00<00:00, 48.23it/s]


Epoch 0, Train Loss: 0.708, Train Accuracy: 0.512, Val Loss: 0.693, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 57.57it/s]
100%|██████████| 9/9 [00:00<00:00, 46.65it/s]


Epoch 1, Train Loss: 0.688, Train Accuracy: 0.556, Val Loss: 0.693, Val Accuracy: 0.481


100%|██████████| 36/36 [00:00<00:00, 58.53it/s]
100%|██████████| 9/9 [00:00<00:00, 47.40it/s]


Epoch 2, Train Loss: 0.685, Train Accuracy: 0.554, Val Loss: 0.693, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 57.39it/s]
100%|██████████| 9/9 [00:00<00:00, 49.85it/s]


Epoch 3, Train Loss: 0.682, Train Accuracy: 0.562, Val Loss: 0.693, Val Accuracy: 0.472


100%|██████████| 36/36 [00:00<00:00, 56.97it/s]
100%|██████████| 9/9 [00:00<00:00, 45.12it/s]


Epoch 4, Train Loss: 0.680, Train Accuracy: 0.561, Val Loss: 0.694, Val Accuracy: 0.500


100%|██████████| 36/36 [00:00<00:00, 59.45it/s]
100%|██████████| 9/9 [00:00<00:00, 47.16it/s]


Epoch 5, Train Loss: 0.678, Train Accuracy: 0.571, Val Loss: 0.695, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 58.64it/s]
100%|██████████| 9/9 [00:00<00:00, 48.72it/s]


Epoch 6, Train Loss: 0.677, Train Accuracy: 0.568, Val Loss: 0.698, Val Accuracy: 0.505


100%|██████████| 36/36 [00:00<00:00, 57.18it/s]
100%|██████████| 9/9 [00:00<00:00, 49.03it/s]


Epoch 7, Train Loss: 0.672, Train Accuracy: 0.569, Val Loss: 0.693, Val Accuracy: 0.500


100%|██████████| 36/36 [00:00<00:00, 58.77it/s]
100%|██████████| 9/9 [00:00<00:00, 48.72it/s]


Epoch 8, Train Loss: 0.668, Train Accuracy: 0.589, Val Loss: 0.694, Val Accuracy: 0.500


100%|██████████| 36/36 [00:00<00:00, 59.32it/s]
100%|██████████| 9/9 [00:00<00:00, 49.27it/s]


Epoch 9, Train Loss: 0.667, Train Accuracy: 0.576, Val Loss: 0.704, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 58.48it/s]
100%|██████████| 9/9 [00:00<00:00, 49.70it/s]


Epoch 10, Train Loss: 0.662, Train Accuracy: 0.582, Val Loss: 0.696, Val Accuracy: 0.546


100%|██████████| 36/36 [00:00<00:00, 57.87it/s]
100%|██████████| 9/9 [00:00<00:00, 47.35it/s]


Epoch 11, Train Loss: 0.662, Train Accuracy: 0.592, Val Loss: 0.690, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 57.73it/s]
100%|██████████| 9/9 [00:00<00:00, 47.89it/s]


Epoch 12, Train Loss: 0.657, Train Accuracy: 0.592, Val Loss: 0.699, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 58.02it/s]
100%|██████████| 9/9 [00:00<00:00, 47.19it/s]


Epoch 13, Train Loss: 0.658, Train Accuracy: 0.575, Val Loss: 0.700, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 59.85it/s]
100%|██████████| 9/9 [00:00<00:00, 49.21it/s]


Epoch 14, Train Loss: 0.648, Train Accuracy: 0.623, Val Loss: 0.702, Val Accuracy: 0.542


100%|██████████| 36/36 [00:00<00:00, 58.71it/s]
100%|██████████| 9/9 [00:00<00:00, 49.14it/s]


Epoch 15, Train Loss: 0.649, Train Accuracy: 0.589, Val Loss: 0.700, Val Accuracy: 0.486


100%|██████████| 36/36 [00:00<00:00, 57.68it/s]
100%|██████████| 9/9 [00:00<00:00, 48.12it/s]


Epoch 16, Train Loss: 0.649, Train Accuracy: 0.587, Val Loss: 0.697, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 55.28it/s]
100%|██████████| 9/9 [00:00<00:00, 47.19it/s]


Epoch 17, Train Loss: 0.641, Train Accuracy: 0.622, Val Loss: 0.693, Val Accuracy: 0.542


100%|██████████| 36/36 [00:00<00:00, 58.92it/s]
100%|██████████| 9/9 [00:00<00:00, 48.71it/s]


Epoch 18, Train Loss: 0.638, Train Accuracy: 0.616, Val Loss: 0.713, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 56.35it/s]
100%|██████████| 9/9 [00:00<00:00, 48.01it/s]


Epoch 19, Train Loss: 0.634, Train Accuracy: 0.613, Val Loss: 0.696, Val Accuracy: 0.542


100%|██████████| 36/36 [00:00<00:00, 57.64it/s]
100%|██████████| 9/9 [00:00<00:00, 48.73it/s]


Epoch 20, Train Loss: 0.630, Train Accuracy: 0.628, Val Loss: 0.696, Val Accuracy: 0.500


100%|██████████| 36/36 [00:00<00:00, 56.82it/s]
100%|██████████| 9/9 [00:00<00:00, 48.72it/s]


Epoch 21, Train Loss: 0.625, Train Accuracy: 0.646, Val Loss: 0.708, Val Accuracy: 0.556


100%|██████████| 36/36 [00:00<00:00, 58.68it/s]
100%|██████████| 9/9 [00:00<00:00, 50.44it/s]


Epoch 22, Train Loss: 0.627, Train Accuracy: 0.642, Val Loss: 0.704, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 58.33it/s]
100%|██████████| 9/9 [00:00<00:00, 47.55it/s]


Epoch 23, Train Loss: 0.622, Train Accuracy: 0.627, Val Loss: 0.698, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 57.07it/s]
100%|██████████| 9/9 [00:00<00:00, 49.40it/s]


Epoch 24, Train Loss: 0.619, Train Accuracy: 0.627, Val Loss: 0.711, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 57.53it/s]
100%|██████████| 9/9 [00:00<00:00, 48.55it/s]


Epoch 25, Train Loss: 0.618, Train Accuracy: 0.644, Val Loss: 0.703, Val Accuracy: 0.542


100%|██████████| 36/36 [00:00<00:00, 61.27it/s]
100%|██████████| 9/9 [00:00<00:00, 48.02it/s]


Epoch 26, Train Loss: 0.620, Train Accuracy: 0.613, Val Loss: 0.706, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 57.70it/s]
100%|██████████| 9/9 [00:00<00:00, 47.02it/s]


Epoch 27, Train Loss: 0.611, Train Accuracy: 0.642, Val Loss: 0.695, Val Accuracy: 0.514


100%|██████████| 36/36 [00:00<00:00, 57.83it/s]
100%|██████████| 9/9 [00:00<00:00, 45.97it/s]


Epoch 28, Train Loss: 0.610, Train Accuracy: 0.644, Val Loss: 0.720, Val Accuracy: 0.556


100%|██████████| 36/36 [00:00<00:00, 58.05it/s]
100%|██████████| 9/9 [00:00<00:00, 48.86it/s]


Epoch 29, Train Loss: 0.602, Train Accuracy: 0.658, Val Loss: 0.703, Val Accuracy: 0.583


100%|██████████| 36/36 [00:00<00:00, 58.12it/s]
100%|██████████| 9/9 [00:00<00:00, 45.99it/s]


Epoch 30, Train Loss: 0.604, Train Accuracy: 0.648, Val Loss: 0.703, Val Accuracy: 0.523


100%|██████████| 36/36 [00:00<00:00, 58.98it/s]
100%|██████████| 9/9 [00:00<00:00, 49.40it/s]


Epoch 31, Train Loss: 0.597, Train Accuracy: 0.674, Val Loss: 0.710, Val Accuracy: 0.551


100%|██████████| 36/36 [00:00<00:00, 56.46it/s]
100%|██████████| 9/9 [00:00<00:00, 48.34it/s]


Epoch 32, Train Loss: 0.594, Train Accuracy: 0.665, Val Loss: 0.698, Val Accuracy: 0.509


100%|██████████| 36/36 [00:00<00:00, 57.60it/s]
100%|██████████| 9/9 [00:00<00:00, 48.46it/s]


Epoch 33, Train Loss: 0.599, Train Accuracy: 0.655, Val Loss: 0.707, Val Accuracy: 0.611


100%|██████████| 36/36 [00:00<00:00, 57.92it/s]
100%|██████████| 9/9 [00:00<00:00, 47.95it/s]


Epoch 34, Train Loss: 0.598, Train Accuracy: 0.653, Val Loss: 0.700, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 57.56it/s]
100%|██████████| 9/9 [00:00<00:00, 46.26it/s]


Epoch 35, Train Loss: 0.580, Train Accuracy: 0.684, Val Loss: 0.696, Val Accuracy: 0.551


100%|██████████| 36/36 [00:00<00:00, 58.58it/s]
100%|██████████| 9/9 [00:00<00:00, 48.85it/s]


Epoch 36, Train Loss: 0.571, Train Accuracy: 0.686, Val Loss: 0.701, Val Accuracy: 0.556


100%|██████████| 36/36 [00:00<00:00, 57.89it/s]
100%|██████████| 9/9 [00:00<00:00, 48.19it/s]


Epoch 37, Train Loss: 0.569, Train Accuracy: 0.686, Val Loss: 0.705, Val Accuracy: 0.528


100%|██████████| 36/36 [00:00<00:00, 56.23it/s]
100%|██████████| 9/9 [00:00<00:00, 48.93it/s]


Epoch 38, Train Loss: 0.560, Train Accuracy: 0.688, Val Loss: 0.692, Val Accuracy: 0.569


100%|██████████| 36/36 [00:00<00:00, 59.42it/s]
100%|██████████| 9/9 [00:00<00:00, 47.23it/s]


Epoch 39, Train Loss: 0.550, Train Accuracy: 0.696, Val Loss: 0.731, Val Accuracy: 0.574


100%|██████████| 36/36 [00:00<00:00, 58.01it/s]
100%|██████████| 9/9 [00:00<00:00, 48.10it/s]


Epoch 40, Train Loss: 0.547, Train Accuracy: 0.684, Val Loss: 0.692, Val Accuracy: 0.569


100%|██████████| 36/36 [00:00<00:00, 56.89it/s]
100%|██████████| 9/9 [00:00<00:00, 49.16it/s]


Epoch 41, Train Loss: 0.532, Train Accuracy: 0.715, Val Loss: 0.724, Val Accuracy: 0.630


100%|██████████| 36/36 [00:00<00:00, 59.56it/s]
100%|██████████| 9/9 [00:00<00:00, 49.07it/s]


Epoch 42, Train Loss: 0.527, Train Accuracy: 0.722, Val Loss: 0.703, Val Accuracy: 0.551


100%|██████████| 36/36 [00:00<00:00, 57.91it/s]
100%|██████████| 9/9 [00:00<00:00, 50.02it/s]


Epoch 43, Train Loss: 0.535, Train Accuracy: 0.712, Val Loss: 0.715, Val Accuracy: 0.574


100%|██████████| 36/36 [00:00<00:00, 57.03it/s]
100%|██████████| 9/9 [00:00<00:00, 47.65it/s]


Epoch 44, Train Loss: 0.528, Train Accuracy: 0.717, Val Loss: 0.739, Val Accuracy: 0.588


100%|██████████| 36/36 [00:00<00:00, 56.88it/s]
100%|██████████| 9/9 [00:00<00:00, 49.26it/s]


Epoch 45, Train Loss: 0.508, Train Accuracy: 0.738, Val Loss: 0.744, Val Accuracy: 0.630


100%|██████████| 36/36 [00:00<00:00, 56.56it/s]
100%|██████████| 9/9 [00:00<00:00, 49.07it/s]


Epoch 46, Train Loss: 0.504, Train Accuracy: 0.736, Val Loss: 0.779, Val Accuracy: 0.671


100%|██████████| 36/36 [00:00<00:00, 58.37it/s]
100%|██████████| 9/9 [00:00<00:00, 47.88it/s]


Epoch 47, Train Loss: 0.489, Train Accuracy: 0.753, Val Loss: 0.747, Val Accuracy: 0.644


100%|██████████| 36/36 [00:00<00:00, 58.67it/s]
100%|██████████| 9/9 [00:00<00:00, 49.80it/s]


Epoch 48, Train Loss: 0.480, Train Accuracy: 0.750, Val Loss: 0.734, Val Accuracy: 0.602


100%|██████████| 36/36 [00:00<00:00, 59.15it/s]
100%|██████████| 9/9 [00:00<00:00, 48.62it/s]


Epoch 49, Train Loss: 0.499, Train Accuracy: 0.738, Val Loss: 0.753, Val Accuracy: 0.630


100%|██████████| 36/36 [00:00<00:00, 57.83it/s]
100%|██████████| 9/9 [00:00<00:00, 49.14it/s]


Epoch 50, Train Loss: 0.487, Train Accuracy: 0.759, Val Loss: 0.751, Val Accuracy: 0.616


100%|██████████| 36/36 [00:00<00:00, 58.28it/s]
100%|██████████| 9/9 [00:00<00:00, 49.39it/s]


Epoch 51, Train Loss: 0.492, Train Accuracy: 0.753, Val Loss: 0.758, Val Accuracy: 0.602


100%|██████████| 36/36 [00:00<00:00, 58.86it/s]
100%|██████████| 9/9 [00:00<00:00, 47.21it/s]


Epoch 52, Train Loss: 0.472, Train Accuracy: 0.778, Val Loss: 0.737, Val Accuracy: 0.648


100%|██████████| 36/36 [00:00<00:00, 59.03it/s]
100%|██████████| 9/9 [00:00<00:00, 49.03it/s]


Epoch 53, Train Loss: 0.452, Train Accuracy: 0.797, Val Loss: 0.747, Val Accuracy: 0.602


100%|██████████| 36/36 [00:00<00:00, 57.07it/s]
100%|██████████| 9/9 [00:00<00:00, 47.34it/s]


Epoch 54, Train Loss: 0.428, Train Accuracy: 0.821, Val Loss: 0.758, Val Accuracy: 0.630


100%|██████████| 36/36 [00:00<00:00, 56.96it/s]
100%|██████████| 9/9 [00:00<00:00, 47.17it/s]


Epoch 55, Train Loss: 0.438, Train Accuracy: 0.781, Val Loss: 0.754, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 59.83it/s]
100%|██████████| 9/9 [00:00<00:00, 48.44it/s]


Epoch 56, Train Loss: 0.425, Train Accuracy: 0.807, Val Loss: 0.775, Val Accuracy: 0.616


100%|██████████| 36/36 [00:00<00:00, 56.59it/s]
100%|██████████| 9/9 [00:00<00:00, 48.61it/s]


Epoch 57, Train Loss: 0.453, Train Accuracy: 0.767, Val Loss: 0.748, Val Accuracy: 0.644


100%|██████████| 36/36 [00:00<00:00, 57.27it/s]
100%|██████████| 9/9 [00:00<00:00, 48.21it/s]


Epoch 58, Train Loss: 0.421, Train Accuracy: 0.812, Val Loss: 0.767, Val Accuracy: 0.657


100%|██████████| 36/36 [00:00<00:00, 57.97it/s]
100%|██████████| 9/9 [00:00<00:00, 48.02it/s]


Epoch 59, Train Loss: 0.401, Train Accuracy: 0.832, Val Loss: 0.764, Val Accuracy: 0.662


100%|██████████| 36/36 [00:00<00:00, 55.31it/s]
100%|██████████| 9/9 [00:00<00:00, 46.98it/s]


Epoch 60, Train Loss: 0.400, Train Accuracy: 0.825, Val Loss: 0.784, Val Accuracy: 0.699


100%|██████████| 36/36 [00:00<00:00, 59.91it/s]
100%|██████████| 9/9 [00:00<00:00, 46.19it/s]


Epoch 61, Train Loss: 0.401, Train Accuracy: 0.821, Val Loss: 0.789, Val Accuracy: 0.685


100%|██████████| 36/36 [00:00<00:00, 58.52it/s]
100%|██████████| 9/9 [00:00<00:00, 47.04it/s]


Epoch 62, Train Loss: 0.392, Train Accuracy: 0.839, Val Loss: 0.767, Val Accuracy: 0.690


100%|██████████| 36/36 [00:00<00:00, 57.35it/s]
100%|██████████| 9/9 [00:00<00:00, 46.67it/s]


Epoch 63, Train Loss: 0.380, Train Accuracy: 0.839, Val Loss: 0.788, Val Accuracy: 0.685


100%|██████████| 36/36 [00:00<00:00, 57.07it/s]
100%|██████████| 9/9 [00:00<00:00, 48.07it/s]


Epoch 64, Train Loss: 0.404, Train Accuracy: 0.828, Val Loss: 0.777, Val Accuracy: 0.657


100%|██████████| 36/36 [00:00<00:00, 56.60it/s]
100%|██████████| 9/9 [00:00<00:00, 50.11it/s]


Epoch 65, Train Loss: 0.368, Train Accuracy: 0.849, Val Loss: 0.777, Val Accuracy: 0.718


100%|██████████| 36/36 [00:00<00:00, 59.94it/s]
100%|██████████| 9/9 [00:00<00:00, 50.18it/s]


Epoch 66, Train Loss: 0.366, Train Accuracy: 0.839, Val Loss: 0.832, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 59.96it/s]
100%|██████████| 9/9 [00:00<00:00, 47.57it/s]


Epoch 67, Train Loss: 0.366, Train Accuracy: 0.847, Val Loss: 0.807, Val Accuracy: 0.681


100%|██████████| 36/36 [00:00<00:00, 55.31it/s]
100%|██████████| 9/9 [00:00<00:00, 48.07it/s]


Epoch 68, Train Loss: 0.357, Train Accuracy: 0.849, Val Loss: 0.818, Val Accuracy: 0.657


100%|██████████| 36/36 [00:00<00:00, 58.92it/s]
100%|██████████| 9/9 [00:00<00:00, 47.59it/s]


Epoch 69, Train Loss: 0.345, Train Accuracy: 0.859, Val Loss: 0.765, Val Accuracy: 0.704


100%|██████████| 36/36 [00:00<00:00, 59.45it/s]
100%|██████████| 9/9 [00:00<00:00, 48.83it/s]


Epoch 70, Train Loss: 0.349, Train Accuracy: 0.856, Val Loss: 0.771, Val Accuracy: 0.648


100%|██████████| 36/36 [00:00<00:00, 59.02it/s]
100%|██████████| 9/9 [00:00<00:00, 49.81it/s]


Epoch 71, Train Loss: 0.351, Train Accuracy: 0.851, Val Loss: 0.754, Val Accuracy: 0.662


100%|██████████| 36/36 [00:00<00:00, 59.96it/s]
100%|██████████| 9/9 [00:00<00:00, 47.96it/s]


Epoch 72, Train Loss: 0.344, Train Accuracy: 0.852, Val Loss: 0.795, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 57.05it/s]
100%|██████████| 9/9 [00:00<00:00, 46.84it/s]


Epoch 73, Train Loss: 0.331, Train Accuracy: 0.877, Val Loss: 0.789, Val Accuracy: 0.704


100%|██████████| 36/36 [00:00<00:00, 59.12it/s]
100%|██████████| 9/9 [00:00<00:00, 47.23it/s]


Epoch 74, Train Loss: 0.342, Train Accuracy: 0.854, Val Loss: 0.859, Val Accuracy: 0.630


100%|██████████| 36/36 [00:00<00:00, 58.07it/s]
100%|██████████| 9/9 [00:00<00:00, 48.91it/s]


Epoch 75, Train Loss: 0.364, Train Accuracy: 0.833, Val Loss: 0.829, Val Accuracy: 0.639


100%|██████████| 36/36 [00:00<00:00, 58.91it/s]
100%|██████████| 9/9 [00:00<00:00, 48.93it/s]


Epoch 76, Train Loss: 0.326, Train Accuracy: 0.866, Val Loss: 0.793, Val Accuracy: 0.685


100%|██████████| 36/36 [00:00<00:00, 57.36it/s]
100%|██████████| 9/9 [00:00<00:00, 48.14it/s]


Epoch 77, Train Loss: 0.331, Train Accuracy: 0.866, Val Loss: 0.794, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 58.51it/s]
100%|██████████| 9/9 [00:00<00:00, 48.45it/s]


Epoch 78, Train Loss: 0.310, Train Accuracy: 0.877, Val Loss: 0.799, Val Accuracy: 0.713


100%|██████████| 36/36 [00:00<00:00, 58.11it/s]
100%|██████████| 9/9 [00:00<00:00, 48.83it/s]


Epoch 79, Train Loss: 0.303, Train Accuracy: 0.873, Val Loss: 0.771, Val Accuracy: 0.713


100%|██████████| 36/36 [00:00<00:00, 59.35it/s]
100%|██████████| 9/9 [00:00<00:00, 49.92it/s]


Epoch 80, Train Loss: 0.301, Train Accuracy: 0.889, Val Loss: 0.767, Val Accuracy: 0.731


100%|██████████| 36/36 [00:00<00:00, 57.71it/s]
100%|██████████| 9/9 [00:00<00:00, 46.40it/s]


Epoch 81, Train Loss: 0.312, Train Accuracy: 0.884, Val Loss: 0.890, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 57.71it/s]
100%|██████████| 9/9 [00:00<00:00, 49.72it/s]


Epoch 82, Train Loss: 0.304, Train Accuracy: 0.887, Val Loss: 0.836, Val Accuracy: 0.699


100%|██████████| 36/36 [00:00<00:00, 56.44it/s]
100%|██████████| 9/9 [00:00<00:00, 49.28it/s]


Epoch 83, Train Loss: 0.307, Train Accuracy: 0.870, Val Loss: 0.783, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 58.16it/s]
100%|██████████| 9/9 [00:00<00:00, 47.21it/s]


Epoch 84, Train Loss: 0.302, Train Accuracy: 0.873, Val Loss: 0.788, Val Accuracy: 0.713


100%|██████████| 36/36 [00:00<00:00, 58.50it/s]
100%|██████████| 9/9 [00:00<00:00, 49.65it/s]


Epoch 85, Train Loss: 0.299, Train Accuracy: 0.884, Val Loss: 0.807, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 57.39it/s]
100%|██████████| 9/9 [00:00<00:00, 48.93it/s]


Epoch 86, Train Loss: 0.339, Train Accuracy: 0.854, Val Loss: 0.780, Val Accuracy: 0.731


100%|██████████| 36/36 [00:00<00:00, 57.65it/s]
100%|██████████| 9/9 [00:00<00:00, 49.40it/s]


Epoch 87, Train Loss: 0.305, Train Accuracy: 0.872, Val Loss: 0.881, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 59.17it/s]
100%|██████████| 9/9 [00:00<00:00, 49.57it/s]


Epoch 88, Train Loss: 0.281, Train Accuracy: 0.905, Val Loss: 0.830, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 58.25it/s]
100%|██████████| 9/9 [00:00<00:00, 49.29it/s]


Epoch 89, Train Loss: 0.278, Train Accuracy: 0.894, Val Loss: 0.794, Val Accuracy: 0.745


100%|██████████| 36/36 [00:00<00:00, 57.75it/s]
100%|██████████| 9/9 [00:00<00:00, 47.68it/s]


Epoch 90, Train Loss: 0.264, Train Accuracy: 0.911, Val Loss: 0.798, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 59.95it/s]
100%|██████████| 9/9 [00:00<00:00, 47.49it/s]


Epoch 91, Train Loss: 0.271, Train Accuracy: 0.899, Val Loss: 0.800, Val Accuracy: 0.745


100%|██████████| 36/36 [00:00<00:00, 57.50it/s]
100%|██████████| 9/9 [00:00<00:00, 50.22it/s]


Epoch 92, Train Loss: 0.274, Train Accuracy: 0.892, Val Loss: 0.797, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 56.35it/s]
100%|██████████| 9/9 [00:00<00:00, 48.16it/s]


Epoch 93, Train Loss: 0.268, Train Accuracy: 0.892, Val Loss: 0.858, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 58.30it/s]
100%|██████████| 9/9 [00:00<00:00, 47.38it/s]


Epoch 94, Train Loss: 0.263, Train Accuracy: 0.899, Val Loss: 0.789, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 58.46it/s]
100%|██████████| 9/9 [00:00<00:00, 48.90it/s]


Epoch 95, Train Loss: 0.261, Train Accuracy: 0.908, Val Loss: 0.861, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.80it/s]
100%|██████████| 9/9 [00:00<00:00, 48.72it/s]


Epoch 96, Train Loss: 0.261, Train Accuracy: 0.903, Val Loss: 0.830, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 58.79it/s]
100%|██████████| 9/9 [00:00<00:00, 49.05it/s]


Epoch 97, Train Loss: 0.265, Train Accuracy: 0.891, Val Loss: 0.848, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 58.03it/s]
100%|██████████| 9/9 [00:00<00:00, 49.87it/s]


Epoch 98, Train Loss: 0.263, Train Accuracy: 0.899, Val Loss: 0.899, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 59.27it/s]
100%|██████████| 9/9 [00:00<00:00, 49.55it/s]


Epoch 99, Train Loss: 0.251, Train Accuracy: 0.906, Val Loss: 0.825, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 58.15it/s]
100%|██████████| 9/9 [00:00<00:00, 49.21it/s]


Epoch 100, Train Loss: 0.251, Train Accuracy: 0.911, Val Loss: 0.823, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.65it/s]
100%|██████████| 9/9 [00:00<00:00, 47.64it/s]


Epoch 101, Train Loss: 0.260, Train Accuracy: 0.901, Val Loss: 0.855, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 57.95it/s]
100%|██████████| 9/9 [00:00<00:00, 47.85it/s]


Epoch 102, Train Loss: 0.273, Train Accuracy: 0.898, Val Loss: 0.858, Val Accuracy: 0.731


100%|██████████| 36/36 [00:00<00:00, 56.92it/s]
100%|██████████| 9/9 [00:00<00:00, 47.91it/s]


Epoch 103, Train Loss: 0.267, Train Accuracy: 0.891, Val Loss: 0.824, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.32it/s]
100%|██████████| 9/9 [00:00<00:00, 46.65it/s]


Epoch 104, Train Loss: 0.264, Train Accuracy: 0.896, Val Loss: 0.860, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 59.61it/s]
100%|██████████| 9/9 [00:00<00:00, 48.90it/s]


Epoch 105, Train Loss: 0.241, Train Accuracy: 0.911, Val Loss: 0.817, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 58.08it/s]
100%|██████████| 9/9 [00:00<00:00, 49.01it/s]


Epoch 106, Train Loss: 0.246, Train Accuracy: 0.905, Val Loss: 0.827, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 59.86it/s]
100%|██████████| 9/9 [00:00<00:00, 48.29it/s]


Epoch 107, Train Loss: 0.239, Train Accuracy: 0.918, Val Loss: 0.805, Val Accuracy: 0.773


100%|██████████| 36/36 [00:00<00:00, 57.12it/s]
100%|██████████| 9/9 [00:00<00:00, 48.87it/s]


Epoch 108, Train Loss: 0.230, Train Accuracy: 0.918, Val Loss: 0.941, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 56.41it/s]
100%|██████████| 9/9 [00:00<00:00, 47.11it/s]


Epoch 109, Train Loss: 0.235, Train Accuracy: 0.911, Val Loss: 0.842, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 58.00it/s]
100%|██████████| 9/9 [00:00<00:00, 47.80it/s]


Epoch 110, Train Loss: 0.228, Train Accuracy: 0.925, Val Loss: 0.835, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.78it/s]
100%|██████████| 9/9 [00:00<00:00, 48.26it/s]


Epoch 111, Train Loss: 0.225, Train Accuracy: 0.925, Val Loss: 0.822, Val Accuracy: 0.801


100%|██████████| 36/36 [00:00<00:00, 58.16it/s]
100%|██████████| 9/9 [00:00<00:00, 49.50it/s]


Epoch 112, Train Loss: 0.225, Train Accuracy: 0.915, Val Loss: 0.894, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.39it/s]
100%|██████████| 9/9 [00:00<00:00, 46.30it/s]


Epoch 113, Train Loss: 0.224, Train Accuracy: 0.924, Val Loss: 0.818, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 55.64it/s]
100%|██████████| 9/9 [00:00<00:00, 48.00it/s]


Epoch 114, Train Loss: 0.244, Train Accuracy: 0.903, Val Loss: 0.828, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 57.68it/s]
100%|██████████| 9/9 [00:00<00:00, 47.73it/s]


Epoch 115, Train Loss: 0.227, Train Accuracy: 0.917, Val Loss: 0.850, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 60.42it/s]
100%|██████████| 9/9 [00:00<00:00, 47.95it/s]


Epoch 116, Train Loss: 0.218, Train Accuracy: 0.924, Val Loss: 0.916, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 58.30it/s]
100%|██████████| 9/9 [00:00<00:00, 49.16it/s]


Epoch 117, Train Loss: 0.262, Train Accuracy: 0.903, Val Loss: 0.827, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 56.90it/s]
100%|██████████| 9/9 [00:00<00:00, 48.17it/s]


Epoch 118, Train Loss: 0.217, Train Accuracy: 0.917, Val Loss: 0.912, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 57.25it/s]
100%|██████████| 9/9 [00:00<00:00, 48.01it/s]


Epoch 119, Train Loss: 0.218, Train Accuracy: 0.924, Val Loss: 0.850, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 58.95it/s]
100%|██████████| 9/9 [00:00<00:00, 49.54it/s]


Epoch 120, Train Loss: 0.214, Train Accuracy: 0.931, Val Loss: 0.872, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 55.36it/s]
100%|██████████| 9/9 [00:00<00:00, 49.61it/s]


Epoch 121, Train Loss: 0.204, Train Accuracy: 0.936, Val Loss: 0.870, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 57.83it/s]
100%|██████████| 9/9 [00:00<00:00, 46.12it/s]


Epoch 122, Train Loss: 0.207, Train Accuracy: 0.925, Val Loss: 0.843, Val Accuracy: 0.759


100%|██████████| 36/36 [00:00<00:00, 57.60it/s]
100%|██████████| 9/9 [00:00<00:00, 48.05it/s]


Epoch 123, Train Loss: 0.203, Train Accuracy: 0.934, Val Loss: 0.928, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 58.19it/s]
100%|██████████| 9/9 [00:00<00:00, 49.26it/s]


Epoch 124, Train Loss: 0.205, Train Accuracy: 0.934, Val Loss: 0.844, Val Accuracy: 0.787


100%|██████████| 36/36 [00:00<00:00, 57.08it/s]
100%|██████████| 9/9 [00:00<00:00, 50.07it/s]


Epoch 125, Train Loss: 0.207, Train Accuracy: 0.925, Val Loss: 0.866, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 57.62it/s]
100%|██████████| 9/9 [00:00<00:00, 49.57it/s]


Epoch 126, Train Loss: 0.230, Train Accuracy: 0.908, Val Loss: 0.867, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 59.85it/s]
100%|██████████| 9/9 [00:00<00:00, 48.25it/s]


Epoch 127, Train Loss: 0.205, Train Accuracy: 0.925, Val Loss: 0.876, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 56.79it/s]
100%|██████████| 9/9 [00:00<00:00, 47.50it/s]


Epoch 128, Train Loss: 0.220, Train Accuracy: 0.922, Val Loss: 0.860, Val Accuracy: 0.764


100%|██████████| 36/36 [00:00<00:00, 59.60it/s]
100%|██████████| 9/9 [00:00<00:00, 49.44it/s]


Epoch 129, Train Loss: 0.200, Train Accuracy: 0.932, Val Loss: 0.886, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 59.97it/s]
100%|██████████| 9/9 [00:00<00:00, 50.09it/s]


Epoch 130, Train Loss: 0.201, Train Accuracy: 0.931, Val Loss: 0.849, Val Accuracy: 0.759


100%|██████████| 36/36 [00:00<00:00, 56.88it/s]
100%|██████████| 9/9 [00:00<00:00, 46.69it/s]


Epoch 131, Train Loss: 0.199, Train Accuracy: 0.931, Val Loss: 0.876, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 53.16it/s]
100%|██████████| 9/9 [00:00<00:00, 45.99it/s]


Epoch 132, Train Loss: 0.193, Train Accuracy: 0.931, Val Loss: 0.831, Val Accuracy: 0.759


100%|██████████| 36/36 [00:00<00:00, 55.84it/s]
100%|██████████| 9/9 [00:00<00:00, 48.08it/s]


Epoch 133, Train Loss: 0.199, Train Accuracy: 0.932, Val Loss: 0.851, Val Accuracy: 0.815


100%|██████████| 36/36 [00:00<00:00, 58.31it/s]
100%|██████████| 9/9 [00:00<00:00, 48.62it/s]


Epoch 134, Train Loss: 0.188, Train Accuracy: 0.931, Val Loss: 0.909, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 57.93it/s]
100%|██████████| 9/9 [00:00<00:00, 47.97it/s]


Epoch 135, Train Loss: 0.207, Train Accuracy: 0.917, Val Loss: 0.855, Val Accuracy: 0.773


100%|██████████| 36/36 [00:00<00:00, 58.46it/s]
100%|██████████| 9/9 [00:00<00:00, 50.31it/s]


Epoch 136, Train Loss: 0.195, Train Accuracy: 0.929, Val Loss: 0.869, Val Accuracy: 0.745


100%|██████████| 36/36 [00:00<00:00, 59.16it/s]
100%|██████████| 9/9 [00:00<00:00, 48.79it/s]


Epoch 137, Train Loss: 0.189, Train Accuracy: 0.939, Val Loss: 0.891, Val Accuracy: 0.787


100%|██████████| 36/36 [00:00<00:00, 57.82it/s]
100%|██████████| 9/9 [00:00<00:00, 49.42it/s]


Epoch 138, Train Loss: 0.201, Train Accuracy: 0.920, Val Loss: 0.882, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 58.19it/s]
100%|██████████| 9/9 [00:00<00:00, 47.36it/s]


Epoch 139, Train Loss: 0.208, Train Accuracy: 0.915, Val Loss: 0.862, Val Accuracy: 0.787


100%|██████████| 36/36 [00:00<00:00, 57.28it/s]
100%|██████████| 9/9 [00:00<00:00, 49.38it/s]


Epoch 140, Train Loss: 0.185, Train Accuracy: 0.934, Val Loss: 0.901, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 60.03it/s]
100%|██████████| 9/9 [00:00<00:00, 47.06it/s]


Epoch 141, Train Loss: 0.197, Train Accuracy: 0.927, Val Loss: 1.018, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 59.21it/s]
100%|██████████| 9/9 [00:00<00:00, 48.53it/s]


Epoch 142, Train Loss: 0.206, Train Accuracy: 0.910, Val Loss: 0.909, Val Accuracy: 0.745


100%|██████████| 36/36 [00:00<00:00, 58.05it/s]
100%|██████████| 9/9 [00:00<00:00, 47.42it/s]


Epoch 143, Train Loss: 0.190, Train Accuracy: 0.931, Val Loss: 0.872, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 59.12it/s]
100%|██████████| 9/9 [00:00<00:00, 48.10it/s]


Epoch 144, Train Loss: 0.181, Train Accuracy: 0.939, Val Loss: 0.966, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 58.86it/s]
100%|██████████| 9/9 [00:00<00:00, 50.22it/s]


Epoch 145, Train Loss: 0.184, Train Accuracy: 0.932, Val Loss: 0.876, Val Accuracy: 0.782


100%|██████████| 36/36 [00:00<00:00, 58.98it/s]
100%|██████████| 9/9 [00:00<00:00, 48.21it/s]


Epoch 146, Train Loss: 0.249, Train Accuracy: 0.889, Val Loss: 1.017, Val Accuracy: 0.653


100%|██████████| 36/36 [00:00<00:00, 59.80it/s]
100%|██████████| 9/9 [00:00<00:00, 45.88it/s]


Epoch 147, Train Loss: 0.209, Train Accuracy: 0.932, Val Loss: 0.961, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 59.77it/s]
100%|██████████| 9/9 [00:00<00:00, 49.20it/s]


Epoch 148, Train Loss: 0.178, Train Accuracy: 0.943, Val Loss: 0.963, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 57.60it/s]
100%|██████████| 9/9 [00:00<00:00, 49.20it/s]


Epoch 149, Train Loss: 0.172, Train Accuracy: 0.941, Val Loss: 0.861, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 58.97it/s]
100%|██████████| 9/9 [00:00<00:00, 46.74it/s]


Epoch 150, Train Loss: 0.173, Train Accuracy: 0.944, Val Loss: 0.935, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 58.31it/s]
100%|██████████| 9/9 [00:00<00:00, 50.33it/s]


Epoch 151, Train Loss: 0.173, Train Accuracy: 0.941, Val Loss: 0.927, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 58.36it/s]
100%|██████████| 9/9 [00:00<00:00, 44.78it/s]


Epoch 152, Train Loss: 0.165, Train Accuracy: 0.941, Val Loss: 0.866, Val Accuracy: 0.745


100%|██████████| 36/36 [00:00<00:00, 58.96it/s]
100%|██████████| 9/9 [00:00<00:00, 48.79it/s]


Epoch 153, Train Loss: 0.188, Train Accuracy: 0.920, Val Loss: 0.911, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 59.53it/s]
100%|██████████| 9/9 [00:00<00:00, 47.49it/s]


Epoch 154, Train Loss: 0.176, Train Accuracy: 0.936, Val Loss: 0.881, Val Accuracy: 0.759


100%|██████████| 36/36 [00:00<00:00, 60.04it/s]
100%|██████████| 9/9 [00:00<00:00, 48.82it/s]


Epoch 155, Train Loss: 0.193, Train Accuracy: 0.936, Val Loss: 0.988, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 58.28it/s]
100%|██████████| 9/9 [00:00<00:00, 47.60it/s]


Epoch 156, Train Loss: 0.167, Train Accuracy: 0.944, Val Loss: 0.983, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 58.02it/s]
100%|██████████| 9/9 [00:00<00:00, 49.16it/s]


Epoch 157, Train Loss: 0.157, Train Accuracy: 0.953, Val Loss: 0.904, Val Accuracy: 0.773


100%|██████████| 36/36 [00:00<00:00, 57.68it/s]
100%|██████████| 9/9 [00:00<00:00, 48.87it/s]


Epoch 158, Train Loss: 0.167, Train Accuracy: 0.939, Val Loss: 0.981, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 60.38it/s]
100%|██████████| 9/9 [00:00<00:00, 47.39it/s]


Epoch 159, Train Loss: 0.167, Train Accuracy: 0.946, Val Loss: 0.893, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 58.74it/s]
100%|██████████| 9/9 [00:00<00:00, 47.89it/s]


Epoch 160, Train Loss: 0.159, Train Accuracy: 0.943, Val Loss: 0.910, Val Accuracy: 0.782


100%|██████████| 36/36 [00:00<00:00, 57.20it/s]
100%|██████████| 9/9 [00:00<00:00, 50.71it/s]


Epoch 161, Train Loss: 0.159, Train Accuracy: 0.951, Val Loss: 0.962, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 59.20it/s]
100%|██████████| 9/9 [00:00<00:00, 48.16it/s]


Epoch 162, Train Loss: 0.153, Train Accuracy: 0.939, Val Loss: 1.011, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 59.06it/s]
100%|██████████| 9/9 [00:00<00:00, 48.60it/s]


Epoch 163, Train Loss: 0.162, Train Accuracy: 0.946, Val Loss: 0.991, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 57.38it/s]
100%|██████████| 9/9 [00:00<00:00, 50.26it/s]


Epoch 164, Train Loss: 0.159, Train Accuracy: 0.944, Val Loss: 0.946, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 58.36it/s]
100%|██████████| 9/9 [00:00<00:00, 49.49it/s]


Epoch 165, Train Loss: 0.163, Train Accuracy: 0.946, Val Loss: 0.932, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 59.45it/s]
100%|██████████| 9/9 [00:00<00:00, 46.76it/s]


Epoch 166, Train Loss: 0.160, Train Accuracy: 0.931, Val Loss: 0.956, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 57.67it/s]
100%|██████████| 9/9 [00:00<00:00, 49.10it/s]


Epoch 167, Train Loss: 0.149, Train Accuracy: 0.950, Val Loss: 0.933, Val Accuracy: 0.759


100%|██████████| 36/36 [00:00<00:00, 57.44it/s]
100%|██████████| 9/9 [00:00<00:00, 49.52it/s]


Epoch 168, Train Loss: 0.149, Train Accuracy: 0.950, Val Loss: 0.972, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 57.90it/s]
100%|██████████| 9/9 [00:00<00:00, 49.34it/s]


Epoch 169, Train Loss: 0.158, Train Accuracy: 0.939, Val Loss: 0.894, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 55.77it/s]
100%|██████████| 9/9 [00:00<00:00, 47.34it/s]


Epoch 170, Train Loss: 0.165, Train Accuracy: 0.943, Val Loss: 0.898, Val Accuracy: 0.769


100%|██████████| 36/36 [00:00<00:00, 57.39it/s]
100%|██████████| 9/9 [00:00<00:00, 47.44it/s]


Epoch 171, Train Loss: 0.172, Train Accuracy: 0.934, Val Loss: 1.088, Val Accuracy: 0.667


100%|██████████| 36/36 [00:00<00:00, 57.86it/s]
100%|██████████| 9/9 [00:00<00:00, 49.28it/s]


Epoch 172, Train Loss: 0.159, Train Accuracy: 0.943, Val Loss: 0.942, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 58.50it/s]
100%|██████████| 9/9 [00:00<00:00, 48.61it/s]


Epoch 173, Train Loss: 0.141, Train Accuracy: 0.953, Val Loss: 0.973, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 58.32it/s]
100%|██████████| 9/9 [00:00<00:00, 48.63it/s]


Epoch 174, Train Loss: 0.149, Train Accuracy: 0.950, Val Loss: 0.928, Val Accuracy: 0.773


100%|██████████| 36/36 [00:00<00:00, 59.71it/s]
100%|██████████| 9/9 [00:00<00:00, 48.17it/s]


Epoch 175, Train Loss: 0.146, Train Accuracy: 0.948, Val Loss: 1.039, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 59.04it/s]
100%|██████████| 9/9 [00:00<00:00, 47.84it/s]


Epoch 176, Train Loss: 0.145, Train Accuracy: 0.958, Val Loss: 0.931, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 59.08it/s]
100%|██████████| 9/9 [00:00<00:00, 47.10it/s]


Epoch 177, Train Loss: 0.154, Train Accuracy: 0.936, Val Loss: 1.017, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 58.57it/s]
100%|██████████| 9/9 [00:00<00:00, 49.14it/s]


Epoch 178, Train Loss: 0.147, Train Accuracy: 0.953, Val Loss: 0.957, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 58.72it/s]
100%|██████████| 9/9 [00:00<00:00, 47.23it/s]


Epoch 179, Train Loss: 0.148, Train Accuracy: 0.951, Val Loss: 0.948, Val Accuracy: 0.736


100%|██████████| 36/36 [00:00<00:00, 59.87it/s]
100%|██████████| 9/9 [00:00<00:00, 46.24it/s]


Epoch 180, Train Loss: 0.142, Train Accuracy: 0.948, Val Loss: 0.929, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 59.79it/s]
100%|██████████| 9/9 [00:00<00:00, 45.83it/s]


Epoch 181, Train Loss: 0.147, Train Accuracy: 0.951, Val Loss: 0.999, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 58.97it/s]
100%|██████████| 9/9 [00:00<00:00, 49.57it/s]


Epoch 182, Train Loss: 0.134, Train Accuracy: 0.960, Val Loss: 0.938, Val Accuracy: 0.773


100%|██████████| 36/36 [00:00<00:00, 58.46it/s]
100%|██████████| 9/9 [00:00<00:00, 48.52it/s]


Epoch 183, Train Loss: 0.135, Train Accuracy: 0.950, Val Loss: 0.978, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 59.89it/s]
100%|██████████| 9/9 [00:00<00:00, 47.64it/s]


Epoch 184, Train Loss: 0.140, Train Accuracy: 0.948, Val Loss: 0.962, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 60.42it/s]
100%|██████████| 9/9 [00:00<00:00, 47.05it/s]


Epoch 185, Train Loss: 0.140, Train Accuracy: 0.964, Val Loss: 0.967, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 58.44it/s]
100%|██████████| 9/9 [00:00<00:00, 47.77it/s]


Epoch 186, Train Loss: 0.132, Train Accuracy: 0.951, Val Loss: 0.964, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 58.05it/s]
100%|██████████| 9/9 [00:00<00:00, 47.71it/s]


Epoch 187, Train Loss: 0.133, Train Accuracy: 0.957, Val Loss: 0.998, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 59.00it/s]
100%|██████████| 9/9 [00:00<00:00, 48.19it/s]


Epoch 188, Train Loss: 0.129, Train Accuracy: 0.957, Val Loss: 0.936, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 57.70it/s]
100%|██████████| 9/9 [00:00<00:00, 50.33it/s]


Epoch 189, Train Loss: 0.129, Train Accuracy: 0.958, Val Loss: 1.001, Val Accuracy: 0.722


100%|██████████| 36/36 [00:00<00:00, 57.86it/s]
100%|██████████| 9/9 [00:00<00:00, 49.25it/s]


Epoch 190, Train Loss: 0.127, Train Accuracy: 0.957, Val Loss: 1.005, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 55.73it/s]
100%|██████████| 9/9 [00:00<00:00, 46.23it/s]


Epoch 191, Train Loss: 0.126, Train Accuracy: 0.948, Val Loss: 0.946, Val Accuracy: 0.755


100%|██████████| 36/36 [00:00<00:00, 60.16it/s]
100%|██████████| 9/9 [00:00<00:00, 48.80it/s]


Epoch 192, Train Loss: 0.125, Train Accuracy: 0.965, Val Loss: 1.010, Val Accuracy: 0.741


100%|██████████| 36/36 [00:00<00:00, 59.47it/s]
100%|██████████| 9/9 [00:00<00:00, 45.73it/s]


Epoch 193, Train Loss: 0.123, Train Accuracy: 0.958, Val Loss: 0.990, Val Accuracy: 0.713


100%|██████████| 36/36 [00:00<00:00, 58.43it/s]
100%|██████████| 9/9 [00:00<00:00, 47.44it/s]


Epoch 194, Train Loss: 0.121, Train Accuracy: 0.960, Val Loss: 1.080, Val Accuracy: 0.694


100%|██████████| 36/36 [00:00<00:00, 57.90it/s]
100%|██████████| 9/9 [00:00<00:00, 48.36it/s]


Epoch 195, Train Loss: 0.132, Train Accuracy: 0.955, Val Loss: 1.009, Val Accuracy: 0.727


100%|██████████| 36/36 [00:00<00:00, 58.50it/s]
100%|██████████| 9/9 [00:00<00:00, 48.75it/s]


Epoch 196, Train Loss: 0.120, Train Accuracy: 0.962, Val Loss: 1.105, Val Accuracy: 0.681


100%|██████████| 36/36 [00:00<00:00, 58.61it/s]
100%|██████████| 9/9 [00:00<00:00, 50.01it/s]


Epoch 197, Train Loss: 0.151, Train Accuracy: 0.939, Val Loss: 0.985, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 56.51it/s]
100%|██████████| 9/9 [00:00<00:00, 46.78it/s]


Epoch 198, Train Loss: 0.135, Train Accuracy: 0.955, Val Loss: 0.978, Val Accuracy: 0.708


100%|██████████| 36/36 [00:00<00:00, 59.45it/s]
100%|██████████| 9/9 [00:00<00:00, 48.66it/s]

Epoch 199, Train Loss: 0.133, Train Accuracy: 0.955, Val Loss: 0.993, Val Accuracy: 0.713
best Epoch:  133 , best total train loss:  0.1994947087433603 , best total val loss:  0.8511699570549859
best train accuracy:  0.9322916666666666 , best validation accuracy:  0.8148148148148148





In [None]:
# Epoch 199, Train Loss: 0.101, Train Accuracy: 0.962, Val Loss: 0.491, Val Accuracy: 0.829
# best Epoch:  198 , best total train loss:  0.11967981037580305 , best total val loss:  0.5110236058632532
# best train accuracy:  0.9583333333333334 , best validation accuracy:  0.8425925925925926

# Epoch 199, Train Loss: 0.018, Train Accuracy: 0.998, Val Loss: 0.560, Val Accuracy: 0.843
# best Epoch:  125 , best total train loss:  0.11817064948586954 , best total val loss:  0.3765643032060729
# best train accuracy:  0.9635416666666666 , best validation accuracy:  0.9027777777777778


# Epoch 199, Train Loss: 0.039, Train Accuracy: 0.991, Val Loss: 1.571, Val Accuracy: 0.745
# best Epoch:  111 , best total train loss:  0.16000805215703118 , best total val loss:  0.8233955204486847
# best train accuracy:  0.9427083333333334 , best validation accuracy:  0.8009259259259259

#####    t_threshold = 3s
# Epoch 199, Train Loss: 0.072, Train Accuracy: 0.986, Val Loss: 0.820, Val Accuracy: 0.815
# best Epoch:  177 , best total train loss:  0.09902919233880109 , best total val loss:  0.7536483804384867
# best train accuracy:  0.9704861111111112 , best validation accuracy:  0.8425925925925926

# Epoch 199, Train Loss: 0.036, Train Accuracy: 0.995, Val Loss: 0.922, Val Accuracy: 0.801
# best Epoch:  156 , best total train loss:  0.08065216031132473 , best total val loss:  0.6996115938656859
# best train accuracy:  0.9739583333333334 , best validation accuracy:  0.8888888888888888

# Epoch 199, Train Loss: 0.057, Train Accuracy: 0.993, Val Loss: 1.001, Val Accuracy: 0.764
# best Epoch:  99 , best total train loss:  0.27886005491018295 , best total val loss:  0.5932083477576574
# best train accuracy:  0.8958333333333334 , best validation accuracy:  0.8055555555555556

In [68]:
        # gt = []
        # for i in range(int(data.y.size(0)/sequence_len)):
        #     gt.append(data.y[i*8])
        # gt = torch.tensor(gt, dtype=torch.long).to(args.device)
        # scores = model(data_input)[0]
        # loss = loss_fn(scores, gt)
        # pred = scores.argmax(dim=1)
        # correct += int((pred == gt).sum()) / gt.size(0)
        # losses.append(loss.item())

        # batch_node_label = []
        # for data_point in data.batch:
        #     if data.y[data_point] == 8:
        #         batch_node_label.append(False)
        #     else:
        #         batch_node_label.append(True)
        # batch_node_label = torch.tensor(batch_node_label, dtype=torch.bool).to(args.device)
        # batch_edge_label = batch_node_label[data.edge_index[0]]
        # edge_score = model(data_input)[1][batch_edge_label]
        # gt_edge = []
        # batch_edge_index = data.obj_ids[data.edge_index[0]][batch_edge_label]
        # for i in range(edge_score.size(0)):
        #     if batch_edge_index[i] in data.crash_ids:
        #         gt_edge.append([0,1])
        #     else:
        #         gt_edge.append([1,0])
        # gt_edge = torch.tensor(gt_edge, dtype=torch.float32).to(args.device)
        # if edge_score.size(0) != 0:
        #     loss_edge = loss_BCE(edge_score, gt_edge)*3
        #     edge_losses.append(loss_edge.item())
        #     pred_edge = edge_score.argmax(dim=1)
        #     labels_edge = gt_edge.argmax(dim=1)
        #     edge_correct += int((pred_edge == labels_edge).sum()) / labels_edge.size(0)
        #     acc_count += 1



        # backprop
        optimizer.zero_grad()
        loss.backward()
        if edge_score.size(0) != 0:
            loss_edge.backward()
        optimizer.step()

    losses_val = []
    edge_losses = []
    correct_val = 0
    edge_correct = 0
    acc_count = 0
    for data in val_loader:
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr, data.scene
        gt = []
        for i in range(int(data.y.size(0)/sequence_len)):
            gt.append(data.y[i*8])
        gt = torch.tensor(gt, dtype=torch.long).to(args.device)
        scores = model(data_input)[0]
        loss = loss_fn(scores, gt)
        pred = scores.argmax(dim=1)
        correct_val += int((pred == gt).sum()) / gt.size(0)
        losses_val.append(loss.item())

        batch_node_label = []
        for data_point in data.batch:
            if data.y[data_point] == 8:
                batch_node_label.append(False)
            else:
                batch_node_label.append(True)
        batch_node_label = torch.tensor(batch_node_label, dtype=torch.bool).to(args.device)
        batch_edge_label = batch_node_label[data.edge_index[0]]
        edge_score = model(data_input)[1][batch_edge_label]
        gt_edge = []
        batch_edge_index = data.obj_ids[data.edge_index[0]][batch_edge_label]
        for i in range(edge_score.size(0)):
            if batch_edge_index[i] in data.crash_ids:
                gt_edge.append([0,1])
            else:
                gt_edge.append([1,0])
        gt_edge = torch.tensor(gt_edge, dtype=torch.float32).to(args.device)
        if edge_score.size(0) != 0:
            loss_edge = loss_BCE(edge_score, gt_edge)*3
            edge_losses.append(loss_edge.item())
            pred_edge = edge_score.argmax(dim=1)
            labels_edge = gt_edge.argmax(dim=1)
            edge_correct += int((pred_edge == labels_edge).sum()) / labels_edge.size(0)
            acc_count += 1

    epoch_loss_train = sum(losses) / len(losses)
    epoch_acc_train = correct / len(train_loader)
    epoch_loss_val = sum(losses_val) / len(losses_val)
    epoch_acc_val = correct_val / len(val_loader)
    epoch_edge_loss_train = sum(edge_losses) / len(edge_losses)
    epoch_edge_acc_train = edge_correct / acc_count
    epoch_edge_loss_val = sum(edge_losses) / len(edge_losses)
    epoch_edge_acc_val = edge_correct / acc_count



    log = "Epoch {},   Train Loss: {:.3f},     Train Accuracy: {:.3f},     Val Loss: {:.3f},     Val Accuracy: {:.3f} \n    Edge Train Loss: {:.3f}, Edge Train Accuracy: {:.3f}, Edge Val Loss: {:.3f}, Edge Val Accuracy: {:.3f}" 
    print(log.format(epoch, 
                     epoch_loss_train, epoch_acc_train, epoch_loss_val, epoch_acc_val, epoch_edge_loss_train, epoch_edge_acc_train, epoch_edge_loss_val, epoch_edge_acc_val))

    is_best_loss = False
    if epoch_loss_val < best_val_loss:
        best_epoch, best_train_graph_loss, best_val_loss, is_best_loss = epoch, epoch_loss_train, epoch_loss_val, True

model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_graph_loss, best_val_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_graph_loss, ', best total val loss: ', best_val_loss)
    

29

In [12]:
losses_test = []
edge_losses_test = []
edge_correct_test = 0
acc_count = 0
for data in tqdm(test_loader):
    data.to(args.device)
    data_input = data.x, data.edge_index, data.batch, data.edge_attr, data.edge_id_0, data.edge_id_1, data.edge_count, data.crash_ids
    output = model(data_input)
    edge_score = output[0]
    num_sq = int(len(data.crash_ids)/sequence_len)
    gt_edge = torch.tensor([]).to(args.device)
    for n in range(num_sq):
        for out_edge_id in output[1][n]:
            crash_nodes = data.crash_ids[(n+1)*sequence_len-1].tolist()
            crash_edge = [[crash_nodes[0],crash_nodes[1]], [crash_nodes[1],crash_nodes[0]]]
            if out_edge_id.tolist() in crash_edge:
                # gt_edge = torch.cat([gt_edge, data.y[(n+1)*sequence_len-1].unsqueeze(0)], 0).long()
                gt_edge = torch.cat([gt_edge, torch.tensor([1]).to(args.device)], 0)
            else:
                gt_edge = torch.cat([gt_edge, torch.tensor([0]).to(args.device)], 0)
    loss_edge = loss_fn(edge_score.squeeze(1), gt_edge)
    edge_losses_val.append(loss_edge.item())
    pred_edge = edge_score.argmax(dim=1)
    edge_correct_val += int((pred_edge == gt_edge).sum()) / gt_edge.size(0)
    print(pred_edge)
    print(gt_edge)
epoch_edge_loss_train = sum(edge_losses) / len(edge_losses)
epoch_edge_acc_train = edge_correct / len(train_loader)
epoch_edge_loss_val = sum(edge_losses_val) / len(edge_losses_val)
epoch_edge_acc_val = edge_correct_val / len(val_loader)

 40%|████      | 2/5 [00:00<00:00,  4.94it/s]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0], device='cuda:3')
tensor([1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.,
        1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.], device='cuda:3')
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0], device='cuda:3')
tensor([1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1.,
        0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:3')


 80%|████████  | 4/5 [00:00<00:00,  6.78it/s]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:3')
tensor([1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1.,
        0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.],
       device='cuda:3')
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0], device='cuda:3')
tensor([1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
        1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.], device='cuda:3')
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:3')
tensor([1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0.], device='cuda:3')


100%|██████████| 5/5 [00:00<00:00,  6.62it/s]


In [22]:
import warnings
warnings.filterwarnings("ignore")

model.train()
loss_fn = torch.nn.CrossEntropyLoss()
loss_BCE = torch.nn.BCEWithLogitsLoss()

for epoch in range(args.epochs):
    best_train_edge_loss, best_val_edge_loss = float("inf"), float("inf")
    losses = []
    edge_losses = []
    correct = 0
    edge_correct = 0
    random.shuffle(train_dataset)
    flattened = list(chain.from_iterable(train_dataset))
    train_loader = DataLoader(flattened, num_workers=args.num_workers, batch_size=args.batch_size*sequence_len, shuffle=False)
    acc_count = 0
    for data in tqdm(train_loader):
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr, data.edge_id_0, data.edge_id_1, data.edge_count, data.crash_ids
        output = model(data_input)
        edge_score = output[0]
        num_sq = int(len(data.crash_ids)/sequence_len)
        gt_edge = torch.tensor([]).to(args.device)
        for n in range(num_sq):
            for out_edge_id in output[1][n]:
                crash_nodes = data.crash_ids[(n+1)*sequence_len-1].tolist()
                crash_edge = [[crash_nodes[0],crash_nodes[1]], [crash_nodes[1],crash_nodes[0]]]
                if out_edge_id.tolist() in crash_edge:
                    # gt_edge = torch.cat([gt_edge, data.y[(n+1)*sequence_len-1].unsqueeze(0)], 0).long()
                    gt_edge = torch.cat([gt_edge, torch.tensor([1]).to(args.device)], 0).long()
                else:
                    # gt_edge = torch.cat([gt_edge, torch.tensor([8]).to(args.device)], 0).long()
                    gt_edge = torch.cat([gt_edge, torch.tensor([0]).to(args.device)], 0).long()
        loss_edge = loss_fn(edge_score, gt_edge)
        edge_losses.append(loss_edge.item())
        pred_edge = edge_score.argmax(dim=1)
        edge_correct += int((pred_edge == gt_edge).sum()) / gt_edge.size(0)
        # backprop
        optimizer.zero_grad()
        loss_edge.backward()
        optimizer.step()

    losses_val = []
    edge_losses_val = []
    edge_correct_val = 0
    acc_count = 0
    for data in tqdm(val_loader):
        data.to(args.device)
        data_input = data.x, data.edge_index, data.batch, data.edge_attr, data.edge_id_0, data.edge_id_1, data.edge_count, data.crash_ids
        output = model(data_input)
        edge_score = output[0]
        num_sq = int(len(data.crash_ids)/sequence_len)
        gt_edge = torch.tensor([]).to(args.device)
        for n in range(num_sq):
            for out_edge_id in output[1][n]:
                crash_nodes = data.crash_ids[(n+1)*sequence_len-1].tolist()
                crash_edge = [[crash_nodes[0],crash_nodes[1]], [crash_nodes[1],crash_nodes[0]]]
                if out_edge_id.tolist() in crash_edge:
                    # gt_edge = torch.cat([gt_edge, data.y[(n+1)*sequence_len-1].unsqueeze(0)], 0).long()
                    gt_edge = torch.cat([gt_edge, torch.tensor([1]).to(args.device)], 0).long()
                else:
                    gt_edge = torch.cat([gt_edge, torch.tensor([0]).to(args.device)], 0).long()
        loss_edge = loss_fn(edge_score, gt_edge)
        edge_losses_val.append(loss_edge.item())
        pred_edge = edge_score.argmax(dim=1)
        edge_correct_val += int((pred_edge == gt_edge).sum()) / gt_edge.size(0)

    epoch_edge_loss_train = sum(edge_losses) / len(edge_losses)
    epoch_edge_acc_train = edge_correct / len(train_loader)
    epoch_edge_loss_val = sum(edge_losses_val) / len(edge_losses_val)
    epoch_edge_acc_val = edge_correct_val / len(val_loader)



    log = "Epoch {}, Edge Train Loss: {:.3f}, Edge Train Accuracy: {:.3f}, Edge Val Loss: {:.3f}, Edge Val Accuracy: {:.3f}" 
    print(log.format(epoch, 
                     epoch_edge_loss_train, epoch_edge_acc_train, epoch_edge_loss_val, epoch_edge_acc_val))

    is_best_loss = False
    if epoch_edge_loss_val < best_val_edge_loss:
        best_epoch, best_train_edge_loss, best_val_edge_loss, is_best_loss = epoch, epoch_edge_loss_train, epoch_edge_loss_val, True

model.save_checkpoint(os.path.join("logs", args.save), optimizer, epoch, best_train_edge_loss, best_val_edge_loss, is_best_loss)
print('best Epoch: ', best_epoch, ', best total train loss: ', best_train_edge_loss, ', best total val loss: ', best_val_edge_loss)

0

In [42]:
best_epoch, best_train_loss, best_val_loss = model.load_checkpoint(os.path.join("logs", args.save), optimizer)
model.eval()

losses = []
edge_losses = []
reg_losses = []
correct = 0
edge_correct = 0
frame_error = 0

for data in test_loader:
    data.to(args.device)

    data_input = data.x, data.edge_index, data.batch, data.edge_attr

    labels = data.y

    gt = []
    for label in labels:
        if label == 0:
            gt.append([1,0])
        else:
            gt.append([0,1])
    gt = torch.tensor(gt, dtype=torch.float32).to(args.device)


    # get class scores from model
    scores = model(data_input)[0]

    loss = loss_fn(scores, gt)

    # Keep track of loss and accuracy
    pred = scores.argmax(dim=1)
    correct += int((pred == labels).sum()) / labels.size(0)
    losses.append(loss.item())


    batch_node_label = []
    for data_point in data.batch:
        if data.y[data_point] == 0:
            batch_node_label.append(False)
        else:
            batch_node_label.append(True)
    batch_node_label = torch.tensor(batch_node_label, dtype=torch.bool).to(args.device)
    batch_edge_label = batch_node_label[data.edge_index[0]]
    edge_score = model(data_input)[1][batch_edge_label]
    gt_edge = []
    batch_edge_index = data.obj_ids[data.edge_index[0]][batch_edge_label]
    for i in range(edge_score.size(0)):
        if batch_edge_index[i] in data.crash_ids:
            gt_edge.append([0,1])
        else:
            gt_edge.append([1,0])
    gt_edge = torch.tensor(gt_edge, dtype=torch.float32).to(args.device)
    loss_edge = loss_fn(edge_score, gt_edge)
    edge_losses.append(loss_edge.item())
    pred_edge = edge_score.argmax(dim=1)
    labels_edge = gt_edge.argmax(dim=1)
    edge_correct += int((pred_edge == labels_edge).sum()) / labels_edge.size(0)


    regression_score = model(data_input)[2][data.y == 1]
    regression_gt = data.frame_to_crash[data.y == 1]
    loss_reg = torch.sqrt(loss_mse(torch.flatten(regression_score, start_dim=0), regression_gt))/t_prediction
    frame_error += abs(torch.sub(torch.flatten(regression_score, start_dim=0), regression_gt)).sum().item() / regression_gt.size(0)
    reg_losses.append(loss_reg.item())

# gather the results for the epoch
loss_test = sum(losses) / len(losses)
correct_test = correct / len(test_loader)
epoch_edge_loss_test = sum(edge_losses) / len(edge_losses)
epoch_edge_acc_test = edge_correct / len(test_loader)
epoch_reg_loss_test = sum(reg_losses) / len(reg_losses)
epoch_reg_error_test = frame_error / len(test_loader)

print("Test Loss at epoch {}: {:.3f}, Test Accuracy: {:.3f}, \nTest Edge Loss: {:.3f}, Test Edge Accuracy: {:.3f} \nTest Regression Loss: {:.3f}, Test Regression Error: {:.3f}".format(best_epoch, loss_test, correct_test, epoch_edge_loss_test, epoch_edge_acc_test, epoch_reg_loss_test, epoch_reg_error_test))

Test Loss at epoch 99: 2.376, Test Accuracy: 0.720, 
Test Edge Loss: 0.604, Test Edge Accuracy: 0.821 
Test Regression Loss: 0.301, Test Regression Error: 7.256
