In [None]:
from func import *
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx, from_networkx
import torch_geometric.transforms as T

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dataset name list
name_data_list = ['Cora', 'Citeseer', 'Pubmed']

# Select dataset (change index as needed)
name_data = name_data_list[0]

# Random seed
seed_number = 42

# GAT experiment
torch.manual_seed(seed_number)
np.random.seed(seed_number)

dataset = generating_dataset(name_data)
data = dataset[0]
num_nodes = data.x.size(0)
num_classes = dataset.num_classes  
num_features = dataset.num_features  

# Set all nodes except train and validation as test set
new_test_mask = ~(data.train_mask | data.val_mask)

# Replace original test_mask with new one
data.test_mask = new_test_mask

#====================================================================================================
# GAT model
gat_model = GAT(num_features, num_classes)
lr = 0.001
weight_decay = 0.0005
epochs = 1000
eval_interval = 200 # print error per 200 epochs
softmax_output, softmax_pred = gat_training(gat_model, data, device, lr, weight_decay, epochs, eval_interval)

gat_accuracy = calculate_accuracy(softmax_pred, data.y, data.test_mask)
#====================================================================================================


#====================================================================================================
# Graph from NetworkX 
G = to_networkx(data, to_undirected=True, remove_self_loops=True)

# ALL
general_simplices = generating_general_simplices(G)
# MAX
maximal_simplices = generating_maximal_simplices(G)
# Aug-MAX
balance_simplices = augment_maximal_cliques(general_simplices, maximal_simplices)

# prior-informed ratio
x_known = get_x_known(data, data.train_mask, data.val_mask, num_classes)

# initial_data comes from the result of GAT
initial_data = softmax_output.detach().cpu().numpy()

epochs = 15
exp_base = 1
lr = 0.4
#====================================================================================================


#====================================================================================================
# Using ALL
simplices = general_simplices
final_P, final_pred = HOI_training(epochs, device, simplices, initial_data, x_known, lr, exp_base)
general_accuracy = calculate_accuracy(final_pred, data.y, data.test_mask)
#====================================================================================================


#====================================================================================================
# Using MAX
simplices = maximal_simplices
final_P, final_pred = HOI_training(epochs, device, simplices, initial_data, x_known, lr, exp_base)
maximal_accuracy = calculate_accuracy(final_pred, data.y, data.test_mask)
#====================================================================================================


#====================================================================================================
# Using Aug-MAX
simplices = balanced_simplices
final_P, final_pred = HOI_training(epochs, device, simplices, initial_data, x_known, lr, exp_base)
balanced_accuracy = calculate_accuracy(final_pred, data.y, data.test_mask)
#====================================================================================================