In [None]:
import sys
sys.path.append('../')

In [None]:
import os

import numpy as np
import pandas as pd

from dynalign.aligners import selectors as dynalign_selectors

In [None]:
from copy import deepcopy

In [None]:
paths = {"graphs": "../data/graphs/"}

In [None]:
graphs = {
    graph_name.replace(".pkl", ""): pd.read_pickle(
        os.path.join(paths["graphs"], graph_name)
    )
    for graph_name in os.listdir(paths["graphs"])
    if ".pkl" in graph_name
}
graphs = {
    k: v[k] for k,v in graphs.items()
}

# Get all scores and ref_nodes

In [None]:
from collections import defaultdict
from typing import Dict, Any

from tqdm.autonotebook import tqdm

from dynalign.aligners.selectors import get_reference_nodes

In [None]:
def prepare_selector_args(
    percent: float, log_normalization: bool = False
) -> Dict[str, Any]:
    args = {
        "selection_method": "percent",
        "selection_method_args": {"percent": percent},
    }
    if log_normalization:
        args["selection_method_args"]["log_norm_scores"] = True
    return args

In [None]:
def get_all_ref_nodes_and_scores(
    snapshots, selector, cache, percentages, log_normalization: bool = False
):

    selector_scores = []
    selector_scores_mapped = []
    reference_nodes = defaultdict(list)
    for snapshot_id, snapshot in tqdm(enumerate(snapshots[1:-1]), leave=False):
        _, __, scores, scores_mapped = get_reference_nodes(
            selector=selector,
            selector_args=prepare_selector_args(
                percent=1.0, log_normalization=log_normalization
            ),
            graph=snapshot,
            reference_graph=snapshots[0],
            reverse_node_index_mapping=None,
            cache=cache[snapshot_id + 1] if cache else None,
        )
        selector_scores.append(scores)
        selector_scores_mapped.append(scores_mapped)
        for percent in tqdm(percentages, leave=False):
            ref_nodes, _, __, ___ = get_reference_nodes(
                selector=selector,
                selector_args=prepare_selector_args(
                    percent=percent, log_normalization=log_normalization
                ),
                graph=snapshot,
                reference_graph=snapshots[0],
                reverse_node_index_mapping=None,
                cache=cache[snapshot_id + 1] if cache else None,
            )
            reference_nodes[percent].append(ref_nodes)

    return selector_scores, selector_scores_mapped, reference_nodes

In [None]:
percentages = np.arange(0.1, 1.0, 0.1)
selectors_cls = [
    "dynalign.aligners.selectors.EdgeJaccardNodesSelector",
#     "dynalign.aligners.selectors.TemporalCentralityMeasureSelector",
]

In [None]:
all_selector_scores = {}
all_selector_scores_scaled = {}
all_selector_ref_nodes = {}


for ds in tqdm(graphs.keys()):
    ds_graphs = graphs[ds]["graphs"]
    temporal_selectors_cache = {
        k: pd.read_pickle(
            os.path.join(f"../data/cached/temporal_scores/{k}/", f"{ds}.pkl")
        )
        for k in ["betweenness"]
    }

    ds_selector_scores = {}
    ds_selector_scores_scaled = {}
    ds_selector_ref_nodes = {}

    for selector in tqdm(selectors_cls, leave=False):
        selector_name = selector.split(".")[-1]
        (
            selector_scores,
            selector_scores_scaled,
            reference_nodes,
        ) = get_all_ref_nodes_and_scores(
            snapshots=ds_graphs,
            selector=selector,
            percentages=percentages,
            cache=None,
        )
        ds_selector_scores[selector_name] = selector_scores
        ds_selector_scores_scaled[selector_name] = selector_scores_scaled
        ds_selector_ref_nodes[selector_name] = reference_nodes

    for selector_name, cache in tqdm(temporal_selectors_cache.items(), leave=False):
        (
            selector_scores,
            selector_scores_scaled,
            reference_nodes,
        ) = get_all_ref_nodes_and_scores(
            snapshots=ds_graphs,
            selector="dynalign.aligners.selectors.TemporalCentralityMeasureSelector",
            percentages=percentages,
            cache=cache,
            log_normalization=True,
        )
        ds_selector_scores[selector_name] = selector_scores
        ds_selector_scores_scaled[selector_name] = selector_scores_scaled
        ds_selector_ref_nodes[selector_name] = reference_nodes

    all_selector_scores[ds] = ds_selector_scores
    all_selector_scores_scaled[ds] = ds_selector_scores_scaled
    all_selector_ref_nodes[ds] = ds_selector_ref_nodes

In [None]:
import pickle

In [None]:
with open("ref_nodes.pkl", "wb") as f:
    pickle.dump(obj=all_selector_ref_nodes, file=f)

In [None]:
all_selector_ref_nodes['bitcoin-alpha']['betweenness'].keys()

# Analysis

In [None]:
ds = "ogbl-collab"

In [None]:
def calculate_jaccard_index(ref_nodes_a, ref_nodes_b):
    ref_nodes_a = set(list(ref_nodes_a))
    ref_nodes_b = set(list(ref_nodes_b))
        
    return len(ref_nodes_a.intersection(ref_nodes_b)) / len(
        ref_nodes_a.union(ref_nodes_b)
    )


def calculate_all_jaccard_indices(snapshot_ref_nodes_a, snapshot_ref_nodes_b):
    jaccard_indices = []
    for ref_nodes_a, ref_nodes_b in zip(snapshot_ref_nodes_a, snapshot_ref_nodes_b):
#         import pdb; pdb.set_trace()
        jaccard_indices.append(calculate_jaccard_index(ref_nodes_a, ref_nodes_b))

    return jaccard_indices

In [None]:
jaccard_index = defaultdict(dict)

for selector_a in tqdm(selectors, leave=False):
    selector_a_ref_nodes = all_selector_ref_nodes[ds][selector_a]
    for selector_b in tqdm(selectors, leave=False):
        selector_b_ref_nodes = all_selector_ref_nodes[ds][selector_b]
        for percent in tqdm(percentages, leave=False):
            if selector_b not in jaccard_index[selector_a]:
                jaccard_index[selector_a][selector_b] = {}
                
            jaccard_index[selector_a][selector_b][
                percent
            ] = calculate_all_jaccard_indices(
                selector_a_ref_nodes[percent], selector_b_ref_nodes[percent]
            )

In [None]:
import pandas as pd

for percent in percentages:
    display(percent)
    df = pd.DataFrame.from_dict(jaccard_index).applymap(
        lambda x: (np.mean(x[percent]), np.std(x[percent]))
    ).applymap(lambda y: (np.round(y[0], 3), np.round(y[1], 3)))
    display(df)

# Scale problem

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.hist(list(ds_selector_scores['betweenness'][0].values()))

In [None]:
plt.hist(list(ds_selector_scores['betweenness'][0].values()))

In [None]:
plt.hist(ds_selector_scores_scaled['betweenness'][0].numpy())

In [None]:
selector = selectors["FILDNE"]

In [None]:
def flip_scores(scores_dict):
    max_value = np.max(list(scores_dict.values()))
    print(max_value)
    return {
        node: 1 - (node_score / max_value) for node, node_score in scores_dict.items()
    }

In [None]:
ref_nodes, scores_dict = selector.select(
    ds_graphs[0], ds_graphs[1], cache=None
)

In [None]:
plt.hist(list(scores_dict.values()))

In [None]:
ref_nodes