In [None]:
%load_ext autoreload
%autoreload 2
import sys

sys.path.append("../")
from src.utils import *
from src.graph2vec import ExtendedGraph2Vec, Ensemble_G2V
from sklearn.model_selection import LeaveOneOut
import pandas as pd

In [None]:
def which_to_generate(num: int):
    path = (
        "../datasets/github_stargazers/"
        if num == 0
        else "../datasets/regular_graphs/"
        if num == 1
        else "../datasets/partition_graphs/"
    )
    graphs, labels = (
        read_stargazers_dataset("../datasets/github_stargazers")
        if num == 0
        else load_artificial(path)
    )
    order_dict = np.load(f"{path}orderings.npy", allow_pickle=True)[()]
    return graphs, labels, order_dict, path

In [None]:
DATASET_NUMBER = 2

In [None]:
# cv_fold = 64

graphs, labels, order_dict, path_to_save = which_to_generate(DATASET_NUMBER)

cv_fold = 10
single_emb_size = 128 if DATASET_NUMBER == 0 else 2 if DATASET_NUMBER == 1 else 15


tested_range = np.linspace(0, 1, 5)

WEIGHTING_FUNCTIONS = (
    {f"w_mean_{a:.2f}": ("w_mean", [a]) for a in tested_range}
    | {f"concat": ("concatenate", [])}
    | {f"p_proj_{a}": ("partial_projection", [0, a]) for a in (True, False)}
    | {f"w_proj_{a:.2f}": ("w_projection", [a]) for a in tested_range}
)

ens_model = Ensemble_G2V(
    ExtendedGraph2Vec(dimensions=single_emb_size),
    ExtendedGraph2Vec(use_pv_dm=True, dimensions=single_emb_size),
)

In [None]:
res = []
for measure, order in tqdm(order_dict.items()):
    for name, (func, args) in tqdm(WEIGHTING_FUNCTIONS.items(), leave=False):
        ens_model.set_weighting_function(func, *args)
        res.append(
            cross_validate_graphs(
                graphs=graphs,
                ordering=order,
                labels=labels,
                n_splits=cv_fold,
                embedder=ens_model,
                cls=LogisticRegression(max_iter=1000),
                method=f"{measure}_{name}",
            )
        )

In [None]:
pd.concat(res, ignore_index=True).to_csv(f"{path_to_save}results2.csv")
pd.read_csv(f"{path_to_save}results2.csv", index_col=0).sample(15)