In [1]:
import os
import sys
import cv2
import json
import ndjson
import shutil
import random
import pickle
import imageio
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
from pprint import pprint
import matplotlib.pyplot as plt
sys.path.append('../')
from utils import normalize_image, plot_image, plot_images, get_sequences

In [52]:
def parse_results_txt(results_txt):
    results = {}
    with open(results_txt, "r") as f:
        data = f.readlines()
        for line in data:
            if ": " in line:
                k, v = line.split(": ")
                if len(v) > 0 and v != 'nan':
                    results[k] = float(v)
    return results

def collect_results(pred_root_dir, models, safety_sets, prod_sets, size_list):
    safety_cols = sorted([v+"_"+size for v in list(safety_sets.values()) for size in size_list])
    prod_cols = sorted([v+"_"+size for v in list(prod_sets.values()) for size in size_list])
    columns = \
        ["model"] + \
        ["safety_avg", "prod_avg"] + \
        safety_cols + \
        prod_cols
    df = pd.DataFrame(columns=columns)

    df_records = []
    for model in models:
        row = {"model": model}
        for safety_set, safety_set_shortname in safety_sets.items():
            for size in size_list:
                results_txt = os.path.join(pred_root_dir, model, safety_set+"_"+size, "results.txt")
                results = parse_results_txt(results_txt)
                row[safety_set_shortname+"_"+size] = results["recall_image"]
        for prod_set, prod_set_shortname in prod_sets.items():
            for size in size_list:
                results_txt = os.path.join(pred_root_dir, model, prod_set+"_"+size, "results.txt")
                results = parse_results_txt(results_txt)
                row[prod_set_shortname+"_"+size] = results["productivity_image"]
        row["safety_avg"] = np.mean([row[safety_col] for safety_col in safety_cols])
        row["prod_avg"] = np.mean([row[prod_col] for prod_col in prod_cols])
        df_records.append(row)
    df = pd.DataFrame.from_dict(df_records)
    return df[columns]

In [55]:
pred_root_dir = "/data/jupiter/li.yu/exps/driveable_terrain_model/"
models = [
    "rgb_baseline_sample_a_v3_2", 
    "sa3_rgb_8cls_1002", 
    "sa3_rgbnir0822_rgb_8cls_1002", 
    "sa3_rgbnir0822rd_rgb_8cls_1002", 
    "sa3_rgbnir0822rdcc_rgb_8cls_1003",
]
safety_sets = {
    "humans_on_path_v5_2023_halo_test_set_anno_with_ocal": "human_v5",
}
prod_sets = {
    "20230912_halo_rgb_productivity_day_candidate_1_cleaned_v3_no_ocal": "day_1", 
    "20230929_halo_rgb_productivity_day_candidate_4_cleaned_v2_no_ocal": "day_4", 
    "20230929_halo_rgb_productivity_day_candidate_8_cleaned_v1_with_ocal_no_drop": "day_8", 
    "20230929_halo_rgb_productivity_day_candidate_10_cleaned_v1_no_ocal": "day_10", 
    "20230929_halo_rgb_productivity_day_candidate_12_dirty_cleaned_v0_no_ocal": "day_12", 
    "20230929_halo_rgb_productivity_day_candidate_13_dirty_no_ocal": "day_13", 
    "20230912_halo_rgb_productivity_night_candidate_0_no_ocal_rgb_branch": "night_0", 
    "20230929_halo_rgb_productivity_night_candidate_4_cleaned_v1_no_ocal": "night_4"
}
size_list = ["640", "768"]

In [60]:
df = collect_results(pred_root_dir, models, safety_sets, prod_sets, size_list)
df.round(5)

Unnamed: 0,model,safety_avg,prod_avg,human_v5_640,human_v5_768,day_10_640,day_10_768,day_12_640,day_12_768,day_13_640,...,day_1_640,day_1_768,day_4_640,day_4_768,day_8_640,day_8_768,night_0_640,night_0_768,night_4_640,night_4_768
0,rgb_baseline_sample_a_v3_2,0.99755,0.98998,0.9984,0.9967,0.99727,0.99317,0.99839,0.99841,0.99758,...,1.0,0.99787,0.99443,0.9911,0.94434,0.93008,1.0,1.0,0.99928,0.9992
1,sa3_rgb_8cls_1002,0.9951,0.99037,0.99681,0.9934,0.99864,0.99488,0.99901,0.99874,0.99758,...,1.0,0.99681,0.99461,0.99165,0.9471,0.92919,1.0,1.0,0.99986,0.99967
2,sa3_rgbnir0822_rgb_8cls_1002,0.99673,0.9884,0.9984,0.99505,0.98636,0.98179,0.99858,0.99836,0.99677,...,1.0,0.99362,0.9921,0.9909,0.94848,0.93804,1.0,1.0,0.9927,0.99814
3,sa3_rgbnir0822rd_rgb_8cls_1002,0.99755,0.98904,0.9984,0.9967,0.99045,0.98862,0.99907,0.99834,0.99597,...,0.99718,0.99362,0.99461,0.99199,0.95078,0.93698,1.0,1.0,0.99147,0.99648
4,sa3_rgbnir0822rdcc_rgb_8cls_1003,0.99796,0.98932,0.9984,0.99752,0.99591,0.98919,0.99716,0.99736,0.99677,...,0.99435,0.99469,0.99497,0.9922,0.95032,0.93733,1.0,1.0,0.99263,0.99745
