In [10]:
from experiment_util.parms_tune import Params, BestScoreParamRecorder, search_grid_generator
from model.synthetic_graph import fan, house, star, build_circle_structure_split
import numpy as np
from model.AMKS.amks import MultiHopAMKS
from running.log import get_logger
import datetime
from tqdm import tqdm
from sklearn.cluster import KMeans
import sklearn.metrics as sk_metrics
from sklearn import utils as sk_utils
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score

In [11]:
class CurParams(Params):
    def __init__(self) -> None:
        super().__init__()
        self._hop = None
        self._sigma = None
        self._step = None

    def update(self, other):
        self._hop = other._hop
        self._sigma = other._sigma
        self._step = other._step

    def to_str(self):
        return f"step:{self._step}_sigma:{self._sigma}_hop:{self._hop}"

In [12]:
g, roles = build_circle_structure_split([[fan, 6], [star, 6], [house]], 2, 16)
label = np.array(roles)
compute_model = MultiHopAMKS()

In [13]:
sigma_range = np.linspace(0, 4, 20)[1:]
hop_range = list(range(0, 6))
step_range = list(range(3, 30, 3))

time = 50
compute_model.set_g(g)
compute_model.set_time(time)

node_index = sorted(g.nodes)

In [14]:
hom_recorder = BestScoreParamRecorder(CurParams(), -1, "hom")
comp_recorder = BestScoreParamRecorder(CurParams(), -1, "comp")
sil_recorder = BestScoreParamRecorder(CurParams(), -1, "sil")
acc_recorder = BestScoreParamRecorder(CurParams(), -1, "acc")
macro_fi_recorder = BestScoreParamRecorder(CurParams(), -1, "macro_f1")

In [15]:
logger = get_logger("amks synthetic")
logger.info("find best params of synthetic graph waveRing")
logger.info("time = " + str(time))
start_time = datetime.datetime.now()
params = CurParams()

In [16]:
def unsupervised_evaluate(embedding_vec: np.ndarray, labels: np.ndarray):
    """计算3个衡量的指标"""
    colors = labels
    nb_clust = len(np.unique(colors))
    trans_data = embedding_vec
    km = KMeans(n_clusters=nb_clust)
    km.fit(trans_data)
    labels_pred = km.labels_
    hom = sk_metrics.homogeneity_score(colors, labels_pred)
    comp = sk_metrics.completeness_score(colors, labels_pred)
    sil = sk_metrics.silhouette_score(trans_data, labels_pred, metric='euclidean')

    return hom, comp, sil

In [17]:
def supervised_evaluate(embedding_vec: np.ndarray, labels: np.ndarray):
    """4最近邻和交叉验证，评估模型的有监督嵌入能力"""
    data, labels = sk_utils.shuffle(embedding_vec, labels)
    knn = KNeighborsClassifier(n_neighbors=4)
    acc_score = cross_val_score(knn, data, y=labels, cv=5, scoring='accuracy')
    macro_f1 = cross_val_score(knn, data, y=labels, cv=5, scoring='f1_macro')
    return np.mean(acc_score), np.mean(macro_f1)

In [18]:
for sigma, hop, step in tqdm(list(search_grid_generator(sigma_range, hop_range, step_range))):
    compute_model.set_maxhop(hop)
    compute_model.set_sigma(sigma)
    compute_model.set_step(step)

    compute_model.compute_emb_vec()
    embedding_vec = compute_model.get_embedding_vec(node_index)

    hom, comp, sil = unsupervised_evaluate(embedding_vec, label)

    acc, macro_f1 = supervised_evaluate(embedding_vec, label)

    params._hop = hop
    params._sigma = sigma
    params._step = step

    hom_recorder.update_params(params, hom, logger)
    comp_recorder.update_params(params, comp, logger)
    sil_recorder.update_params(params, sil, logger)
    acc_recorder.update_params(params, acc, logger)
    macro_fi_recorder.update_params(params, macro_f1, logger)

time_spent = datetime.datetime.now() - start_time
min_ = time_spent.seconds // 60
sec = time_spent.seconds % 60
logger.info("time spent: " + str(min_) + " min" + str(sec) + " s")

  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(trans_data)
  km.fit(tran