In [5]:
from typing import List, Literal
from average_results import ret_avg_results
import numpy as np
from itertools import product

def generate_table_exp1(datasets=['mnist', 'covertype', 'yeast', 'skin', 'statlog'],knn_algo:Literal['brute', 'kd_tree', 'ball_tree']='brute'):
    ret = ret_avg_results(datasets=datasets, knn_algo=knn_algo)
    if ret is None:
        print("No results found")
        return

    baselines = {}
    smart_accs = {}
    for dataset in sorted(ret.keys()):
        results_dict = ret[dataset]
        clfs= results_dict["clfs"]
        ks= results_dict['ks']
        thresholds= results_dict['thresholds']
        t1_idx = thresholds.index(1.0)
        
        smart= results_dict[f"smart_acc"] #shape = (n_clfs, n_k, n_t)
        smart = np.swapaxes(smart, 1, 2) # shape = (n_clfs, n_t, n_k)
        for iclf, clf in enumerate(clfs):
            baselines[f'{dataset}_{clf}'] = "{:.1f}".format(results_dict["baseline_acc"][iclf] * 100)

            curr_smart = smart[iclf,t1_idx,:] # shape = (n_k)
            k_le_10 = [idx for idx, k in enumerate(ks) if k==10]
            mean_acc_k_le_10 = np.mean(curr_smart[k_le_10])
            means = [mean_acc_k_le_10] + [np.mean(curr_smart[idx]) for idx in range(len(ks)) if ks[idx] > 10]
            means = ["{:.1f}".format(mean * 100) for mean in means]

            smart_accs[f"{dataset}_{clf}"] = means
    with open('exp1_tab.txt','w') as f:
        finals = {}
        for dataset in sorted(datasets):
            for clf in sorted(clfs):
                baseline = baselines[f'{dataset}_{clf}']
                smarts = smart_accs[f'{dataset}_{clf}']

                final = np.array(smarts + [baseline], dtype=object).reshape((2,3))
                finals[f'{dataset}_{clf}'] = final

        finals_per_ds = []
        for dataset in sorted(datasets):
            final_ds = np.hstack([finals[f'{dataset}_{clf}'] for clf in sorted(clfs)])
            finals_per_ds.append(final_ds)

        final_arr = np.vstack(finals_per_ds)
        final_arr = final_arr.tolist()

        f.write("\n".join(["\t".join(row) for row in final_arr]))
                
def generate_table_exp2(datasets=['mnist', 'covertype', 'yeast', 'skin', 'statlog'],knn_algo:Literal['brute', 'kd_tree', 'ball_tree']='brute'):
    ret = ret_avg_results(datasets=datasets, knn_algo=knn_algo)
    if ret is None:
        print("No results found")
        return

    baseline_accs = {}
    baseline_times = {}
    smart_accs = {}
    smart_times = {}
    for dataset in sorted(ret.keys()):
        results_dict = ret[dataset]
        clfs= results_dict["clfs"]
        ks= results_dict['ks']
        thresholds= results_dict['thresholds']
        
        smart_acc= results_dict[f"smart_acc"] #shape = (n_clfs, n_k, n_t)
        smart_acc = np.swapaxes(smart_acc, 1, 2) # shape = (n_clfs, n_t, n_k)
        smart_time = results_dict[f"smart_time"] #shape = (n_clfs, n_k, n_t)
        smart_time = np.swapaxes(smart_time, 1, 2)
        
        for iclf, clf in enumerate(clfs):
            best_baseline_acc = "{:.1f}".format(results_dict["baseline_acc"][iclf] * 100)
            baseline_accs[f'{dataset}_{clf}'] = best_baseline_acc
            
            baseline_time = results_dict["baseline_time"][iclf]
            baseline_times[f'{dataset}_{clf}'] = "{:.4f}".format(baseline_time)[0:7]

            for it, t in enumerate(filter(lambda t: t in (0.6,1.0), thresholds)):
                

                curr_smart = smart_acc[iclf,it,:] # shape = (n_k)
                curr_time = smart_time[iclf,it,:] # shape = (n_k)

                best_k = min(filter(lambda k: ks.index(k) in np.argwhere(curr_smart == np.amax(curr_smart)),ks))
                best_k_idx = ks.index(best_k)

                best_k_acc = curr_smart[best_k_idx]
                best_k_acc = "{:.1f}".format(best_k_acc * 100)

                best_k_time = curr_time[best_k_idx]
                best_k_time = "{:.4f}".format(best_k_time)[0:7]

                smart_accs[f"{dataset}_{clf}_{t}"] = best_k_acc
                smart_times[f"{dataset}_{clf}_{t}"] = best_k_time

    with open('exp2_tab.txt','w') as f:
        finals = {}
        for dataset in sorted(datasets):
            for clf in sorted(clfs):
                baseline_acc = baseline_accs[f'{dataset}_{clf}']
                baseline_time = baseline_times[f'{dataset}_{clf}']
                smart_accs_final = [smart_accs[f'{dataset}_{clf}_{t}'] for t in (0.6, 1)]
                smart_times_final = [smart_times[f'{dataset}_{clf}_{t}'] for t in (0.6, 1)]

                final = np.array(smart_accs_final + [baseline_acc] + smart_times_final +  [baseline_time], dtype=object).reshape((2,3))
                finals[f'{dataset}_{clf}'] = final

        finals_per_ds = []
        for dataset in sorted(datasets):
            final_ds = np.hstack([finals[f'{dataset}_{clf}'] for clf in sorted(clfs)])
            finals_per_ds.append(final_ds)

        final_arr = np.vstack(finals_per_ds)
        final_arr = final_arr.tolist()
        f.write("\n".join(["\t".join(row) for row in final_arr]))       

def generate_table_exp3(datasets=['mnist', 'covertype', 'yeast', 'skin', 'statlog'],knn_algos:List[Literal['brute', 'kd_tree', 'ball_tree']]=['brute', 'ball_tree']):
    smart_accs = {}
    smart_times = {}
    for knn_algo in knn_algos:
        ret = ret_avg_results(datasets=datasets, knn_algo=knn_algo)
        if ret is None:
            print("No results found")
            return
        
        for dataset in sorted(ret.keys()):
            results_dict = ret[dataset]
            clfs= results_dict["clfs"]
            ks= results_dict['ks']
            thresholds= results_dict['thresholds']
            
            smart_acc= results_dict[f"smart_acc"] #shape = (n_clfs, n_k, n_t)
            smart_acc = np.swapaxes(smart_acc, 1, 2) # shape = (n_clfs, n_t, n_k)
            smart_time = results_dict[f"smart_time"] #shape = (n_clfs, n_k, n_t)
            smart_time = np.swapaxes(smart_time, 1, 2)
            
            for iclf, clf in enumerate(clfs):
                for it, t in enumerate(filter(lambda t: t in (0.6,1.0), thresholds)):
                    curr_smart = smart_acc[iclf,it,:] # shape = (n_k)
                    curr_time = smart_time[iclf,it,:] # shape = (n_k)

                    best_k = min(filter(lambda k: ks.index(k) in np.argwhere(curr_smart == np.amax(curr_smart)),ks))
                    best_k_idx = ks.index(best_k)

                    best_k_acc = curr_smart[best_k_idx]
                    best_k_acc = "{:.1f}".format(best_k_acc * 100)

                    best_k_time = curr_time[best_k_idx]
                    best_k_time = "{:.4f}".format(best_k_time)[0:7]

                    smart_accs[f"{dataset}_{clf}_{t}_{knn_algo}"] = best_k_acc
                    smart_times[f"{dataset}_{clf}_{t}_{knn_algo}"] = best_k_time

    with open('exp3_tab.txt','w') as f:
        finals = {}
        for dataset in sorted(datasets):
            for clf in sorted(clfs):
                smart_accs_final = [smart_accs[f'{dataset}_{clf}_{t}_{knn_algo}'] for t in (0.6, 1) for knn_algo in knn_algos]
                smart_times_final = [smart_times[f'{dataset}_{clf}_{t}_{knn_algo}'] for t in (0.6, 1) for knn_algo in knn_algos]

                final = np.vstack([smart_accs_final, smart_times_final])
                # final = np.array(smart_accs_final + smart_times_final, dtype=object).reshape((2,3))
                finals[f'{dataset}_{clf}'] = final

        finals_per_ds = []
        for dataset in sorted(datasets):
            final_ds = np.hstack([finals[f'{dataset}_{clf}'] for clf in sorted(clfs)])
            finals_per_ds.append(final_ds)

        final_arr = np.vstack(finals_per_ds)
        final_arr = final_arr.tolist()
        f.write("\n".join(["\t".join(row) for row in final_arr]))    

generate_table_exp1(datasets=['covertype', 'glass', 'mnist', 'skin','statlog', 'usps', 'wine', 'yeast'],)
generate_table_exp2(datasets=['covertype', 'glass', 'mnist', 'skin','statlog', 'usps', 'wine', 'yeast'],)
generate_table_exp3(datasets=['usps',],)