In [1]:
import os
from collections import defaultdict
import pandas as pd

def update_dict(line, stat_dict, file, folder_loc):
    stat = line.split(" ")[-1]
    stat_dict[file.split('/')[folder_loc]].append(float(stat))
    return stat_dict

def compute_mean(mean, k, precs, recs, f1s):
    mean[k] = {}
    mean[k]['precision'] = sum(precs[k])/len(precs[k])
    mean[k]['recall'] = sum(recs[k])/len(recs[k])
    mean[k]['f1_score'] = sum(f1s[k])/len(f1s[k])
    return mean

def compute_max(best, k, precs, recs, f1s):
    best[k] = {}
    best[k]['precision'] = max(precs[k])
    best[k]['recall'] = max(recs[k])
    best[k]['f1_score'] = max(f1s[k])
    return best

def get_output_details(output_dir):
    files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(output_dir) for f in filenames if "MILP_eval.csv" in f]
    precs, recs, f1s = defaultdict(list), defaultdict(list), defaultdict(list)
    prec_ratio, rec_ratio, f1_ratio = defaultdict(list), defaultdict(list), defaultdict(list)

    for file in files:
        if 'sample' in file.split('/')[2]:
            sample_id = file.split('/')[2].split('_')[1]
            ratio = file.split('/')[3]
            with open(file, 'r') as f:
                for line in f:
                    if "precision" in line:
                        precs = update_dict(line, precs, file, 2)
                        prec_ratio = update_dict(line, prec_ratio, file, 3)
                    if "recall" in line:
                        recs = update_dict(line, recs, file, 2)
                        rec_ratio = update_dict(line, rec_ratio, file, 3)
                    if "f1" in line:
                        f1s = update_dict(line, f1s, file, 2)
                        f1_ratio = update_dict(line, f1_ratio, file, 3)
    return precs, recs, f1s, prec_ratio, rec_ratio, f1_ratio                    

def get_mean(precs, recs, f1s):
    mean, best = {}, {}
    mean_scores, best_scores = [], [] 
    for sample_id in precs:
        mean[sample_id] = {}
        mean[sample_id]['precision'] = sum(precs[sample_id])/len(precs[sample_id])
        mean[sample_id]['recall'] = sum(recs[sample_id])/len(recs[sample_id])
        mean[sample_id]['f1_score'] = sum(f1s[sample_id])/len(f1s[sample_id])
        mean_scores.append([sample_id, mean[sample_id]['precision'], mean[sample_id]['recall'], mean[sample_id]['f1_score']])

        best[sample_id] = {}
        best[sample_id]['precision'] = max(precs[sample_id])
        best[sample_id]['recall'] = max(recs[sample_id])
        best[sample_id]['f1_score'] = max(f1s[sample_id])  
        best_scores.append([sample_id, best[sample_id]['precision'], best[sample_id]['recall'], best[sample_id]['f1_score']])

    mean_scores = pd.DataFrame(mean_scores)
    mean_scores.rename(columns = {0: 'Sample', 1: 'Precision', 2: 'Recall', 3: 'F1 score'}, inplace = True)

    best_scores = pd.DataFrame(best_scores)
    best_scores.rename(columns = {0: 'Sample', 1: 'Precision', 2: 'Recall', 3: 'F1 score'}, inplace = True)
    return mean, best, mean_scores, best_scores

def get_greedy_stats(greedy_dir, ids):
    greedy_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(greedy_dir) for f in filenames if "greedy_mean.csv" in f]
    greedy_precs, greedy_recs, greedy_f1s = defaultdict(list), defaultdict(list), defaultdict(list)

    count = {}
    for file in greedy_files:
        #print(file)
        if file.split('/')[2] in ids:
            with open(file, 'r') as f:
                rl, pl, fl = [], [], []
                for line in f:
                    if "score_ref_coverage" in line:
                        rl.append(line)
                    if "score_pred_coverage" in line:
                        pl.append(line)
                    if "overall_score" in line:
                        fl.append(line)       
                greedy_recs = update_dict(rl[0], greedy_recs, file, 2)
                greedy_precs = update_dict(pl[0], greedy_precs, file, 2)
                greedy_f1s = update_dict(fl[0], greedy_f1s, file, 2)
    return greedy_precs, greedy_recs, greedy_f1s            

def get_ilp_stats(ilp_dir):
    ilp_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(ilp_dir) for f in filenames if "eval.csv" in f]
    ilp_precs, ilp_recs, ilp_f1s = defaultdict(list), defaultdict(list), defaultdict(list)

    for file in ilp_files:
        if 'sample' in file.split('/')[4]:
            sample_id = file.split('/')[4].split('_')[1]
            ratio = file.split('/')[3]
            with open(file, 'r') as f:
                for line in f:
                    if "precision" in line:
                        ilp_precs = update_dict(line, ilp_precs, file, 2)
                    if "recall" in line:
                        ilp_recs = update_dict(line, ilp_recs, file, 2)
                    if "f1" in line:
                        ilp_f1s = update_dict(line, ilp_f1s, file, 2)
    return ilp_precs, ilp_recs, ilp_f1s                    