In [None]:
import numpy as np
import torch
from func import *

seed = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# len(n_units) indicates the number of labels
# i-th number in n_units represents the number of nodes for the i-th label
# one can conduct balanced and imbalanced experiments by changing the numbers in n_units
n_units = [100,100,100,100,100]
n_L = len(n_units)

pi = 0.05  # prior-informed ratio: tested from 0.01 to 0.1 by 0.01
p2 = 0.05  # intra-connection probability: tested from 0.01 to 0.1 by 0.01
q2 = 0.005 # inter-connection probability: tested from 0.001 to 0.01 by 0.001


epochs = 15
exp_base = 1 # W_k in (3), (5), and (7). In this study, we set W_k = 1 for all experiments
lr = 0.4


# seed 
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = False     

# PPL model (diagonals indicate intra-connection probability, others are inter-connection probability)
prob_mat2 = np.full((n_L, n_L), q2)
np.fill_diagonal(prob_mat2, p2)

# Graph, corresponding simplices, node clustering, labels
G, simplices, classes, labels = generating_graph_with_simplices(n_units, prob_mat2, seed)

# ALL
general_simplices = generating_general_simplices(G)
# MAX
maximal_simplices = generating_maximal_simplices(G)
# Aug-MAX
balance_simplices = augment_maximal_cliques(general_simplices, maximal_simplices)

# initialization by Equilibrium measure
initial_data, initial_pred, x_known = initialization(G, classes, simplices, pi, seed)

# Using ALL
final_P1, final_pred1 = HOI_training(epochs, device, general_simplices, initial_data, x_known, lr, exp_base)
# Using MAX
final_P2, final_pred2 = HOI_training(epochs, device, maximal_simplices, initial_data, x_known, lr, exp_base)
# Using Aug-MAX
final_P3, final_pred3 = HOI_training(epochs, device, balance_simplices, initial_data, x_known, lr, exp_base)

# First term is the probability distribution of all nodes 
# with shape (number of nodes, number of labels) = (len(G.nodes), n_L)
# second term is the label prediction with shape (len(G.nodes))