In [None]:
import sys
sys.path.append('../src/mane/prototype/')
import numpy as np
import graph as g
import pickle as p

from sklearn.preprocessing import normalize, scale
from sklearn.metrics import f1_score
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegressionCV

In [None]:
def lg(exp_id, graph_name, index=[0], norm=False, split=0.5, use_bias=False,
       max_iter=100, C=1e9, ic=200, test_with_training_data=True, cv=None):
    weightfile = '../src/mane/prototype/embeddings/' + exp_id + '.weights'
    graphfile = '../src/mane/data/' + graph_name
    with open(weightfile, 'rb') as f:
        w = p.load(f)
    graph = g.graph_from_pickle(graphfile+'.graph', graphfile+'.community')
    emb = None
    if index is None:
        emb = w
    else:
        for i in index:
            if emb is None:
                emb = w[i]
            else:
                emb += w[i]
        emb /= len(index)
    if use_bias:
        emb[:,-1] = w[2].reshape((-1,))
    if norm:
        emb = normalize(emb)
    xids, y_train = graph.gen_training_community(split)
    X = [emb[i] for i in xids]
    if cv:
        learner = LogisticRegressionCV(fit_intercept=True, cv=cv, 
                                       solver='lbfgs', max_iter=max_iter, 
                                       intercept_scaling=ic)
    else:
        learner = LogisticRegression(C=C, max_iter=max_iter, 
                                     intercept_scaling=ic).fit(X, y_train)
    predictor = learner.fit(X, y_train)
    if test_with_training_data:
        eval_list = graph.nodes()
    else:
        eval_list = [i for i in graph.nodes() if i not in xids]
    y_true = [graph._communities[i] for i in eval_list]
    y_pred = [predictor.predict(emb[i].reshape(1,-1))[0] for i in eval_list]
    print('Experiment ', exp_id, ' ', graph_name, ' ', str(index))
    if cv:
        print('With', cv, '-fold cross-validation')
    print('f1_macro: ', f1_score(y_true, y_pred, average='macro'))
    print('f1_micro: ', f1_score(y_true, y_pred, average='micro'))