In [None]:
from IPython.display import Markdown
import wandb
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns

api = wandb.Api()


In [None]:
projects = api.projects(entity="haraghi")
for project in projects:
    print(project.name)

In [None]:
datasets_name_and_num_classes = {
    "NCARS": {"name": "N-Cars", "num_classes": 2},
    "NASL": {"name": "N-ASL", "num_classes": 24},
    "NCALTECH101": {"name": "N-Caltech101", "num_classes": 101},
    "DVSGESTURE_TONIC": {"name": "DVS-Gesture", "num_classes": 11},
    "FAN1VS3": {"name": "Fan1vs3", "num_classes": 2}
}

In [None]:
dataset_projects = [
        "FINAL-NASL-varyinig-sparsity",
        "FINAL-NCARS-varyinig-sparsity",
        "FINAL-DVSGESTURE_TONIC-HP-varyinig-sparsity",
        "FINAL-FAN1VS3-varyinig-sparsity",
        "FINAL-NCALTECH101-varyinig-sparsity",
]

In [None]:
def find_val_and_test_acc_keys(run):
    val_acc_key = []
    test_acc_key = []
    for key in run.summary.keys():
        if "val" in key and "acc" in key and "mean" in key:
            val_acc_key.append(key)
        if "test" in key and "acc" in key and "mean" in key:
            test_acc_key.append(key)
    assert len(val_acc_key) <= 1, f"More than one val acc key found: {val_acc_key}"
    assert len(test_acc_key) <= 1, f"More than one test acc key found: {test_acc_key}"
    return val_acc_key[0] if len(val_acc_key) == 1 else None , test_acc_key[0] if len(test_acc_key) == 1 else None

In [None]:
folder_name = 'paper'
subfolder_name = os.path.join('images',folder_name,'sparsity_vs_acc')
entity = 'haraghi'
if not os.path.exists(subfolder_name):
    os.makedirs(subfolder_name)
file_path = os.path.join(subfolder_name,"sparsity_vs_acc.tex")
file_path_md = os.path.join(subfolder_name,"sparsity_vs_acc.md")


val_dict = {}
test_dict = {}
num_events_set = set() 

for project_name in dataset_projects:
    runs = api.runs(f"{entity}/{project_name}")
    runs = [r for r in runs if r.state == "finished" and "transform" in r.config]
    if len(runs) == 0:
        print(f"No runs found for {project_name}")
        continue
    num_events = np.unique([run.config['transform']['train']['num_events_per_sample'] for run in runs])
    runs_per_num_events = {num_event: [run for run in runs if run.config['transform']['train']['num_events_per_sample'] == num_event] for num_event in num_events}
    dataset_name = runs[0].config["dataset"]["name"]
    
    num_events_set = num_events_set.union(set(num_events))
    
    val_mean = {}
    test_mean = {}
    lr = {}
    batch_size = {}
    weight_decay = {}
    
    for num_event in num_events:
        val_mean[num_event] = []
        test_mean[num_event] = []
        lr[num_event] = []
        batch_size[num_event] = []
        weight_decay[num_event] = []
        for run in runs_per_num_events[num_event]:
            val_key, test_key = find_val_and_test_acc_keys(run)
            val_mean[num_event].append(run.summary[val_key] if val_key in run.summary else None)
            test_mean[num_event].append(run.summary[test_key] if test_key in run.summary else None)
            lr[num_event].append(run.config['optimize']['lr'])
            batch_size[num_event].append(run.config['train']['batch_size'])
            if 'weight_decay' in run.config['optimize']:
                weight_decay[num_event].append(run.config['optimize']['weight_decay'])
                
        print(f"percentage of runs with val acc for {num_event} events: {np.sum([v is not None for v in val_mean[num_event]]) / len(val_mean[num_event])} out of {len(val_mean[num_event])} runs")
        print(f"percentage of runs with test acc for {num_event} events: {np.sum([v is not None for v in test_mean[num_event]]) / len(test_mean[num_event])} out of {len(test_mean[num_event])} runs")
    
    val_dict[dataset_name] = val_mean
    test_dict[dataset_name] = test_mean  

num_events_list = sorted(list(num_events_set))

In [None]:
def create_full_metric_mean_std(metric_dict, num_events_list):
    full_metric_mean_std = {}
    for dataset_name, num_events_dict in metric_dict.items():
        full_metric_mean_std[dataset_name] = []
        for num_events in num_events_list:
            if num_events in num_events_dict and not any([v is None for v in num_events_dict[num_events]]):
                full_metric_mean_std[dataset_name].append((np.mean(num_events_dict[num_events]), np.std(num_events_dict[num_events])))
            else:
                full_metric_mean_std[dataset_name].append((None, None))
    return full_metric_mean_std


In [None]:
full_test_mean_std = create_full_metric_mean_std(test_dict, num_events_list)
full_val_mean_std = create_full_metric_mean_std(val_dict, num_events_list)

In [None]:
def write_sparsity_vs_acc_table(file_path, full_metric_mean_std, num_events_list, datasets_name_and_num_classes):
    # Open file for writing
    with open(file_path, "w") as file:
        # Write table header
        file.write("\\begin{tabular}{"+("c"*(3+len(num_events_list)))+"}\n")
        file.write("\\toprule\n")
        file.write(" & & & \\multicolumn{"+str(len(num_events_list))+"}{c}{\\# events per video}\\\\\n")
        file.write("Dataset & \\# classes & & " +
                   " & ".join([str(num_events) for num_events in num_events_list]) +
                   "\\\\\n")
        file.write("\\midrule\n")

        # Write table rows
        for dataset, values in full_metric_mean_std.items():
            row = datasets_name_and_num_classes[dataset]["name"] + " & " 
            # Number of classes
            row += str(datasets_name_and_num_classes[dataset]["num_classes"]) + " & "
            # Test accuracies
            row += "Test Acc. (\\%) & "
            
            
            # for mean_std_tuple in values:
            #     if mean_std_tuple[0] is not None:
            #         row += "${:.2f}$ \\textcolor{{WildStrawberry}}{{\\scriptsize $\\pm {:.2f}$}}".format(mean_std_tuple[0] * 100, mean_std_tuple[1] * 100) + " & "
            #     else:
            #         row += "-- & "
            # file.write(row[:-2] + "\\\\\n")
            
            for mean_std_tuple in values:
                if mean_std_tuple[0] is not None:
                    row += "${:.2f}$".format(mean_std_tuple[0] * 100) + " & "
                else:
                    row += "-- & "
            file.write(row[:-2] + "\\\\\n")
            
            row =  " & & "
            row += "\\textcolor{WildStrawberry}{\\scriptsize Std. Dev. (\\%)} & "
            for mean_std_tuple in values:
                if mean_std_tuple[1] is not None:
                    row += "\\textcolor{{WildStrawberry}}{{\\scriptsize$\\pm {:.2f}$}}".format(mean_std_tuple[1] * 100) + " & "
                else:
                    row += "-- & "                    
            # Write the row
            file.write(row[:-2] + "\\\\\n")
            
            row =  " & & "
            row += "\\textcolor{Cerulean}{\\scriptsize p-value} & "
            for mean_std_tuple in values:
                if mean_std_tuple[2] is not None:
                    row += r"\textcolor{Cerulean}{\scriptsize " +  mean_std_tuple[2] + "} & "
                else:
                    row += "-- & "                    
            # Write the row
            file.write(row[:-2] + "\\\\\n")
            
        # Write table footer
        file.write("\\bottomrule\n")
        file.write("\\end{tabular}\n")


In [None]:
print("{{\\scriptsize($\\pm {:.2f}$)".format(7.654))

In [None]:
def write_sparsity_vs_acc_table_md(file_path, full_metric_mean_std, num_events_list, datasets_name_and_num_classes):
    # Open file for writing
    with open(file_path, "w") as file:
        # Write table header
  


        file.write("| Dataset | # classes | " +
                   " | ".join([str(num_events) for num_events in num_events_list]) +
                   "\n")
        file.write("| --- "*(2+len(num_events_list))+"|\n")
        # Write table rows
        for dataset, values in full_metric_mean_std.items():
            row = "| " + datasets_name_and_num_classes[dataset]["name"] + " | " 
            # Number of classes
            row += str(datasets_name_and_num_classes[dataset]["num_classes"]) + " | "
            # Test accuracies
            for mean_std_tuple in values:
                if mean_std_tuple[0] is not None:
                    row += "${:.2f}$".format(mean_std_tuple[0] * 100) + " ($\\pm {:.2f}$)".format(mean_std_tuple[1] * 100) + " | "
                else:
                    row += "-- | "
            # Write the row
            file.write(row[:-2] + "|\n")



In [None]:
import pickle
with open(os.path.join(subfolder_name,"full_val_mean_std.pickle"), "wb") as f:
    pickle.dump([full_val_mean_std,num_events_list], f)
with open(os.path.join(subfolder_name,"full_test_mean_std.pickle"), "wb") as f:
    pickle.dump([full_test_mean_std,num_events_list], f)
with open(os.path.join(subfolder_name,"p_values_text.pkl"), "rb") as f:
    p_values_text = pickle.load(f)

In [None]:
full_test_mean_std_p_value = {}
for dataset_name in full_test_mean_std.keys():
    full_test_mean_std_p_value[dataset_name] = []
    for i, (mean, std) in enumerate(full_test_mean_std[dataset_name]):       
            full_test_mean_std_p_value[dataset_name].append((mean, std, p_values_text[dataset_name][i]))
print(full_test_mean_std_p_value)

In [None]:
write_sparsity_vs_acc_table(file_path, full_test_mean_std_p_value, num_events_list, datasets_name_and_num_classes)
write_sparsity_vs_acc_table_md(file_path_md, full_test_mean_std, num_events_list, datasets_name_and_num_classes)

# Display the content of the Markdown file as a Markdown cell
with open(file_path_md, "r") as file:
    markdown_content = file.read()

Markdown(markdown_content)

In [None]:
def plot_acc_vs_sparsity(data):
    plt.figure(figsize=(10, 6))
    
    for dataset, sparsity_data in data.items():
        sparsities = []
        mean_accuracies = []
        std_accuracies = []
        
        for sparsity, accuracies in sparsity_data.items():
            sparsities.append(sparsity)
            mean_accuracies.append(np.mean(accuracies))
            std_accuracies.append(np.std(accuracies))
        
        plt.errorbar(sparsities, mean_accuracies, yerr=std_accuracies, label=dataset, capsize=5, marker='o', linestyle='--')

    plt.xlabel('Sparsity')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs Sparsity')
    plt.xscale('log')  # Log scale if sparsity values span several orders of magnitude
    plt.legend()
    plt.grid(True)
    plt.show()


plot_acc_vs_sparsity(test_dict)
