In [1]:
import json
import os
import re
from pathlib import Path
from typing import Dict, Optional

from core.CaS.cas_params import CaSParams
from core.CaS.recas_params import ReCaSParams

MODEL_NAME_MAP = {
    "mlp": "MLP",
    "gcn": "GCN",
    "sage": "SAGE",
    "revgat": "RevGAT",
    "none": "None",
    "mlp_node2vec": "MLP",
}


def parse_method_str(method: str, model_name: str) -> Dict[str, Optional[str]]:
    groups = method.split("+")
    emb = "None"
    if model_name == "none":
        if len(groups) == 2:
            _, feature_type = groups
        else:
            _, feature_type, _ = groups
    elif "_" in model_name:
        if len(groups) == 4:
            _, _, feature_type, emb = groups
        else:
            _, _, feature_type, emb, _ = groups
    else:
        if len(groups) == 3:
            _, _, feature_type = groups
        else:
            _, _, feature_type, _ = groups
    return {
        "gnn_name": MODEL_NAME_MAP[model_name],
        "feature_type": feature_type,
        "emb": emb,
    }


def create_cas_params(result_dir: Path) -> CaSParams:
    cas_params = CaSParams()
    dataset_list = [f for f in os.listdir(result_dir) if os.path.isdir(f)]
    for dataset in dataset_list:
        dataset_dir = result_dir / dataset
        file_list = [f for f in os.listdir(result_dir) if f.endswith("txt")]
        for fname in file_list:
            model_name = fname[len(dataset):]
            with open(dataset_dir / fname) as f:
                for line in f:
                    match = re.search(r"Best parameters for '(.+?)': (.+)", line.strip())
                    if match is not None:
                        method_str, params_str = match.groups()
                        params = json.loads(params_str)
                        method = parse_method_str(method_str, model_name)
                        cas_params.add(params, dataset=dataset, **method)
    return cas_params


def create_recas_params(result_dir: Path) -> CaSParams:
    recas_params = ReCaSParams()
    dataset_list = [f for f in os.listdir(result_dir) if os.path.isdir(result_dir / f)]
    for dataset in dataset_list:
        dataset_dir = result_dir / dataset
        file_list = [f for f in os.listdir(dataset_dir) if f.endswith(".txt")]
        for fname in file_list:
            model_name = fname[len(dataset) + 1: -4]
            with open(dataset_dir / fname) as f:
                while True:
                    try:
                        line = f.readline()
                    except UnicodeDecodeError:
                        continue
                    if not line:
                        break
                    match = re.search(r"Best parameters for '(.+?)': (.+)", line.strip())
                    if match is not None:
                        method_str, params_str = match.groups()
                        params = json.loads(params_str.replace("'", '"'))
                        method = parse_method_str(method_str, model_name)
                        recas_params.add(params, dataset=dataset, **method)
    return recas_params

In [None]:
result_dir = "results_analysis_cs"
target_dir = "core/CaS/cas_params_cs.json"
create_recas_params(Path(result_dir)).save(target_dir)

In [None]:
result_dir = "results_analysis_s"
target_dir = "core/CaS/cas_params_s.json"
create_recas_params(Path(result_dir)).save(target_dir)

In [None]:
result_dir = "results_recas_sc"
target_dir = "core/CaS/recas_params_sc.json"
create_recas_params(Path(result_dir)).save(target_dir)

In [None]:
result_dir = "results_recas_scs"
target_dir = "core/CaS/recas_params_scs.json"
create_recas_params(Path(result_dir)).save(target_dir)

In [None]:
result_dir = "results_recas_cscs"
target_dir = "core/CaS/recas_params_cscs.json"
create_recas_params(Path(result_dir)).save(target_dir)