# Preparing Dataset

In [None]:
import os
import glob
import xml.etree.ElementTree as ET
import numpy as np
from tqdm import tqdm
import functools 
from functools import reduce
import shutil
from operator import mul
from multiprocessing import Process, Manager, Pool
from operator import itemgetter
from itertools import groupby
import ujson, json
import math
import matplotlib.pyplot as plt
print(os.cpu_count())

## 1. Define data path

In [None]:
# please give the path to the data folder
root_folder = "/absolute/path/to/data/folder"

# please give the experiment name of the evaluation
exp_name = "Experiment-testing_DateOfRunning"

processed_data_tmp_folder_path = "processed_data_tmp"


## 2. Process data

In [None]:
def process_func_new(fcd_files, index, keyid):
    """Process the data for the safety critical events analysis.

    Args:
        fcd_files (list): List of fcd files.
        index (int): Index of the process.
        keyid (str): Type of the safety-critical event, should be "crash" or "safe".
    """
    data_info = {
        "safe_id_list": [], "safe_weight": [],
        "crash_id_list": [], "crash_weight": []
    }
    for file_fcd in tqdm(fcd_files):
        # check json files
        fcd_id_list = []
        with open(file_fcd) as fcd_obj:
            for line in fcd_obj:
                try:
                    iter_info = json.loads(line)
                    fcd_id_list.append(int(iter_info["original_name"]))
                except:
                    print("can't load",file_fcd)
        file_json = file_fcd.replace(".fcd.json",".json")
        json_id_list, json_weight_list, json_lineindex_list = [], [], []
        json_weight1_list, json_crit0_list, json_crit001_list, json_crit_list = [], [], [], []
        json_nade_info, json_nade_info_new = [], []
        with open(file_json) as json_obj:
            line_index = 0
            for line in json_obj:
                try:
                    iter_info = json.loads(line)
                except:
                    print("can't load", file_json, line_index)
                    continue
                for k in iter_info["weight_list_step"]:
                    if len(iter_info["weight_list_step"][k]) > 1:
                        iter_info["weight_list_step"][k] = [reduce(mul, iter_info["weight_list_step"][k])]
                weight_list = np.array(list(iter_info["weight_list_step"].values()))
                cum_weight_list = np.cumprod(weight_list)
                find_flag = False
                for i in range(len(cum_weight_list)):
                    if cum_weight_list[i] < 1:
                        find_flag = True
                        break
                if not find_flag:
                    pass
                if 1:
                    init_time_step = list(iter_info["weight_list_step"].keys())[i]
                    if init_time_step in iter_info["CAV_info"]:
                        json_id_list.append(int(iter_info["episode_info"]["id"]))
                        json_weight_list.append(np.clip(iter_info["weight_episode"],0,np.inf))
                        json_lineindex_list.append(line_index)
                        json_weight1_list.append(len(list(iter_info["CAV_info"].keys())))
                        crit_list = []
                        crit0_list = []
                        crit001_list = []
                        nade_info = []
                        new_nade_info = []
                        for time in iter_info["CAV_info"]:
                            if iter_info["CAV_info"][time]["criticality"] > 0.01:
                                crit001_list.append(time)
                            if iter_info["CAV_info"][time]["criticality"] > 0.0:
                                crit0_list.append(time)
                            crit_list.append([time,iter_info["CAV_info"][time]["criticality"]])
                        json_crit0_list.append(crit0_list)
                        json_crit001_list.append(crit001_list)
                        json_crit_list.append(crit_list)
                        for time in iter_info["NADE_info"]:
                            nade_info.append([time]+list(iter_info["NADE_info"][time].values())[0])
                        json_nade_info.append(nade_info)
                        json_nade_info_new.append(new_nade_info)
                    else:
                        pass
                else:
                    pass

                line_index += 1

        if set(fcd_id_list)!=set(json_id_list):
            print(file_fcd, len(fcd_id_list), len(json_id_list))
        inter_ids = set(fcd_id_list).intersection(set(json_id_list))
        inter_ids = sorted(list(inter_ids))
        for ep_id in inter_ids:
            info = [ep_id,
                file_json.split("/")[-1].replace(".json",""),
                fcd_id_list.index(ep_id), 
                json_lineindex_list[json_id_list.index(ep_id)],
                json_weight1_list[json_id_list.index(ep_id)],
                json_crit001_list[json_id_list.index(ep_id)],
                json_crit0_list[json_id_list.index(ep_id)],
                json_nade_info[json_id_list.index(ep_id)],
                json_crit_list[json_id_list.index(ep_id)],                
            ]
            data_info[keyid+"_id_list"].append(info)
            data_info[keyid+"_weight"].append(json_weight_list[json_id_list.index(ep_id)])
    print([len(data_info[k]) for k in data_info])
    os.makedirs(processed_data_tmp_folder_path, exist_ok=True)
    with open(processed_data_tmp_folder_path+f"/complete_info_offlinecollect-tmp_{keyid}_{str(index)}.json",'w') as fp:
        print("saved",f"complete_info_offlinecollect-tmp_{keyid}_{str(index)}.json")
        json.dump(data_info, fp, indent=4)

In [None]:
def densify(fcd_dict, init_time_step):
    """Remove useless information from trajectory data.

    Args:
        fcd_dict (dict): Raw trajectory data.
        init_time_step (str): Initial time step.

    Returns:
        dict: Processed trajectory data.
    """
    new_result = {"fcd-export":{"timestep":[]}, "original_name":fcd_dict["original_name"]}
    for time_info in fcd_dict["fcd-export"]["timestep"]:
        if round(float(time_info["@time"]),1) >= round(float(init_time_step),1):
            new_time_info = {}
            new_time_info["@time"] = time_info["@time"]
            new_time_info["vehicle"] = []
            for veh in time_info["vehicle"]:
                new_veh_info = dict(veh)
                for theKey in ["@pos","@lane","@slope","@accelerationLat"]:
                    new_veh_info.pop(theKey)
                new_veh_info["@type"] = new_veh_info["@type"].split("@")[0]
                new_time_info["vehicle"].append(new_veh_info)
            new_result["fcd-export"]["timestep"].append(new_time_info)
    return new_result

def process_fcd_json(root_folder, input_exp, output_exp, fcd_files, index, keyid, offset=0):
    """Process the trajectory data and save the processed data.

    Args:
        root_folder (str): Path to the data folder.
        input_exp (str): Experiment name of the evaluation.
        output_exp (str): Experiment name of the processed data.
        fcd_files (list): List of trajectory data files.
        index (int): Index of the parallel process.
        keyid (str): The type of the data, should be crash or safe.
        offset (int, optional): Offset of file index. Defaults to 0.
    """
    num_lines_section = 50
    num = 0
    
    for file_fcd in tqdm(fcd_files):
        file_json = file_fcd.replace(".fcd.json",".json")
        if not os.path.exists(file_fcd) or not os.path.exists(file_json):
            continue
        # print(file_fcd)
        # check json files
        
        json_id_weight_info = {}
        
        densified_jsoninfo_list = []
        json_id_list = []
        with open(file_json) as json_obj:
            for line_json in json_obj:
                iter_info_json = json.loads(line_json)
                if len(list(iter_info_json["CAV_info"].keys())) == 0:
                    init_time_step = "0.0"
                else:
                    init_time_step = list(iter_info_json["CAV_info"].keys())[0]
                    if iter_info_json["CAV_info"][init_time_step]["criticality"] == 0:
                        del_time_key_list = []
                        for t in iter_info_json["CAV_info"]:
                            if iter_info_json["CAV_info"][t]["criticality"] == 0:
                                del_time_key_list.append(t)
                            if iter_info_json["CAV_info"][t]["criticality"] > 0 and float(t) > float(init_time_step):
                                init_time_step = t
                                break
                        for t in del_time_key_list:
                            iter_info_json["CAV_info"].pop(t)
                json_id_weight_info[iter_info_json["episode_info"]["id"]] = init_time_step
                for theKey in ["weight_step_info","current_weight","crash_decision_info","decision_time_info","drl_epsilon_step_info",
                               "real_epsilon_step_info","drl_obs_step_info","ndd_step_info", "criticality_step_info", "cav_mean_speed", 
                               "RSS_rate",
                              ]:
                    iter_info_json.pop(theKey)
                
                densified_jsoninfo_list.append(iter_info_json)
                json_id_list.append(iter_info_json["episode_info"]["id"])
        
        densified_fcdinfo_list = []
        fcd_id_list = []
        with open(file_fcd) as fcd_obj:
            for line_fcd in fcd_obj:
                iter_info_fcd = json.loads(line_fcd)
                fcd_ep_id = int(iter_info_fcd["original_name"])
                if fcd_ep_id in json_id_weight_info:
                    new_fcd_info = densify(iter_info_fcd, json_id_weight_info[fcd_ep_id])
                    densified_fcdinfo_list.append(new_fcd_info)
                else:
                    densified_fcdinfo_list.append(iter_info_fcd)
                fcd_id_list.append(int(iter_info_fcd["original_name"]))
                
        inter_ids = set(fcd_id_list).intersection(set(json_id_list))
        inter_ids = sorted(list(inter_ids))
        
        num_lines = len(inter_ids)
        num += num_lines
        num_sections = int(num_lines/num_lines_section)+1
        
        for i in range(num_sections):
            file_index = int(file_fcd.split("/")[-1].split(".")[0])
            new_file_index = file_index+offset
            file_fcd_output = file_fcd.replace(input_exp,output_exp)
            file_fcd_output = file_fcd_output.replace(f"{file_index}.fcd.json",f"{new_file_index}_{i}.fcd.json")
            file_json_output = file_json.replace(input_exp,output_exp)
            file_json_output = file_json_output.replace(f"{file_index}.json",f"{new_file_index}_{i}.json")
            if i == num_sections-1:
                stored_ids = inter_ids[i*num_lines_section:]
            else:
                stored_ids = inter_ids[i*num_lines_section:(i+1)*num_lines_section]
            
            if len(stored_ids) > 0:

                # filter problematic trajs (different length)
                filtered_stored_ids = []
                for ids in stored_ids:
                    line1 = densified_fcdinfo_list[fcd_id_list.index(ids)]
                    line2 = densified_jsoninfo_list[json_id_list.index(ids)]
                    time_step_count1 = len(line1["fcd-export"]["timestep"])
                    time_step_count2 = len(list(line2["CAV_info"].keys()))
                    if time_step_count1 != time_step_count2:
                        print("problem", line1["original_name"], file_fcd_output, time_step_count1, time_step_count2)
                    else:
                        filtered_stored_ids.append(ids)
                        
                with open(file_fcd_output,'w') as fp1:
                    for ids in filtered_stored_ids:
                        line1 = densified_fcdinfo_list[fcd_id_list.index(ids)]
                        line1["original_name"] = str(int(line1["original_name"])+offset*20000)
                        json.dump(line1, fp1)
                        fp1.write('\n')
                with open(file_json_output,'w') as fp2:
                    for ids in filtered_stored_ids:
                        line2 = densified_jsoninfo_list[json_id_list.index(ids)]
                        line2["episode_info"]["id"] += offset*20000
                        json.dump(line2, fp2)
                        fp2.write('\n')

def process_fcd_json_mp(fcd_files, offset=0):
    """Process the trajectory data using parallel processing and save the processed data.

    Args:
        fcd_files (list): List of trajectory data files.
        offset (int, optional): Offset of file index. Defaults to 0.
    """
    num_lines_section = 50
    example_fcd_file = fcd_files[0]
    input_exp = example_fcd_file.split("/")[-3]
    output_exp = input_exp+"/densified_exps"        
    
    for file_fcd in tqdm(fcd_files):
        file_json = file_fcd.replace(".fcd.json",".json")
        if not os.path.exists(file_fcd) or not os.path.exists(file_json):
            continue
        # print(file_fcd)
        # check json files
        
        json_id_weight_info = {}
        
        densified_jsoninfo_list = []
        json_id_list = []
        with open(file_json) as json_obj:
            for line_json in json_obj:
                iter_info_json = json.loads(line_json)
                if len(list(iter_info_json["CAV_info"].keys())) == 0:
                    init_time_step = "0.0"
                else:
                    init_time_step = list(iter_info_json["CAV_info"].keys())[0]
                    if iter_info_json["CAV_info"][init_time_step]["criticality"] == 0:
                        del_time_key_list = []
                        for t in iter_info_json["CAV_info"]:
                            if iter_info_json["CAV_info"][t]["criticality"] == 0:
                                del_time_key_list.append(t)
                            if iter_info_json["CAV_info"][t]["criticality"] > 0 and float(t) > float(init_time_step):
                                init_time_step = t
                                break
                        for t in del_time_key_list:
                            iter_info_json["CAV_info"].pop(t)
                json_id_weight_info[iter_info_json["episode_info"]["id"]] = init_time_step
                for theKey in ["weight_step_info","current_weight","crash_decision_info","decision_time_info","drl_epsilon_step_info",
                               "real_epsilon_step_info","drl_obs_step_info","ndd_step_info", "criticality_step_info", "cav_mean_speed", 
                               "RSS_rate",
                              ]:
                    iter_info_json.pop(theKey)
                
                densified_jsoninfo_list.append(iter_info_json)
                json_id_list.append(iter_info_json["episode_info"]["id"])
        
        densified_fcdinfo_list = []
        fcd_id_list = []
        with open(file_fcd) as fcd_obj:
            for line_fcd in fcd_obj:
                iter_info_fcd = json.loads(line_fcd)
                fcd_ep_id = int(iter_info_fcd["original_name"])
                if fcd_ep_id in json_id_weight_info:
                    new_fcd_info = densify(iter_info_fcd, json_id_weight_info[fcd_ep_id])
                    densified_fcdinfo_list.append(new_fcd_info)
                else:
                    densified_fcdinfo_list.append(iter_info_fcd)
                fcd_id_list.append(int(iter_info_fcd["original_name"]))
                
        inter_ids = set(fcd_id_list).intersection(set(json_id_list))
        inter_ids = sorted(list(inter_ids))
        
        num_lines = len(inter_ids)
        num_sections = int(num_lines/num_lines_section)+1
        
        for i in range(num_sections):
            file_index = int(file_fcd.split("/")[-1].split(".")[0])
            new_file_index = file_index+offset
            file_fcd_output = file_fcd.replace(input_exp,output_exp)
            file_fcd_output = file_fcd_output.replace(f"{file_index}.fcd.json",f"{new_file_index}_{i}.fcd.json")
            file_json_output = file_json.replace(input_exp,output_exp)
            file_json_output = file_json_output.replace(f"{file_index}.json",f"{new_file_index}_{i}.json")
            if i == num_sections-1:
                stored_ids = inter_ids[i*num_lines_section:]
            else:
                stored_ids = inter_ids[i*num_lines_section:(i+1)*num_lines_section]
            
            if len(stored_ids) > 0:
                # filter problematic trajs (different length)
                filtered_stored_ids = []
                for ids in stored_ids:
                    line1 = densified_fcdinfo_list[fcd_id_list.index(ids)]
                    line2 = densified_jsoninfo_list[json_id_list.index(ids)]
                    time_step_count1 = len(line1["fcd-export"]["timestep"])
                    time_step_count2 = len(list(line2["CAV_info"].keys()))
                    if time_step_count1 != time_step_count2:
                        print("problem", line1["original_name"], file_fcd_output, time_step_count1, time_step_count2)
                    else:
                        filtered_stored_ids.append(ids)
                with open(file_fcd_output,'w') as fp1:
                    for ids in filtered_stored_ids:
                        line1 = densified_fcdinfo_list[fcd_id_list.index(ids)]
                        line1["original_name"] = str(int(line1["original_name"])+offset*20000)
                        json.dump(line1, fp1)
                        fp1.write('\n')
                with open(file_json_output,'w') as fp2:
                    for ids in filtered_stored_ids:
                        line2 = densified_jsoninfo_list[json_id_list.index(ids)]
                        line2["episode_info"]["id"] += offset*20000
                        json.dump(line2, fp2)
                        fp2.write('\n')
    return f"finished {fcd_files[0]}"


### Step 1: Remove unnecessary data

In [None]:
## Densify fcd files
out_exp = exp_name+"/densified_exps"
output_space = [os.path.join(root_folder,out_exp,"crash"),os.path.join(root_folder,out_exp,"tested_and_safe")]
for fo in output_space:
    os.makedirs(fo, exist_ok=True)
safe_folder = os.path.join(root_folder, exp_name, "tested_and_safe")
# safe_fcd_files = sorted(glob.glob(safe_folder+"/*.fcd.json"))
safe_fcd_files = []
for i in range(1000):
    safe_fcd_files.append(os.path.join(safe_folder,f"{i}.fcd.json"))
print(len(safe_fcd_files))
# print(safe_fcd_files)
crash_folder = os.path.join(root_folder, exp_name, "crash")
crash_fcd_files = sorted(glob.glob(crash_folder+"/*.fcd.json"))
print(len(crash_fcd_files))
num_each = 100

process_fcd_json(root_folder, exp_name, out_exp, crash_fcd_files, 0, "crash", offset=0)

split_safe_fcd_files = [safe_fcd_files[i*num_each:(i+1)*num_each] for i in range(10)]
print(len(split_safe_fcd_files))
process_fcd_json_mp(split_safe_fcd_files[0])
pool = Pool(5)
for result in pool.imap_unordered(process_fcd_json_mp, split_safe_fcd_files):
    print(result)

### Step 2: Analyze the safety critical events

In [None]:
folder = os.path.join(root_folder, out_exp)
safe_folder = os.path.join(folder, "tested_and_safe")
safe_fcd_files = sorted(glob.glob(safe_folder+"/*.fcd.json"))
crash_folder = os.path.join(folder, "crash")
crash_fcd_files = sorted(glob.glob(crash_folder+"/*.fcd.json"))
print(len(crash_fcd_files), len(safe_fcd_files))
p_list = []
num_each = int(len(safe_fcd_files)/10)+1
for i in range(10):
    if i == 9:
        files_list = safe_fcd_files[i*num_each:]
    else:
        files_list = safe_fcd_files[i*num_each:(i+1)*num_each]
    p = Process(target=process_func_new, args=(files_list, i, "safe"))
    p_list.append(p)
p = Process(target=process_func_new, args=(crash_fcd_files, 0, "crash"))
p_list.append(p)
for p_ind in p_list:
    p_ind.start()
for p_ind in p_list:
    p_ind.join()

In [None]:
def find_training_range_crash_new_nnmetric(crit_info, nade_info):
    """Find the interesting time step range for the crash event.

    Args:
        crit_info (list): List of criticality information.
        nade_info (list): List of NADE information.

    Returns:
        tuple: Time step range.
    """
    threshold = 0.9
    t_list = []
    crit_list = []
    for info in crit_info:
        t_list.append(info[0])
        crit_list.append(info[1])
    if crit_list[-1] < 0.001 and crit_list[-2] > 0:
        crit_list[-1] = 1
    lane_change_info_origin = []
    lane_change_info = []
    lane_change_crit = []
    for i in range(len(t_list)):
        if crit_list[i] >= threshold and min(crit_list[i:]) >= 0.001:
            return t_list[i], t_list[-1]
    return None, None

def find_training_range_safe_new_nnmetric(crit_info,nade_info,debug_flag=False):
    """Find the interesting time step range for the safe event.

    Args:
        crit_info (list): List of criticality information.
        nade_info (list): List of NADE information.
        debug_flag (bool, optional): Debug flag. Defaults to False.

    Returns:
        tuple: Time step range.
    """
    threshold = 0.9
    t_list = []
    crit_list = []
    for info in crit_info:
        t_list.append(info[0])
        crit_list.append(info[1])
    if crit_list[-1] == 0 and crit_list[-2] > 0:
        crit_list[-1] = 1
    lane_change_info_origin = []
    lane_change_info = []
    lane_change_crit = []
    for i in range(len(t_list)):
        if crit_list[i] >= threshold:
            for j in range(i+1, len(t_list)):
                if debug_flag:
                    print(i,j,min(crit_list[i:j]),min(crit_list[i:min(j+1,len(t_list))]))
                if min(crit_list[i:j]) >= 0.001:
                    if min(crit_list[i:j+1]) < 0.001 or j+1 == len(t_list):
                        return t_list[i], t_list[j]
    return None, None
    
def debug(info):
    """Debug the criticality information.

    Args:
        info (list): List of criticality information.
    """
    print(info[0])
    t_list = []
    crit_list = []
    for ind in info[8]:
        t_list.append(float(ind[0]))
        crit_list.append(ind[1])
    nade_info = info[7]
    lane_change_info = []
    lane_change_crit = []
    for ind in nade_info:
        if ind[1] in [0,1]:
            if lane_change_info == [] or float(ind[0])-lane_change_info[-1] >= 1:
                lane_change_info.append(float(ind[0]))
                if ind[2] > 0:
                    lane_change_crit.append("red")
                else:
                    lane_change_crit.append("blue")
    plt.figure(dpi=100)
    plt.plot(t_list,crit_list)
    for t in lane_change_info:
        plt.plot([t]*2,[0,1],c=lane_change_crit[lane_change_info.index(t)])
    plt.plot(t_list,[0]*len(t_list),"--",c="k",alpha=0.5)
    plt.plot(t_list,[0.01]*len(t_list),"--",c="k",alpha=0.5)
    plt.plot(t_list,[0.5]*len(t_list),"--",c="k",alpha=0.5)
    plt.plot(t_list,[0.9]*len(t_list),"--",c="k",alpha=0.5)
    plt.xlabel("time (s)")
    plt.ylabel("criticality")
    plt.show()

In [None]:
files = sorted(glob.glob(processed_data_tmp_folder_path+"/complete_info_offlinecollect-tmp_*"))
# print(files)
num_crash = 0
num_safe = 0

data_info = {
    "safe_id_list": [], "safe_weight": [], "safe_ep_info": [],
    "crash_id_list": [], "crash_weight": [], "crash_ep_info": [],
}

int_index = 5
output_file = "offline_av_alldata.json"
critic_len_list = []

for file in tqdm(files):
    print(file)
    with open(file) as fp:
        json_obj = json.load(fp)
    for i in range(len(json_obj["safe_id_list"])):
        info = json_obj["safe_id_list"][i]
        if len(info[int_index]) != 0:
            re = find_training_range_safe_new_nnmetric(info[8],info[7])
            if re[0] is None:
                num_safe+=1
            else:
                info_out = [
                    info[0], info[1], info[2], info[3],
                    int(float(re[0])*10), int(float(re[1])*10)
                ]
                data_info["safe_id_list"].append(info_out)
                data_info["safe_weight"].append(json_obj["safe_weight"][i])
                data_info["safe_ep_info"].append(tuple(info_out+[json_obj["safe_weight"][i]]))
               
    for i in range(len(json_obj["crash_id_list"])):
        info = json_obj["crash_id_list"][i]
        if len(info[int_index]) != 0:
            re = find_training_range_crash_new_nnmetric(info[8],info[7])
            if re[0] is None:
                num_crash+=1
                print("remove")
                debug(info)
            else:
                info_out = [
                    info[0], info[1], info[2], info[3],
                    int(float(re[0])*10), int(float(re[1])*10)
                ]
                data_info["crash_id_list"].append(info_out)
                data_info["crash_weight"].append(json_obj["crash_weight"][i])
                data_info["crash_ep_info"].append(tuple(info_out+[json_obj["crash_weight"][i]]))
    
                
print(num_safe, num_crash)
print(np.mean(critic_len_list))
with open(folder+f"/{output_file}", 'w') as fp:
    ujson.dump(data_info, fp)

In [None]:
# remove episode with length smaller than 1
json_file_path = folder+"/offline_av_alldata.json"
with open(json_file_path) as fp:
    data_info_origin = ujson.load(fp)
    print(len(data_info_origin["crash_id_list"]))
    print(data_info_origin.keys())
    print(len(data_info_origin["safe_id_list"]))
    partial_data_info_origin = {}
    partial_data_info_origin["crash_id_list"] = []
    partial_data_info_origin["crash_weight"] = []
    partial_data_info_origin["crash_ep_info"] = []
    partial_data_info_origin["safe_id_list"] = []
    partial_data_info_origin["safe_weight"] = []
    partial_data_info_origin["safe_ep_info"] = []
    for i in range(len(data_info_origin["crash_id_list"])):
        if data_info_origin["crash_id_list"][i][5]-data_info_origin["crash_id_list"][i][4]>=2:
            partial_data_info_origin["crash_id_list"].append(data_info_origin["crash_id_list"][i])
            partial_data_info_origin["crash_weight"].append(data_info_origin["crash_weight"][i])
            partial_data_info_origin["crash_ep_info"].append(data_info_origin["crash_ep_info"][i])
        else:
            print(data_info_origin["crash_id_list"][i])
    for i in range(len(data_info_origin["safe_id_list"])):
        if data_info_origin["safe_id_list"][i][5]-data_info_origin["safe_id_list"][i][4]>=2:
            partial_data_info_origin["safe_id_list"].append(data_info_origin["safe_id_list"][i])
            partial_data_info_origin["safe_weight"].append(data_info_origin["safe_weight"][i])
            partial_data_info_origin["safe_ep_info"].append(data_info_origin["safe_ep_info"][i])
    print(len(partial_data_info_origin["crash_id_list"]),len(partial_data_info_origin["safe_id_list"]))
with open(folder+"/offline_av_alldata_new.json", "w") as fp:
    ujson.dump(partial_data_info_origin, fp)
    
plt.figure()
weight_list = np.array(partial_data_info_origin["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)])
plt.show()

plt.figure()
weight_list = np.array(partial_data_info_origin["safe_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)])
plt.show()

plt.figure()
weight_list = np.array(partial_data_info_origin["safe_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(11)])
print(sum(weight_list>3))
plt.show()

### Step 3: Analyze the near-miss events

In [None]:
# calculate center distance between vehicles
veh_length = 5.0
veh_width = 1.8
circle_r = 1.227
tem_len = math.sqrt(circle_r**2-(veh_width/2)**2)

def read_json_fcd_from_json(fcd_file, safe_data):
    """Load the trajectory data from the json file.
    
    Args:
        fcd_file (str): Path to the fcd file.
        safe_data (list): List of safe data.
        
    Returns:
        list: List of trajectory data.
    """
    safe_id_list = [ep[0] for ep in safe_data]
    fcdjson_obj_list = []
    with open(fcd_file) as fp:
        for line in fp:
            fcdjson_obj = json.loads(line)
            ep_id = int(fcdjson_obj["original_name"])
            if ep_id in safe_id_list:
                fcdjson_obj_list.append(fcdjson_obj)
    return fcdjson_obj_list

def find_three_circle_centers(veh_info):
    """Find the centers of the three circles which can cover the vehicle.

    Args:
        veh_info (dict): Vehicle information.

    Returns:
        list: List of the centers of the three circles.
    """
    x, y = float(veh_info["@x"]), float(veh_info["@y"])
    heading = float(veh_info["@angle"])/180*math.pi
    center1 = (
        x-veh_length/2*math.sin(heading), 
        y-veh_length/2*math.cos(heading)
    )
    center0 = (
        center1[0]+(veh_length/2-tem_len)*math.sin(heading),
        center1[1]+(veh_length/2-tem_len)*math.cos(heading)
    )
    center2 = (
        center1[0]-(veh_length/2-tem_len)*math.sin(heading),
        center1[1]-(veh_length/2-tem_len)*math.cos(heading)
    )
    center_list = [center0, center1, center2]
    return center_list

def get_smallest_dist(fcdjson_obj, safe_data):
    """Get the smallest distance between the centers of the vehicless covering circles.

    Args:
        fcdjson_obj (dict): Trajectory data.
        safe_data (list): List of safe data.
        
    Returns:
        tuple: Smallest distance and the episode id.
    """
    traj_info = fcdjson_obj["fcd-export"]
    safe_id_list = [ep[0] for ep in safe_data]
    assert(int(fcdjson_obj["original_name"])in safe_id_list)
    index = safe_id_list.index(int(fcdjson_obj["original_name"]))
    clip = (safe_data[index][4],safe_data[index][5])
    min_cc_distance_list_time = []
    for m in traj_info["timestep"]:
        t = int(round(float(m["@time"]),1)*10)
        if t < clip[0] or t > clip[1]:
            continue
        vehs = m["vehicle"]
        assert(vehs[-1]["@id"]=="CAV")
        cav_info = vehs[-1]
        cav_three_circles = find_three_circle_centers(cav_info)
        min_cc_distance_list_vehicles = []
        for bv in vehs:
            if bv["@id"] == "CAV":
                continue
            bv_three_circles = find_three_circle_centers(bv)
            cc_distance_list = []
            for cav_c in cav_three_circles:
                for bv_c in bv_three_circles:
                    cc_distance_list.append(math.sqrt((cav_c[0]-bv_c[0])**2+(cav_c[1]-bv_c[1])**2))
            min_cc_distance_list_vehicles.append(min(cc_distance_list))
        min_cc_distance_list_time.append(min(min_cc_distance_list_vehicles))
    if min_cc_distance_list_time == []:
        print(safe_data[index],len(traj_info["timestep"]))
        min_cc_distance_list_time = [100]
    return min(min_cc_distance_list_time), int(fcdjson_obj["original_name"])

def main(fcd_files, index, all_data_path):
    """Main function for the safety-critical event analysis. We will calculate the minimum distance between the centers of the vehicles covering circles.

    Args:
        fcd_files (list): List of fcd files.
        index (int): Index of the process.
        all_data_path (str): Path to the data file.
    """
    results = []
    with open(all_data_path) as fp:
        all_data = json.load(fp)
        safe_data = all_data["safe_id_list"]
    for fcd_file in tqdm(fcd_files):
        fcdjson_obj_list = read_json_fcd_from_json(fcd_file, safe_data)
        for fcdjson_obj in fcdjson_obj_list:
            min_dist, ep_id = get_smallest_dist(fcdjson_obj, safe_data)
            results.append([min_dist,ep_id])
    print(len(results))
    results_plot = np.array(results)[:,0]
    plt.figure(dpi=100)
    plt.hist(results_plot, bins=100)
    plt.show()
    with open(f"{processed_data_tmp_folder_path}/min_center_distance_{index}.npy", 'wb') as f:
        print(f"{processed_data_tmp_folder_path}/min_center_distance_{index}.npy")
        np.save(f, results)

In [None]:
# smallest distance for three circles
all_data_path = folder+"/offline_av_alldata_new.json"
tested_and_safe_folder = folder+"/tested_and_safe/"
safe_fcd_files = sorted(glob.glob(tested_and_safe_folder+"*.fcd.json"))
print(len(safe_fcd_files))
num_each = int(len(safe_fcd_files)/20)+1
p_list = []
for i in range(20):
    p = Process(target=main, args=(safe_fcd_files[i*num_each:(i+1)*num_each], i, all_data_path))
    p_list.append(p)
for p_ind in p_list:
    p_ind.start()
for p_ind in p_list:
    p_ind.join()

In [None]:
# find near-miss, minimum three-circle-distance<2.5
np_files = glob.glob(f"{processed_data_tmp_folder_path}/min_center_distance_*.npy")
print(np_files)

results_min_dist = []
ep_id_list = []
for f in np_files:
    new_results = np.load(f)
    results_min_dist = np.append(results_min_dist, new_results[:,0])
    ep_id_list = np.append(ep_id_list, new_results[:,1])
print(len(results_min_dist))

file_name_alldata = "offline_av_alldata_new.json"

with open(os.path.join(folder, file_name_alldata)) as fp1:
    summary_info = ujson.load(fp1)

dist_threshold = 2.5
output_data = {
    "safe2crash_id_list": [],
    "safe2crash_weight": [],
    "safe2crash_ep_info": [],
}
safe2crash_id_list = []
safe2crash_weight = []
safe_id_list = [info[0] for info in summary_info["safe_id_list"]]
for i in tqdm(range(len(ep_id_list))):
    if results_min_dist[i] < dist_threshold:
        ep_id = ep_id_list[i]
        try:
            j = safe_id_list.index(ep_id)
        except:
            print(ep_id)
            continue
        ep_info = summary_info["safe_id_list"][j]
        output_data["safe2crash_id_list"].append(ep_info)
        output_data["safe2crash_weight"].append(summary_info["safe_weight"][j])
        output_data["safe2crash_ep_info"].append(tuple(ep_info+[summary_info["safe_weight"][j]]))
print(len(output_data["safe2crash_id_list"]), len(output_data["safe2crash_weight"]))

file_name_output = "offline_av_nearmiss_new.json"
with open(os.path.join(folder, file_name_output), "w") as fp2:
    ujson.dump(output_data, fp2)
    
plt.figure()
weight_list = np.array(output_data["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)])
plt.show()

plt.figure()
weight_list = np.array(output_data["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(11)])
plt.show()


### Step 4: Store the processed inforamtion of all safety critical events including crashes and near-miss events

In [None]:
json_file_path = folder+"/offline_av_alldata_new.json"
with open(json_file_path) as fp:
    data_info_origin = ujson.load(fp)
    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))

json_file_path = folder+"/offline_av_nearmiss_new.json"
with open(json_file_path) as fp:
    data_info = ujson.load(fp)
    print(len(data_info["safe2crash_id_list"]))
print(data_info.keys())
data_info_new = {}
data_info_new["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crash_ep_info"] = data_info_origin["crash_ep_info"]
del_index_list = []
for i in range(len(data_info["safe2crash_id_list"])):
    if data_info["safe2crash_id_list"][i] not in data_info_origin["safe_id_list"]:
        print(i)
        del_index_list.append(i)
for i in del_index_list:
    data_info["safe2crash_id_list"].pop(i)
    data_info["safe2crash_weight"].pop(i)
    data_info["safe2crash_ep_info"].pop(i)
print(len(data_info["safe2crash_id_list"]))
data_info_new["safe2crash_id_list"] = data_info["safe2crash_id_list"]
data_info_new["safe2crash_weight"] = data_info["safe2crash_weight"]
data_info_new["safe2crash_ep_info"] = data_info["safe2crash_ep_info"]
data_info_new["crashnearmiss_history"] = {}
data_info_new["crashnearmiss_history"]["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crashnearmiss_history"]["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crashnearmiss_history"]["crash_ep_info"] = data_info_origin["crash_ep_info"]
data_info_new["crashnearmiss_history"]["crash_score_list"] = [[1,0,0,0]]*len(data_info_origin["crash_id_list"])
data_info_new["crashnearmiss_history"]["safe2crash_id_list"] = data_info["safe2crash_id_list"]
data_info_new["crashnearmiss_history"]["safe2crash_weight"] = data_info["safe2crash_weight"]
data_info_new["crashnearmiss_history"]["safe2crash_ep_info"] = data_info["safe2crash_ep_info"]
data_info_new["crashnearmiss_history"]["safe2crash_score_list"] = [[0,1,0,1]]*len(data_info["safe2crash_id_list"])

json_file_path_new = folder+"/offline_av_neweval_crashnearmiss_new_origin.json"
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)
    
json_file_path_new = folder+"/offline_av_neweval_crashnearmiss_new.json"
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)
    
plt.figure()
weight_list = np.array(data_info_new["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="crash")
weight_list = np.array(data_info_new["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()
plt.figure()
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(9)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()