In [None]:
import sys
sys.path.append('../')

import importlib
from collections import defaultdict
from pathlib import Path
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import Markdown

from dynalign.experiments.paths import EMBEDDINGS_PATH

In [None]:
def get_all_metadata_paths(path):
    models_paths = [it for it in EMBEDDINGS_PATH.iterdir() if it.is_dir()]
    all_paths = []
    for model_path in models_paths:
        all_paths += [it for it in model_path.iterdir() if "metadata" in str(it)]

    return all_paths


def read_metadata(paths):
    parsed_metadata = defaultdict(dict)
    for path in tqdm(paths):
        ds = path.name.replace("_metadata.pkl", "")
        model_name = path.parent.name
        parsed_metadata[model_name][ds] = pd.read_pickle(path)

    return parsed_metadata

In [None]:
metadata_paths = get_all_metadata_paths(EMBEDDINGS_PATH)
metadata = read_metadata(metadata_paths)

In [None]:
datasets = list(metadata[list(metadata.keys())[0]].keys())
models = list(metadata.keys())

In [None]:
import numpy as np


def parse_epoch_time_log(metrics):
    total_times_log = defaultdict(int)
    for epoch_metrics in metrics.values():
        for epoch_metric_key in epoch_metrics.keys():
            total_times_log[f"total_{epoch_metric_key}"] += np.sum(
                epoch_metrics[epoch_metric_key]
            )
    return total_times_log


def get_snapshot_calculation_times(metrics):
    out_metrics = {}

    for k, v in metrics.items():
        if k == "epoch_time_log":
            continue

        out_metrics[k] = np.sum(v)

    out_metrics["computation_time"] = np.sum(
        [
            v
            for k, v in out_metrics.items()
            if k in {"training_step_time", "loss_backward_time", "optimizer_step_time"}
        ]
    )

    return out_metrics


def get_run_calculation_times(metrics, model_name):
    out_metrics = defaultdict(int)
    for snapshot_id, snapshot_metrics in enumerate(metrics):
        if model_name == "Node2Vec" and snapshot_id == 0:
            continue
            
        parsed_metrics = get_snapshot_calculation_times(
            snapshot_metrics["enhanced_time_log"]
        )
        for k, v in parsed_metrics.items():
            out_metrics[k] += v
        
    
        out_metrics["total_calculation_time"] += snapshot_metrics["calculation_time"]
        

    return out_metrics


def get_calculation_times(metadata, model_name, dataset_name):
    model_ds_metadata = metadata[model_name][dataset_name]["metrics"]
    out_metrics = defaultdict(list)

    for run_id, run in enumerate(model_ds_metadata):
        run_metrics = get_run_calculation_times(
            model_ds_metadata[run_id], model_name=model_name
        )

        for k, v in run_metrics.items():
            out_metrics[k].append(v)

    averaged_metrics = {
        k: (np.mean(v).round(4), np.std(v).round(4)) for k, v in out_metrics.items()
    }
    return out_metrics, averaged_metrics


# calculation_times_test = get_detailed_calculation_times(test_metadata)

In [None]:
dfs = {}
for dataset in datasets:
    dataset_times = {}

    for model in models:
        dataset_times[model] = get_calculation_times(
            metadata=metadata, model_name=model, dataset_name=dataset
        )[1]
    
    
    display(Markdown(dataset))
    df = pd.DataFrame.from_dict(dataset_times, orient="index")
    display(df)
    dfs[dataset] = (df)


In [None]:
# def calculate_overhead(model_times, mean_n2v_time):
#     overhead = model_times. / mean_n2v_time

#     return overhead

overhead_times = {}
for ds, ds_df in dfs.items():

    mean_n2v_time = ds_df.loc["Node2Vec"]["computation_time"][0]
    overhead_times[ds] = ds_df.T.apply(
        lambda x: x.loc["computation_time"][0] - mean_n2v_time
    ).to_dict()
    
overhead_times_ratio = {}
for ds, ds_df in dfs.items():

    mean_n2v_time = ds_df.loc["Node2Vec"]["computation_time"][0]
    overhead_times_ratio[ds] = ds_df.T.apply(
        lambda x: x.loc["computation_time"][0] / mean_n2v_time
    ).to_dict()

In [None]:
pd.DataFrame(overhead_times_ratio).drop(["Node2Vec"])

In [None]:
from pathlib import Path

posthoc_times = {}

for posthoc in ("PosthocALL", "PosthocEJ", "PosthocTB"):
    posthoc_dir = Path(f"../data/posthoc/{posthoc}/")
    posthoc_metadata_files = [
        it for it in list(posthoc_dir.iterdir()) if "metadata" in str(it)
    ]
    metadata = {
        path.name.split("_")[0]: pd.read_pickle(path)['metrics']
        for path in posthoc_metadata_files
    }
    
    posthoc_calculation_time = {}
    for ds, ds_metadata in metadata.items():
        snapshot_times = defaultdict(list)
        for run in ds_metadata:
            snapshot_times = sum([it['calculation_time'] for it in run])
        posthoc_calculation_time[ds] = np.mean(snapshot_times)

    posthoc_times[posthoc] = posthoc_calculation_time

In [None]:
pd.DataFrame(posthoc_times)

In [None]:
df = pd.DataFrame(overhead_times)
df.drop(["Node2Vec"], inplace=True)
df

In [None]:
df - pd.DataFrame(posthoc_times)['PosthocALL']

In [None]:
pd.DataFrame(overhead_times) / pd.DataFrame(posthoc_times)['PosthocALL']