In [None]:
from function import *
import numpy as np

# You can experiment in different situations by changing n_classes, n_unit
# exp_factor, p, q and BP, and so on

# number of communities
n_classes = 3

# number of nodes in each community
# You can experiment imbalanced case by adjusting entry values
n_unit = [50,50,50]

# expontial base weight
exp_factor = 1

# homo connection probability in the planted partition model
p = 0.15

# hetero connection probability in the planted partition model
q = 0.015

# prior information ratio
# for example if n_unit = [50,50,50] and BP = 0.02,
# number of prior known nodes would be [1,1,1]
# ,that is, for each community, we only know the true label of a single node out of 50 nodes
BP = 0.02

# set of nodes (list)
V = [ i for i in range(sum(n_unit)) ]

# number of nodes
n_V = len(V)

while True:
    # Using nodes(V), # of classes, # of nodes for each community,
    # homo connection probability p, and hetero connection probability q in the planted partition model,
    # we generate the graph G, Simplices, and target value (Label)
    # Label is of the form of scalar: for example, "0" denotes that corresponding node belongs to the
    # first community ("1" denotes second community, and so on).
    # Label_mat is of the form of vector (one hot type): if a node belongs to the first community, 
    # then corresponding label_mat = [1,0,0] when n_classes = 3
    G, Simplices, Classes, Label, Label_mat = Generating_graph_with_simplices(V, n_classes, n_unit, p, q)
    
    # x_known consists of known nodes information
    # For example, x_known = [[3,1],[4,0],[5,0],[153,0],[154,1],[155,0],[306,0],[307,0],[308,1]] implies
    # node 2 belongs to the first community, node 52 belongs to the second community, and
    # node 103 belongs to the third community
    # x_init and pred1 is one hot type  and scalar type classification result, respectively, 
    # obtained by equilibrium measure (EM) method 
    x_init, pred1, x_known = Initialization(G, Classes, BP)
    
    # In the equilibrium measure method,
    # exclude cases where computational errors occur, such as division by zero.
    if sum([np.isnan(x_known[:,1][k]) == True for k in range(len(x_known))]) != 0:
        continue
    else:
        break

if sum([np.isnan(x_known[:,1][k]) == True for k in range(len(x_known))]) == 0:
   
    # in the Optimzation function below, we need to distinguish between known information and
    # trained information (obtained from EM method), so we remove the known information here.
    x_init_revised = x_init.copy()
    x_known_reversed_list = x_known[:,0][::-1]
    for i in range(len(x_known_reversed_list)):
        x_init_revised = np.delete(x_init_revised, int(x_known_reversed_list[i]))    

    # HOI indicates the usage of the higher order interaction
    # ,that is, if HOI = 0, then the algorithm only use pairwise interactions between nodes
    # however, if HOI = 1, then the algorithm use higher order interaction as well as pairwise interactions
    HOI = 0
    
    # upto 1-simplex
    Simplices2 = Simplices[:2]
    
    # result2, pred2 correspond to the prediction with vector and scalar form, respectively.
    result2, pred2 = Optimization(x_init_revised, x_known, Simplices2, 
                                  n_classes, HOI, exp_factor)
   
    HOI = 1

    # use all simplicial complexes
    Simplices3 = Simplices
    result3, pred3 = Optimization(x_init_revised, x_known, Simplices3, 
                                     n_classes, HOI, exp_factor)
    
    # pre, rec, f1s, acc indicate precision, recall, f1-score, accuracy, respectively.
    # index "1" indicates the result obtained by EM method
    conf_matrix1 = confusion_matrix(Label, pred1, n_classes)
    pre1, rec1, f1s1, acc1 = precision_recall_f1_accuracy(conf_matrix1)

    # index "2" indicates the result obtained by the objective function when only pairwise relation considered.
    conf_matrix2 = confusion_matrix(Label, pred2, n_classes)
    pre2, rec2, f1s2, acc2 = precision_recall_f1_accuracy(conf_matrix2)
    
    # index "3" indicates the result obtained by the objective function when all simplices considered.
    conf_matrix3 = confusion_matrix(Label, pred3, n_classes)
    pre3, rec3, f1s3, acc3 = precision_recall_f1_accuracy(conf_matrix3)