In [8]:
import copy
import glob
import json, ujson
import math
import matplotlib.pyplot as plt
from multiprocessing import Process, Manager, Pool
import numpy as np
import os
from tqdm import tqdm


In [9]:
# please give the path to the data folder
root_folder = "/absolute/path/to/data/folder"

# please give the experiment name of the evaluation
exp_name = "Experiment-testing_DateOfRunning"

processed_data_folder = os.path.join(root_folder, exp_name, "densified_exps")
original_summary_file = os.path.join(processed_data_folder, "offline_av_alldata_new.json")
processed_data_tmp_folder_path = "processed_data_tmp"


In [None]:
# ablation study: No episodic data densification (NEDD)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new.json")) as fp:
    data_info_origin = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NEDD.json"), "w") as fp2:
    ujson.dump(data_info_origin, fp2)

with open(original_summary_file) as fp:
    data_info_origin = ujson.load(fp)
    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))

data_info_new = {}
data_info_new["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crash_weight"] = [1 for _ in data_info_origin["crash_weight"]]
data_info_new["crash_ep_info"] = data_info_origin["crash_ep_info"]

data_info_new["safe2crash_id_list"] = data_info_origin["safe_id_list"]
data_info_new["safe2crash_weight"] = [1 for _ in data_info_origin["safe_weight"]]
data_info_new["safe2crash_ep_info"] = data_info_origin["safe_ep_info"]
data_info_new["crashnearmiss_history"] = {}
data_info_new["crashnearmiss_history"]["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crashnearmiss_history"]["crash_weight"] = [1 for _ in data_info_origin["crash_weight"]]
data_info_new["crashnearmiss_history"]["crash_ep_info"] = data_info_origin["crash_ep_info"]
data_info_new["crashnearmiss_history"]["crash_score_list"] = [[1,0,0,0]]*len(data_info_origin["crash_id_list"])
data_info_new["crashnearmiss_history"]["safe2crash_id_list"] = data_info_origin["safe_id_list"]
data_info_new["crashnearmiss_history"]["safe2crash_weight"] = [1 for _ in data_info_origin["safe_weight"]]
data_info_new["crashnearmiss_history"]["safe2crash_ep_info"] = data_info_origin["safe_ep_info"]
data_info_new["crashnearmiss_history"]["safe2crash_score_list"] = [[0,1,0,1]]*len(data_info_origin["safe_id_list"])

json_file_path_new = os.path.join(root_folder, exp_name, "densified_exps", "offline_av_neweval_crashnearmiss_new_ablationstudy_NEDD.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)
    
plt.figure()
weight_list = np.array(data_info_new["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="crash")
weight_list = np.array(data_info_new["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()
plt.figure()
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(9)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()

In [11]:
# calculate center distance between vehicles
veh_length = 5.0
veh_width = 1.8
circle_r = 1.227
tem_len = math.sqrt(circle_r**2-(veh_width/2)**2)

def read_json_fcd_from_json(fcd_file, safe_data):
    """Load the trajectory data from the json file.
    
    Args:
        fcd_file (str): Path to the fcd file.
        safe_data (list): List of safe data.
        
    Returns:
        list: List of trajectory data.
    """
    safe_id_list = [ep[0] for ep in safe_data]
    fcdjson_obj_list = []
    with open(fcd_file) as fp:
        for line in fp:
            fcdjson_obj = json.loads(line)
            ep_id = int(fcdjson_obj["original_name"])
            if ep_id in safe_id_list:
                fcdjson_obj_list.append(fcdjson_obj)
    return fcdjson_obj_list

def find_three_circle_centers(veh_info):
    """Find the centers of the three circles which can cover the vehicle.

    Args:
        veh_info (dict): Vehicle information.

    Returns:
        list: List of the centers of the three circles.
    """
    x, y = float(veh_info["@x"]), float(veh_info["@y"])
    heading = float(veh_info["@angle"])/180*math.pi
    center1 = (
        x-veh_length/2*math.sin(heading), 
        y-veh_length/2*math.cos(heading)
    )
    center0 = (
        center1[0]+(veh_length/2-tem_len)*math.sin(heading),
        center1[1]+(veh_length/2-tem_len)*math.cos(heading)
    )
    center2 = (
        center1[0]-(veh_length/2-tem_len)*math.sin(heading),
        center1[1]-(veh_length/2-tem_len)*math.cos(heading)
    )
    center_list = [center0, center1, center2]
    return center_list

def get_smallest_dist(fcdjson_obj, safe_data):
    """Get the smallest distance between the centers of the vehicless covering circles.

    Args:
        fcdjson_obj (dict): Trajectory data.
        safe_data (list): List of safe data.
        
    Returns:
        tuple: Smallest distance and the episode id.
    """
    traj_info = fcdjson_obj["fcd-export"]
    safe_id_list = [ep[0] for ep in safe_data]
    assert(int(fcdjson_obj["original_name"])in safe_id_list)
    index = safe_id_list.index(int(fcdjson_obj["original_name"]))
    clip = (safe_data[index][4],safe_data[index][5])
    min_cc_distance_list_time = []
    for m in traj_info["timestep"]:
        t = int(round(float(m["@time"]),1)*10)
        if t < clip[0] or t > clip[1]:
            continue
        vehs = m["vehicle"]
        assert(vehs[-1]["@id"]=="CAV")
        cav_info = vehs[-1]
        cav_three_circles = find_three_circle_centers(cav_info)
        min_cc_distance_list_vehicles = []
        for bv in vehs:
            if bv["@id"] == "CAV":
                continue
            bv_three_circles = find_three_circle_centers(bv)
            cc_distance_list = []
            for cav_c in cav_three_circles:
                for bv_c in bv_three_circles:
                    cc_distance_list.append(math.sqrt((cav_c[0]-bv_c[0])**2+(cav_c[1]-bv_c[1])**2))
            min_cc_distance_list_vehicles.append(min(cc_distance_list))
        min_cc_distance_list_time.append(min(min_cc_distance_list_vehicles))
    if min_cc_distance_list_time == []:
        print(safe_data[index],len(traj_info["timestep"]))
        min_cc_distance_list_time = [100]
    return min(min_cc_distance_list_time), int(fcdjson_obj["original_name"])

def main(fcd_files, index, all_data_path):
    """Main function for the safety-critical event analysis. We will calculate the minimum distance between the centers of the vehicles covering circles.

    Args:
        fcd_files (list): List of fcd files.
        index (int): Index of the process.
        all_data_path (str): Path to the data file.
    """
    results = []
    with open(all_data_path) as fp:
        all_data = json.load(fp)
        safe_data = all_data["safe_id_list"]
    for fcd_file in tqdm(fcd_files):
        fcdjson_obj_list = read_json_fcd_from_json(fcd_file, safe_data)
        for fcdjson_obj in fcdjson_obj_list:
            min_dist, ep_id = get_smallest_dist(fcdjson_obj, safe_data)
            results.append([min_dist,ep_id])
    print(len(results))
    results_plot = np.array(results)[:,0]
    plt.figure(dpi=100)
    plt.hist(results_plot, bins=100)
    plt.show()
    with open(f"{processed_data_tmp_folder_path}/min_center_distance_ablationstudy_NSLDD_{index}.npy", 'wb') as f:
        print(f"{processed_data_tmp_folder_path}/min_center_distance_ablationstudy_NSLDD_{index}.npy")
        np.save(f, results)

In [None]:
# ablation study: No state-level data densification (NSLDD)
def find_training_range(crit_info):
    t_list = []
    for info in crit_info:
        t_list.append(info[0])
    return t_list[0], t_list[-1]
    

files = sorted(glob.glob(os.path.join(processed_data_tmp_folder_path, "complete_info_offlinecollect-tmp_*")))
# print(files)
num_crash = 0
num_safe = 0

data_info_origin = {
    "safe_id_list": [], "safe_weight": [], "safe_ep_info": [],
    "crash_id_list": [], "crash_weight": [], "crash_ep_info": [],
}

int_index = 5
ablationstudy_original_summary_file = os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NSLDD.json")
critic_len_list = []


for file in tqdm(files):
#     print(file)
    with open(file) as fp:
        json_obj = json.load(fp)
#     print(json_obj["safe_id_list"])
#     print(json_obj["crash_id_list"][:100])
    for i in range(len(json_obj["safe_id_list"])):
        info = json_obj["safe_id_list"][i]
        if len(info[int_index]) != 0:
            re = find_training_range(info[8])
#             debug(info)
            if re[0] is None:
                num_safe+=1
            else:
                info_out = [
                    info[0], info[1], info[2], info[3],
                    int(float(re[0])*10), int(float(re[1])*10)
                ]
#                 print(info_out)
                data_info_origin["safe_id_list"].append(info_out)
                data_info_origin["safe_weight"].append(json_obj["safe_weight"][i])
                data_info_origin["safe_ep_info"].append(tuple(info_out+[json_obj["safe_weight"][i]]))
               
    for i in range(len(json_obj["crash_id_list"])):
        info = json_obj["crash_id_list"][i]
        if len(info[int_index]) != 0:
            re = find_training_range(info[8])
#             debug(info)
            if re[0] is None:
                num_crash+=1
                print("remove")
            else:
                info_out = [
                    info[0], info[1], info[2], info[3],
                    int(float(re[0])*10), int(float(re[1])*10)
                ]
#                 print(info_out)
                data_info_origin["crash_id_list"].append(info_out)
                data_info_origin["crash_weight"].append(json_obj["crash_weight"][i])
                data_info_origin["crash_ep_info"].append(tuple(info_out+[json_obj["crash_weight"][i]]))
#                 critic_len_list.append(ranges[-1][-1]-ranges[-1][0]+1)
       
print(num_safe, num_crash)
print(np.mean(critic_len_list))

with open(os.path.join(processed_data_folder, "offline_av_alldata_ablationstudy_NSLDD.json"), "w") as fp:
    ujson.dump(data_info_origin, fp)

filtered_ep_id = []

print(len(data_info_origin["crash_id_list"]))
print(data_info_origin.keys())
print(len(data_info_origin["safe_id_list"]))
partial_data_info_origin = {}
partial_data_info_origin["crash_id_list"] = []
partial_data_info_origin["crash_weight"] = []
partial_data_info_origin["crash_ep_info"] = []
partial_data_info_origin["safe_id_list"] = []
partial_data_info_origin["safe_weight"] = []
partial_data_info_origin["safe_ep_info"] = []
for i in range(len(data_info_origin["crash_id_list"])):
    if data_info_origin["crash_id_list"][i][0] in filtered_ep_id:
        print("crash:", data_info_origin["crash_id_list"][i])
        continue
    if data_info_origin["crash_id_list"][i][5]-data_info_origin["crash_id_list"][i][4]>=2:
        partial_data_info_origin["crash_id_list"].append(data_info_origin["crash_id_list"][i])
        partial_data_info_origin["crash_weight"].append(data_info_origin["crash_weight"][i])
        partial_data_info_origin["crash_ep_info"].append(data_info_origin["crash_ep_info"][i])
    else:
        print(data_info_origin["crash_id_list"][i])
for i in range(len(data_info_origin["safe_id_list"])):
    if data_info_origin["safe_id_list"][i][0] in filtered_ep_id:
        print("safe:",data_info_origin["safe_id_list"][i])
        continue
    if data_info_origin["safe_id_list"][i][5]-data_info_origin["safe_id_list"][i][4]>=2:
        partial_data_info_origin["safe_id_list"].append(data_info_origin["safe_id_list"][i])
        partial_data_info_origin["safe_weight"].append(data_info_origin["safe_weight"][i])
        partial_data_info_origin["safe_ep_info"].append(data_info_origin["safe_ep_info"][i])
print(len(partial_data_info_origin["crash_id_list"]),len(partial_data_info_origin["safe_id_list"]))

with open(ablationstudy_original_summary_file, "w") as fp:
    ujson.dump(partial_data_info_origin, fp)

# smallest distance for three circles
all_data_path = ablationstudy_original_summary_file
tested_and_safe_folder = os.path.join(processed_data_folder,"tested_and_safe")
print(tested_and_safe_folder)
safe_fcd_files = sorted(glob.glob(os.path.join(tested_and_safe_folder,"*.fcd.json")))
print("len(safe_fcd_files)",len(safe_fcd_files))
num_each = int(len(safe_fcd_files)/20)+1
p_list = []
for i in range(20):
    p = Process(target=main, args=(safe_fcd_files[i*num_each:(i+1)*num_each], i, all_data_path))
    p_list.append(p)
for p_ind in p_list:
    p_ind.start()
for p_ind in p_list:
    p_ind.join()

# find near-miss, minimum three-circle-distance<2.5
np_files = glob.glob(f"{processed_data_tmp_folder_path}/min_center_distance_ablationstudy_NSLDD_*.npy")
print(np_files)

results_min_dist = []
ep_id_list = []
for f in np_files:
    new_results = np.load(f)
    results_min_dist = np.append(results_min_dist, new_results[:,0])
    ep_id_list = np.append(ep_id_list, new_results[:,1])
print(len(results_min_dist))

file_name_alldata = "offline_av_alldata_new.json"

summary_info = partial_data_info_origin

dist_threshold = 2.5
output_data = {
    "safe2crash_id_list": [],
    "safe2crash_weight": [],
    "safe2crash_ep_info": [],
}
safe2crash_id_list = []
safe2crash_weight = []
safe_id_list = [info[0] for info in summary_info["safe_id_list"]]
for i in tqdm(range(len(ep_id_list))):
    if results_min_dist[i] < dist_threshold:
        ep_id = ep_id_list[i]
        try:
            j = safe_id_list.index(ep_id)
        except:
            print(ep_id)
            continue
        ep_info = summary_info["safe_id_list"][j]
        output_data["safe2crash_id_list"].append(ep_info)
        output_data["safe2crash_weight"].append(summary_info["safe_weight"][j])
        output_data["safe2crash_ep_info"].append(tuple(ep_info+[summary_info["safe_weight"][j]]))
print(len(output_data["safe2crash_id_list"]), len(output_data["safe2crash_weight"]))

file_name_output = "offline_av_nearmiss_new_ablationstudy_NSLDD.json"
with open(os.path.join(processed_data_folder, file_name_output), "w") as fp2:
    ujson.dump(output_data, fp2)

with open(ablationstudy_original_summary_file) as fp:
    data_info_origin = ujson.load(fp)
    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))

nearmiss_json_file_path = os.path.join(processed_data_folder, "offline_av_nearmiss_new_ablationstudy_NSLDD.json")

with open(nearmiss_json_file_path) as fp:
    data_info = ujson.load(fp)
    print(len(data_info["safe2crash_id_list"]))
print(data_info.keys())

data_info_new = {}
data_info_new["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crash_ep_info"] = data_info_origin["crash_ep_info"]
del_index_list = []
for i in tqdm(range(len(data_info["safe2crash_id_list"]))):
    if data_info["safe2crash_id_list"][i] not in data_info_origin["safe_id_list"]:
        print(i,data_info["safe2crash_id_list"][i])
        del_index_list.append(i)
for i in del_index_list:
    data_info["safe2crash_id_list"].pop(i)
    data_info["safe2crash_weight"].pop(i)
    data_info["safe2crash_ep_info"].pop(i)
print(len(data_info["safe2crash_id_list"]))
data_info_new["safe2crash_id_list"] = data_info["safe2crash_id_list"]
data_info_new["safe2crash_weight"] = data_info["safe2crash_weight"]
data_info_new["safe2crash_ep_info"] = data_info["safe2crash_ep_info"]
data_info_new["crashnearmiss_history"] = {}
data_info_new["crashnearmiss_history"]["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crashnearmiss_history"]["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crashnearmiss_history"]["crash_ep_info"] = data_info_origin["crash_ep_info"]
data_info_new["crashnearmiss_history"]["crash_score_list"] = [[1,0,0,0]]*len(data_info_origin["crash_id_list"])
data_info_new["crashnearmiss_history"]["safe2crash_id_list"] = data_info["safe2crash_id_list"]
data_info_new["crashnearmiss_history"]["safe2crash_weight"] = data_info["safe2crash_weight"]
data_info_new["crashnearmiss_history"]["safe2crash_ep_info"] = data_info["safe2crash_ep_info"]
data_info_new["crashnearmiss_history"]["safe2crash_score_list"] = [[0,1,0,1]]*len(data_info["safe2crash_id_list"])

json_file_path_new = os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NSLDD.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)
    
import matplotlib.pyplot as plt
import numpy as np
plt.figure()
weight_list = np.array(data_info_new["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="crash")
weight_list = np.array(data_info_new["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()
plt.figure()
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(9)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()

In [None]:
# ablation study: No near-miss episodes (NNME)
with open(original_summary_file) as fp:
    data_info_origin = ujson.load(fp)
    
    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))

data_info_origin_new = copy.deepcopy(data_info_origin)

json_file_path_new = os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NNME.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_origin_new, fp2)
    
data_info_new = {}
data_info_new["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crash_ep_info"] = data_info_origin["crash_ep_info"]

data_info_new["safe2crash_id_list"] = []
data_info_new["safe2crash_weight"] = []
data_info_new["safe2crash_ep_info"] = []
data_info_new["crashnearmiss_history"] = {}
data_info_new["crashnearmiss_history"]["crash_id_list"] = data_info_origin["crash_id_list"]
data_info_new["crashnearmiss_history"]["crash_weight"] = data_info_origin["crash_weight"]
data_info_new["crashnearmiss_history"]["crash_ep_info"] = data_info_origin["crash_ep_info"]
data_info_new["crashnearmiss_history"]["crash_score_list"] = [[1,0,0,0]]*len(data_info_origin["crash_id_list"])
data_info_new["crashnearmiss_history"]["safe2crash_id_list"] = []
data_info_new["crashnearmiss_history"]["safe2crash_weight"] = []
data_info_new["crashnearmiss_history"]["safe2crash_ep_info"] = []
data_info_new["crashnearmiss_history"]["safe2crash_score_list"] = []

json_file_path_new = os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NNME.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)
    
plt.figure()
weight_list = np.array(data_info_new["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="crash")
weight_list = np.array(data_info_new["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()
plt.figure()
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(9)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()

In [None]:
# ablation study: No retrospective data densification (NRDD)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new.json")) as fp:
    data_info_origin = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NRDD.json"), "w") as fp2:
    ujson.dump(data_info_origin, fp2)

with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_origin.json")) as fp:
    data_info_new = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NRDD.json"), "w") as fp2:
    ujson.dump(data_info_new, fp2)

In [15]:
# ablation study: No near-miss episodes and retrospective data densification (NNME_NRDD)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new.json")) as fp:
    data_info_origin = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NNME_NRDD.json"), "w") as fp2:
    ujson.dump(data_info_origin, fp2)

with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_origin.json")) as fp:
    data_info_new = ujson.load(fp)
    data_info_new["safe2crash_id_list"] = []
    data_info_new["safe2crash_weight"] = []
    data_info_new["safe2crash_ep_info"] = []
    data_info_new["crashnearmiss_history"]["safe2crash_id_list"] = []
    data_info_new["crashnearmiss_history"]["safe2crash_weight"] = []
    data_info_new["crashnearmiss_history"]["safe2crash_ep_info"] = []
    data_info_new["crashnearmiss_history"]["safe2crash_score_list"] = []
with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NNME_NRDD.json"), "w") as fp2:
    ujson.dump(data_info_new, fp2)

In [None]:
# ablation study: No trajectory resampling by probability of occurrence in NDE (NRNDE)
with open(original_summary_file) as fp:
    data_info_origin = ujson.load(fp)
    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))
    for i in range(len(data_info_origin["crash_id_list"])):
        data_info_origin["crash_weight"][i] = 1.
        data_info_origin["crash_ep_info"][i][-1] = 1.
    for i in range(len(data_info_origin["safe_id_list"])):
        data_info_origin["safe_weight"][i] = 1.
        data_info_origin["safe_ep_info"][i][-1] = 1.

json_file_path_new = os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NRNDE.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_origin, fp2)

with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_origin.json")) as fp:
    data_info_new = ujson.load(fp)

    print(len(data_info_origin["crash_id_list"]))
    print(len(data_info_origin["safe_id_list"]))
    for i in range(len(data_info_new["crash_id_list"])):
        data_info_new["crash_weight"][i] = 1.
        data_info_new["crash_ep_info"][i][-1] = 1.
    for i in range(len(data_info_new["safe2crash_id_list"])):
        data_info_new["safe2crash_weight"][i] = 1.
        data_info_new["safe2crash_ep_info"][i][-1] = 1.
    
    for i in range(len(data_info_new["crashnearmiss_history"]["crash_id_list"])):
        data_info_new["crashnearmiss_history"]["crash_weight"][i] = 1.
        data_info_new["crashnearmiss_history"]["crash_ep_info"][i][-1] = 1.
    for i in range(len(data_info_new["crashnearmiss_history"]["safe2crash_id_list"])):
        data_info_new["crashnearmiss_history"]["safe2crash_weight"][i] = 1.
        data_info_new["crashnearmiss_history"]["safe2crash_ep_info"][i][-1] = 1.
        
json_file_path_new = os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NRNDE.json")
with open(json_file_path_new, "w") as fp2:
    ujson.dump(data_info_new, fp2)

plt.figure()
weight_list = np.array(data_info_new["crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="crash")
weight_list = np.array(data_info_new["safe2crash_weight"])
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10-15 for i in range(81)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()
plt.figure()
plt.hist(np.log10(weight_list[np.nonzero(weight_list)]), bins=[2*i/10 for i in range(9)],alpha=0.5,label="near-miss")
plt.legend()
plt.show()

In [17]:
# ablation study: No reconnection of informative states in Markov process (NRSMDP)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new.json")) as fp:
    data_info_origin = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_alldata_new_ablationstudy_NRSMDP.json"), "w") as fp2:
    ujson.dump(data_info_origin, fp2)

with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_origin.json")) as fp:
    data_info_new = ujson.load(fp)
with open(os.path.join(processed_data_folder, "offline_av_neweval_crashnearmiss_new_ablationstudy_NRSMDP.json"), "w") as fp2:
    ujson.dump(data_info_new, fp2)