In [None]:
####### If any required libraries or packages fail to import, please ensure they are installed beforehand.
import numpy as np
import sys       
import pandas as pd
import datetime  
import os     

# This is necessary to avoid warnings about memory leaks in KMeans. Correct as needed.
os.environ["OMP_NUM_THREADS"] = "12" 

import pickle
import csv    
import openpyxl 

from IPython.display import clear_output

import torch
os.environ['TORCH'] = torch.__version__
from torch_geometric.utils import to_networkx 

import networkx as nx 

from lion_pytorch import Lion 
# LION optimizer
# pip install lion-pytorch
# cf. https://github.com/lucidrains/lion-pytorch

from early_stopping import EarlyStopping
# Early Stopping
# cf. https://github.com/Bjarten/early-stopping-pytorch
# get the file "early_stopping.py"

In [None]:
def Initialization_parameters_node_edge():
    global feature_s_on, feature_degree_on, feature_prob_on, class_num, input_normalization
    global out1_num, out1_items, out1_weight, out1_index
    global inverse_T, closeness, betweenness, clustering
    
    # add each node state as a node feature    1:Yes    0:No
    feature_s_on = 0       # fix this parameter as 0, do not change it
    
    # add node degrees as a node feature   1:Yes   0:No
    feature_degree_on = 1

    # add hihger probabilities as a node feature   2:Yes   0:No
    feature_prob_on = 2    

    # the number of classes in the classification task
    class_num = 2      # fix this as 2 (label 'in' and 'out'), do not change it

    # Input normalization  1:Yes  0:No
    input_normalization = 1

    # the number of scalar outputs (only big phi in this study)
    out1_num = 1   # fix this as 1, do not change it
    
    # The following three parameters are utilized internally. Do not change them.
    out1_items = ['BigPhi']
    out1_weight = [1.0]
    out1_index = ['Phi']    
    
    # add parameter T as a node feature   0:Yes   -1:No
    inverse_T = 0        # do not worry about the name 'inverse'

    # add closeness centrality as a node feature   1:Yes   0:No
    closeness = 1

    # add betweenness centrality as a node feature   1:Yes   0:No
    betweenness = 1

    # add clustering coefficients as a node feature   1:Yes   0:No
    clustering = 1

In [None]:
def Initialization_parameters_optimization():
    global learning_rate, weight_decay_val, weights, patience_val, loss2_coef, optimizer_select
    
    learning_rate = 0.0001 

    weight_decay_val = 0     # 0 in this study

    weights = [1.8, 1.0]      # weights for class labels 'out' and 'in' 

    patience_val = 50    # patience in the early stopping strategy

    loss2_coef = 5.0     # weight for the cross-entropy loss, 5 in this study

    optimizer_select = 'Lion'    # select one among 'Lion', 'Adam', 'RAdam', and 'AdamW'

In [None]:
def Initialization_parameters_GNN():
    global convolution_type, other_network_type, drop_rate

    convolution_type = 'Transformer'        # select one among 'Transformer', 'GraphConv', and 'GAT' 
    
    # pooling type
    # 0(proposed method):global_max, x - global_max,  1:global_mean, x - global_mean,   2:global_max, x
    other_network_type = 0    
    
    drop_rate = 0.3 

In [None]:
def Initialization_parameters_dataset():
    global dataset_type, random_data_num, test_data_shitei, learning_data_ratio, validation_data_ratio
    global test_data_bigN, bigN, prelearned_model, pickle_filename, file_postfix, model_folder
    global adding_split_brain_data, add_split_brain_data_rate, over_sampling, auto_bins, os_fluctuation, bins
    global learning_batch_size
    
    dataset_type = 1       # fix this parameter as 1, do not change it
    
    # specify random_connection_graph dataset   '1000_567':non_extrapolative setting    '1500':extrapolative setting
    random_data_num = '1000_567' 
    
    # specify test dataset [1]:N=5   [2]:N=6   [3]:N=7     0:N=5,6,7 mix (i.e., non_extrapolative setting)
    test_data_shitei = 0      # 0 for non_extrapolative setting    [3] for extrapolative setting

    # if test_data_shitei is set as 0, all data is shuffled and this proportion of data is utilized in the training process
    # if test data shitei is not set as 0, this parameter is ignored 
    learning_data_ratio = 0.9    # real number [0-1],    0.9 (non-extrapolative setting in the proposed method)

    # the proportion of validation dataset within expanded dataset after data augmentation and oversampling 
    validation_data_ratio = 0.1      # 0.1 in the proposed method      

    # execute N=100 graphs as test dataset or not    0:No     2:load saved graphs and execute
    # set this parameter 0 for non-extrapolative setting and extrapolative setting
    # When setting this parameter as 2, test dataset is replaced with saved graphs of N=100
    #           In this case, you should set learning_data_ratio as 1.0
    test_data_bigN = 2
    
    # value of bigN system
    bigN = 100       # fix this parameter as 100, do not change it (In the case of test_data_bigN=0, this parameter is ignored.)

    # use prelearned models or not    1:Yes,   0:No
    # When you set the parameter test_data_bigN as 2, you must set this parameter as 1
    # When setting this parameter as 1, this program excecutes only test process and 
    #    random_data_num can be set as ’1000_567' or '1500', whichever you prefer (the program automatically ignore it).
    # When setting this parameter as 1, the program will use models by loading *****_param_*****.pth files from model_folder
    prelearned_model = 1

    # pkl file name for bigN(=100) graphs
    # If you set test_data_bigN as 0, this parameter is ignored.
    pickle_filename = 'random_graph_data_N=100/bigN_data_N=100_type9_p=0.0000.pkl' 
    # Each pickle file includes 100 graphs. First 50 graphs are not split-brain-like systems which are not described in the paper.
    # Latter 50 graphs are split-brain-like systems.

    # postfix of output files, this parameter is valid only in test_data_bigN=2
    file_postfix = '_N=100_type9_p=0.00'   

    # folder name where prelearned models exist
    model_folder = 'prelearned_model/'

    # data augmentation is executed or not  
    adding_split_brain_data = 2        # 2:ON    0:OFF
    
    # the amount of augmented data (if dding_split_brain_data is set as 0, this parameter is ignored)
    add_split_brain_data_rate = 0.05   # Specified as a percentage of the number of data used in the training process
    
    # oversampling is executed or not  
    over_sampling = 2          # 2:ON    0:OFF

    # bins are automatically determined or not by Kmeans method in oversampling (the number of cluster is fixed as seven in this study)
    auto_bins = 1      # 1:ON    0:OFF     We recommend that this parameter should be fixed as 1
    
    # variables for adding noise in oversampling
    os_fluctuation = 2      # 1:big phi and the first feature,  2:big phi and all features (proposed method)
    
    # bins (if auto_bins is set as 1, this parameter is ignored) 
    bins = [0.00001, 0.2, 0.4, 0.6, 1.1,  2, 4, np.inf]    
   
    learning_batch_size = 128     # 128 in this study

In [None]:
# device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
import torch_geometric        # package for GNNs
from torch_geometric.data import Data, InMemoryDataset       
        
class BigPhis(InMemoryDataset):       
    def __init__(self, data_def, transform = None):   
        super(BigPhis, self).__init__('.', transform) 
        self.data, self.slices = self.collate(data_def)

In [None]:
def Set_path_candidate():
    global path_candidate, test_starting_number, random_data_num

    if dataset_type == 1 and random_data_num == '1000_567':  # non-extrapolative setting
        path_candidate = [
            './random_graph_data_non_extrapolative_setting/N=5',
            './random_graph_data_non_extrapolative_setting/N=6',
            './random_graph_data_non_extrapolative_setting/N=7'
            ]
    elif dataset_type == 1 and random_data_num == '1500':  # for extrapolative setting
        path_candidate = [
            './random_graph_data_extrapolative_setting/N=5',
            './random_graph_data_extrapolative_setting/N=6',
            './random_graph_data_extrapolative_setting/N=7'    
            ]
    else:
        print('Something is wrong')
        sys.exit()
    

    if test_data_shitei != 0:    
        specified_indexes = [index - 1 for index in test_data_shitei]

        specified_paths = [path_candidate[i] for i in specified_indexes]    
        not_specified_paths = [path for i, path in enumerate(path_candidate) if i not in specified_indexes] 

        path_candidate = not_specified_paths + specified_paths

        test_starting_number = len(path_candidate) - len(test_data_shitei) + 1

        display('Data path utilized in this experiment', path_candidate)
        print('Test data is after the next number (counting the first as 1)', test_starting_number)

In [None]:
def Data_generation_dataset1():
    global max_mc      
    max_mc = 0      
    
    import copy 
    import re   

    pattern_N = r'N=(\d{1,3})'  
    
    global data_list, data_count    
    
    data_count = []  


    for path in path_candidate:
        match_N = re.search(pattern_N, path)

        if match_N:
            N = int(match_N.group(1))
            print('N:', N)
        else:
            print('N cannot be found')
            sys.exit()

            
        pattern_path = r'(N=[^\（]*)'

        match = re.search(pattern_path, path)

        if match:
            path_short = match.group(1).strip('/')
            path_short = path_short + '(random)'
        else:
            print('No match with path')    
            sys.exit()

            
        datafile = pd.read_csv( path + '_meta/summary_N=%d.csv' %(N), index_col=0)                         

        total_data_num = len(datafile)
        
        
        for num in range(total_data_num):
            T_val = datafile.loc['SN=%d' %(num), 'T']
            
            out1_list = [ datafile.loc['SN=%d' %(num), out1_items[i] ] for i in range(out1_num) ]

            mc_all =  datafile.loc['SN=%d' %(num), 'Cut']
            mc_list = mc_all.split('|') 
            
            cm = np.array( pd.read_csv( path + '_meta/connection_SN=%d.csv' %(num), header=None)  ) 

            src = []
            dst = []

            for i in range(N):  
                for j in range(N): 
                    if cm[i, j] == 1: 
                        src.append(i)
                        dst.append(j)

            J = np.array( pd.read_csv( path + '_meta/edge_val_SN=%d.csv' %(num), header=None)  ) 

            edge_val = []

            for i in range(N):  
                for j in range(N): 
                    if cm[i, j] == 1: 
                        edge_val.append( [ J[i,j] ] )

            state = pd.read_csv( path + '_meta/state_SN=%d.csv' %(num), header=None).iloc[0,:] 
            state = (state * 2 -1).tolist()
                
            edge_val_save = copy.deepcopy(edge_val)     
            
            edge_index = torch.tensor([src, dst], dtype=torch.long) 
            edge_attr = torch.tensor(edge_val, dtype=torch.float)                    
            
            if inverse_T == 0:
                node_list = [ [x, T_val] for x in state]
            elif inverse_T == -1:
                node_list = [ [x] for x in state]
            else:
                print('inverse_T ERROR'); sys.exit()

            if feature_degree_on >= 1:      
                from collections import Counter

                element_counts = Counter(src) 
                element_counts_list = [element_counts[i] for i in range(N)] 

                if feature_degree_on == 1:
                    node_list = [orig + [add] for orig, add in zip(node_list, element_counts_list)]
                else:
                    print('feature_degree_on is Wrong!')
                    sys.exit()


            edge_index = torch.tensor([src, dst], dtype=torch.long)
            data_nx = Data(edge_index=edge_index, num_nodes=N)
            G = to_networkx(data_nx,to_undirected=True) 

            if closeness == 1:
                closeness_centrality = nx.closeness_centrality(G)  
                closeness_list = [closeness_centrality[node] for node in sorted(G.nodes())]
                node_list = [orig + [add] for orig, add in zip(node_list, closeness_list)]                   

            if betweenness == 1:
                betweenness_centrality = nx.betweenness_centrality(G)
                betweenness_list = [betweenness_centrality[node] for node in sorted(G.nodes())]
                node_list = [orig + [add] for orig, add in zip(node_list, betweenness_list)]                   

            if clustering == 1:
                clustering_coefficients = nx.clustering(G)
                clustering_list = [clustering_coefficients[node] for node in sorted(G.nodes())]
                node_list = [orig + [add] for orig, add in zip(node_list, clustering_list)]    



            ################# 
                        
            if feature_prob_on == 2: 
                inflow = np.zeros(N) 

                for i in range(len(src)): 
                    value = edge_val_save[i][0] * node_list[src[i]][0]
                    inflow[dst[i]] += value

                node_list_first = [row[0] for row in node_list]

                prob_val = 1.0 / ( 1.0 + np.exp( -2 * inflow * node_list_first / T_val ) )
                prob_val = [pv if pv > 0.5 else 1 - pv for pv in prob_val]
                
                node_list = [orig + [add] for orig, add in zip(node_list, prob_val)]

            
            x = torch.tensor(node_list, dtype=torch.float) 



            #########################################
            y = torch.tensor(out1_list, dtype=torch.float)  
                    

            
            multiple_correct_label = []  


            if len(mc_list) > max_mc:
                max_mc = len(mc_list)

            total_mc = 0 

            while(total_mc < 2):  

                for val in range(len(mc_list)):  
                    mc = mc_list[val]  

                    mc_former = mc.split('==>')[0] 
                    mc_latter = mc.split('==>')[1]

                    mc_numbers_former_str = re.findall(r'\d', mc_former)   
                    mc_numbers_latter_str = re.findall(r'\d', mc_latter)

                    mc_numbers_former = [int(num) for num in mc_numbers_former_str] 
                    mc_numbers_latter = [int(num) for num in mc_numbers_latter_str] 

                    mc_numbers_former.sort()
                    mc_numbers_latter.sort()
                    
                    correct_label = []   

                    if class_num == 2: 
                        for _ in range(N):
                            if _ in mc_numbers_former:
                                correct_label.append([0,1])
                            elif _ in mc_numbers_latter:
                                correct_label.append([0,1])
                            else:
                                correct_label.append([1,0])
                    else:
                        print('Wrong class_num!!')
                        sys.exit()


                    multiple_correct_label.append(correct_label)

                    total_mc += 1
                    if total_mc == 2:
                        break


                            
            y2 = torch.tensor( multiple_correct_label, dtype=torch.float)  
            y2 = y2.permute(1,0,2)    

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, y2=y2, mc_num=len(mc_list), 
                        meta=[path_short, '-', num, T_val, '-'] )            

            data_list.append(data)

        data_count.append( len(data_list) )
        
    print('Max Major Complex (Degeneracy Degree) is ', max_mc)
    print('In our dataset, there is no degeneracy, i.e., the above value must be 1')

In [None]:
def Data_generation_bigN():
    import copy      
    import random 
    
    global data_list_bigN   
    global pickle_filename
    global total_data_num    # the number of saved graphs
    
    N = bigN  
    # total_data_num = 100      # the number of saved graphs


    for num in range(total_data_num):
        if test_data_bigN == 2:    
            if num == 0:
                print('loading a graph')
                
            graph_df = pd.read_pickle(pickle_filename)
            T_val = graph_df.loc[num]['T_val']
            src = graph_df.loc[num]['src'] 
            dst = graph_df.loc[num]['dst'] 
            edge_val = graph_df.loc[num]['edge_val'] 
            state = graph_df.loc[num]['state'] 
        else:
            print('test_data_bigN is wrong!'); sys.exit()
            
        edge_val_save = copy.deepcopy(edge_val)    

        edge_index = torch.tensor([src, dst], dtype=torch.long) 
        edge_attr = torch.tensor(edge_val, dtype=torch.float)                    


        if inverse_T == 0:
            node_list = [ [x, T_val] for x in state]
        elif inverse_T == -1:
            node_list = [ [x] for x in state]
        else:
            print('inverse_T ERROR'); sys.exit()



        if feature_degree_on >= 1:      
            from collections import Counter

            element_counts = Counter(src)        
            element_counts_list = [element_counts[i] for i in range(N)]  

            if feature_degree_on == 1:
                node_list = [orig + [add] for orig, add in zip(node_list, element_counts_list)]
            else:
                print('feature_degree_on is Wrong!')
                sys.exit()


        edge_index = torch.tensor([src, dst], dtype=torch.long)

        data_nx = Data(edge_index=edge_index, num_nodes=N)
        G = to_networkx(data_nx,to_undirected=True) 

        if closeness == 1:
            closeness_centrality = nx.closeness_centrality(G)    
            closeness_list = [closeness_centrality[node] for node in sorted(G.nodes())]
            node_list = [orig + [add] for orig, add in zip(node_list, closeness_list)]                   

        if betweenness == 1:
            betweenness_centrality = nx.betweenness_centrality(G)
            betweenness_list = [betweenness_centrality[node] for node in sorted(G.nodes())]
            node_list = [orig + [add] for orig, add in zip(node_list, betweenness_list)]                   

        if clustering == 1:
            clustering_coefficients = nx.clustering(G)
            clustering_list = [clustering_coefficients[node] for node in sorted(G.nodes())]
            node_list = [orig + [add] for orig, add in zip(node_list, clustering_list)]    


        if feature_prob_on == 2:   
            inflow = np.zeros(N)   

            for i in range(len(src)): 
                value = edge_val_save[i][0] * node_list[src[i]][0]
                inflow[dst[i]] += value

            node_list_first = [row[0] for row in node_list]

            prob_val = 1.0 / ( 1.0 + np.exp( -2 * inflow * node_list_first / T_val ) )
            prob_val = [pv if pv > 0.5 else 1 - pv for pv in prob_val]

            node_list = [orig + [add] for orig, add in zip(node_list, prob_val)]

        x = torch.tensor(node_list, dtype=torch.float)   


        ########################################
        y = torch.tensor([-100], dtype=torch.float)                         
        # As true big phi value is unknown, we set dummy value as -100.
        
        multiple_correct_label = [[[0,1]] * N]*2     # True labels are also unknown, and we set [0,1] as dummy.
        y2 = torch.tensor( multiple_correct_label, dtype=torch.float)  

        y2 = y2.permute(1,0,2)    

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, y2=y2, mc_num=2, 
                    meta=['big_N', '-', num, T_val, '-'] )            

        data_list_bigN.append(data)

In [None]:
def Dataset_generation():
    global data_list, dataset, feature_dimension, edge_dimension

    dataset = BigPhis(data_list)
    
    print()
    print(f'Dataset: {dataset}:')
    print('====================')
    print(f'Number of graphs: {len(dataset)}')
    print(f'Number of features: {dataset.num_features -1}') 

    data = dataset[0]  # Get the first graph object.

    print()
    print(data)
    print('=============================================================')

    # Gather some statistics about the first graph.
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    
    feature_dimension = dataset.num_node_features

    if feature_s_on == 0:
        feature_dimension += -1
    
    edge_dimension = dataset[0].edge_attr.size(1)
    
    print('the number of features (Dataset_generation):', feature_dimension)
    print('the number of edge features (Dataset_generation):', edge_dimension)

In [None]:
def Dataset_generation_bigN():
    global data_list_bigN, test_dataset, feature_dimension, edge_dimension

    test_dataset = BigPhis(data_list_bigN)

    print()
    print(f'Dataset: {test_dataset}:')
    print('====================')
    print(f'Number of graphs: {len(test_dataset)}')
    print(f'Number of features: {test_dataset.num_features -1}')  

    data = test_dataset[0]  # Get the first graph object.

    print()
    print(data)
    print('==============================================')

    # Gather some statistics about the first graph.
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')

    feature_dimension = test_dataset.num_node_features  

    if feature_s_on == 0:
        feature_dimension += -1
        
    edge_dimension = test_dataset[0].edge_attr.size(1)
    
    print('the number of features (Dataset_generation):', feature_dimension)
    print('the number of edge features (Dataset_generation):', edge_dimension)

In [None]:
def Add_split_brain_data(kl_dataset):
    import random
    import copy      
    
    degeneracy_num = [4, 2]    
    # Our program is made to adapt to degeneracy, but there is no degeneracy in the current dataset. So, do not worry about this parameter.
 
    data_list_local = list(kl_dataset) 
    
    orig_data_num = len(kl_dataset)  

    add_data_num = int( orig_data_num * add_split_brain_data_rate )  
    
    print('the number of original data:', orig_data_num )
    print('ratio of added data:', add_split_brain_data_rate, 'and then the number of added data:', add_data_num )
        
    for _ in range( add_data_num ):
        
        while(1):
            chosen_index1 = random.randint(0, orig_data_num-1 )  
            chosen_index2 = random.randint(0, orig_data_num-1 ) 
            
            tmp_data1 = copy.deepcopy( kl_dataset[chosen_index1] )
            tmp_data2 = copy.deepcopy( kl_dataset[chosen_index2] )


            if torch.abs( tmp_data1.y[0] - tmp_data2.y[0] ) < 0.001: 
                continue
        
            if tmp_data1.y[0] < tmp_data2.y[0]:
                tmp_data1, tmp_data2 = tmp_data2, tmp_data1 
                
            break
        
        num_nodes_data1 = tmp_data1.x.size(0)

        shifted_edge_index_data2 = tmp_data2.edge_index + num_nodes_data1

        new_edge_index = torch.cat([tmp_data1.edge_index, shifted_edge_index_data2], dim=1)

        new_x = torch.cat([tmp_data1.x, tmp_data2.x], dim=0)
        new_edge_attr = torch.cat([tmp_data1.edge_attr, tmp_data2.edge_attr], dim=0)

        new_y = tmp_data1.y 

        new_mc_num = tmp_data1.mc_num
                
        num_nodes_data2 = tmp_data2.x.size(0) 
        
        if class_num == 2:
            multiple_correct_label = [[[1,0]] * num_nodes_data2]*degeneracy_num[dataset_type] 
        
        y2 = torch.tensor( multiple_correct_label, dtype=torch.float) 
        
        y2 = y2.permute(1,0,2)    
        
        new_y2 = torch.cat([tmp_data1.y2, y2], dim=0)

        new_meta = []
        
        for m1, m2 in zip(tmp_data1.meta, tmp_data2.meta):
            new_meta.append(f"{m1}/{m2}")
        
        
        
        new_data = Data(x=new_x, edge_index=new_edge_index, edge_attr=new_edge_attr, y=new_y, y2=new_y2, 
                        mc_num=new_mc_num, meta=new_meta)
         
        data_list_local.append(new_data)
        
    print('the number of training data after augmentation:', len(data_list_local) )
    
    return data_list_local

In [None]:
def Over_sampling(kl_dataset):
    import random
    import copy      
    from collections import defaultdict  
    
    
    #####################
    global bins
    from sklearn.cluster import KMeans

    if auto_bins == 1:   
        target_values = np.array([d.y[0].item() for d in kl_dataset]).reshape(-1, 1)

        kmeans = KMeans(n_clusters=7, init='k-means++', n_init=10, random_state=None) 
        kmeans.fit(target_values)
        cluster_centers = np.sort(kmeans.cluster_centers_.flatten()) 

        bins = np.concatenate(([0.00005], (cluster_centers[:-1] + cluster_centers[1:]) / 2, [np.inf]))  

        print('border of bins:', bins)    
    else:
        print('manual bins:', bins)
    ########################
        
    
    

    categories = defaultdict(list)

    for i, d in enumerate(kl_dataset):
        category = np.digitize(d.y[0].item(), bins) - 1  
        categories[category].append(i)
    
    max_count = max(len(v) for v in categories.values())
    print('the number of samples before oversampling:', [len(v) for v in categories.values()] )
    print('the number of samples in the largest bin:', max_count)
    
    if len(bins) -1 != len( categories.values() ):   
        print('Some bins have no members(ERROR)'); sys.exit()
    
    data_list_local = list(kl_dataset)  
    for category, indices in categories.items(): 
        shortage = max_count - len(indices)
        print('Category', category, 'the number of missing samples', shortage)
        
        if shortage > 0:
             for _ in range( int(shortage/1.0) ):
                chosen_index = random.choice(indices)
 
                tmp_data =  copy.deepcopy( kl_dataset[chosen_index] )

 
                for val in range(out1_num):
                    tmp_data.y[val] *= random.uniform(0.95, 1.05)

                if os_fluctuation == 1:
                    tmp_data.x[:,1] *= random.uniform(0.95, 1.05)
                elif os_fluctuation == 2:
                    tmp_data.x[:,:] *= random.uniform(0.95, 1.05)
                else:
                    print('os_fluctuation is wrong')
                    sys.exit()

                data_list_local.append(tmp_data)

    print('the number of samples after oversampling:', len(data_list_local) )
    
    return data_list_local

In [None]:
def Split_dataset():
    import copy 
   
    global dataset, train_dataset, validation_dataset, test_dataset, test_starting_number, kougi_learning_dataset

    if test_data_shitei == 0:
        train_num = int(len(dataset) * learning_data_ratio )      
        train2_num = int( train_num * (1-validation_data_ratio) ) 

        dataset = dataset.shuffle()

        train_dataset = dataset[:train2_num]    
        validation_dataset = dataset[train2_num:train_num]  
        test_dataset = dataset[train_num:]     
        
        kougi_learning_dataset = dataset[:train_num] 
        
    else:       
        kougi_learning_dataset = dataset[: data_count[test_starting_number -2] ]
        test_dataset  = dataset[data_count[test_starting_number -2] : ]  
        
        train2_num = int( len(kougi_learning_dataset) * (1-validation_data_ratio) ) 
        
        
        kougi_learning_dataset = kougi_learning_dataset.shuffle()

        train_dataset = kougi_learning_dataset[:train2_num] 
        validation_dataset = kougi_learning_dataset[train2_num:] 

    
    
    # data augmentation
    if adding_split_brain_data == 2:  
        data_list_local = Add_split_brain_data(kougi_learning_dataset) 

        kougi_learning_dataset = BigPhis(data_list_local)     
        kougi_learning_dataset = kougi_learning_dataset.shuffle()
        print('the number of training data after augmentation', len(kougi_learning_dataset) )

        train2_num = int( len(kougi_learning_dataset) * (1-validation_data_ratio) )  
        train_dataset = kougi_learning_dataset[:train2_num] 
        validation_dataset = kougi_learning_dataset[train2_num:] 
   

    # oversampling
    if over_sampling == 2: 
        data_list_local = Over_sampling(kougi_learning_dataset)   

        kougi_learning_dataset = BigPhis(data_list_local)    
        kougi_learning_dataset = kougi_learning_dataset.shuffle()
        print('the number of training data after oversampling:', len(kougi_learning_dataset) )
        
        train2_num = int( len(kougi_learning_dataset) * (1-validation_data_ratio) )  
        train_dataset = kougi_learning_dataset[:train2_num] 
        validation_dataset = kougi_learning_dataset[train2_num:]  
        
    
    print(f'Final Number of training graphs: {len(train_dataset)}')
    print(f'Final Number of validation graphs: {len(validation_dataset)}')
    print(f'Final Number of test graphs: {len(test_dataset)}')

In [None]:
def Input_normalization_new():
    global train_dataset, validation_dataset, test_dataset, mean_x, std_x
    
    if input_normalization == 1:
        data_list_local = list(train_dataset)  
        
        all_features = torch.cat([data.x for data in data_list_local], dim=0)
        mean = all_features.mean(dim=0)
        std = all_features.std(dim=0)
        mean_x = mean 
        std_x = std

        for data in data_list_local:
            data.x = (data.x - mean) / std

        if feature_s_on == 0:
            for data in data_list_local:
                data.x = data.x[:, 1:] 
            
        train_dataset = BigPhis(data_list_local)      
            
            

        data_list_local = list(validation_dataset)

        for data in data_list_local:
            data.x = (data.x - mean) / std

        if feature_s_on == 0:
            for data in data_list_local:
                data.x = data.x[:, 1:]  
            
        validation_dataset = BigPhis(data_list_local)      

      
        
        #################################
        
        if prelearned_model == 0: 
            data_list_local = list(test_dataset)

            for data in data_list_local:
                data.x = (data.x - mean) / std

            if feature_s_on == 0:
                for data in data_list_local:
                    data.x = data.x[:, 1:] 

            test_dataset = BigPhis(data_list_local)      

In [None]:
def Generation_dataloader():
    from torch_geometric.loader import DataLoader

    global train_dataset, validation_dataset, test_dataset, train_loader, validation_loader, test_loader
    
    
    train_loader = DataLoader(train_dataset, batch_size=learning_batch_size, shuffle=True)    
    validation_loader = DataLoader(validation_dataset, batch_size=256, shuffle=False) 
    
    if prelearned_model == 0:    
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)     

In [None]:
from torch.nn import Linear     
import torch.nn.functional as F   
from torch_geometric.nn import GATConv, GraphConv, TransformerConv 
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.nn import BatchNorm

In [None]:
class GCN_reg(torch.nn.Module):    
    def __init__(self, hidden_channels1, hidden_channels2, hidden_channels3, hidden_channels4, model_type):
        
        super(GCN_reg, self).__init__()
        
        if model_type == 'GraphConv':
            self.conv1 = GraphConv(feature_dimension, hidden_channels1)   
            self.conv2 = GraphConv(hidden_channels1, hidden_channels2)
            self.conv3 = GraphConv(hidden_channels2, hidden_channels3)  
            self.conv3b = GraphConv(hidden_channels3, hidden_channels4)  
        elif model_type == 'GAT':
            self.conv1 = GATConv(feature_dimension, hidden_channels1, heads=4, concat=True, edge_dim=edge_dimension)   
            self.conv2 = GATConv(hidden_channels1*4, hidden_channels2, heads=4, concat=True, edge_dim=edge_dimension)
            self.conv3 = GATConv(hidden_channels2*4, hidden_channels3, heads=4, concat=True, edge_dim=edge_dimension)              
            self.conv3b = GATConv(hidden_channels3*4, hidden_channels4, edge_dim=edge_dimension)  
        elif model_type == 'Transformer':
            self.conv1 = TransformerConv(feature_dimension, hidden_channels1, heads=4, concat=True, edge_dim=edge_dimension)  
            self.conv2 = TransformerConv(hidden_channels1*4, hidden_channels2, heads=4, concat=True, edge_dim=edge_dimension)
            self.conv3 = TransformerConv(hidden_channels2*4, hidden_channels3, heads=4, concat=True, edge_dim=edge_dimension)               
            self.conv3b = TransformerConv(hidden_channels3*4, hidden_channels4, edge_dim=edge_dimension) 
        else:
            print('Something is wrong'); sys.exit()
        

        
        if model_type == 'GAT' or model_type == 'Transformer':
            self.batch_norm1 = BatchNorm(hidden_channels1*4)
            self.batch_norm2 = BatchNorm(hidden_channels2*4)
            self.batch_norm3 = BatchNorm(hidden_channels3*4)
        else:     
            self.batch_norm1 = BatchNorm(hidden_channels1)
            self.batch_norm2 = BatchNorm(hidden_channels2)
            self.batch_norm3 = BatchNorm(hidden_channels3)
            
            
        # Branch 1 (Estimating big phi)
        self.lin = Linear(hidden_channels4, out1_num)   
        
        # Branch 2 (Estimating major complex)
        self.lin2 = Linear(hidden_channels4, class_num) 

        # We do not use the following self.att. However, this setting is necessary as a matter of form to load our prelearned models.
        self.att = torch.nn.Linear(hidden_channels4*2, 1)


    def forward(self, x, edge_index, edge_attr, batch):
        global drop_rate
        
        # 1. Obtain node embeddings 
        if convolution_type == 'GAT' or convolution_type == 'Transformer':
            x0 = self.conv1(x, edge_index, edge_attr=edge_attr) 
        else:
            x0 = self.conv1(x, edge_index) 
        
        x0 = self.batch_norm1(x0)
        x0 = x0.relu()
        x0 = F.dropout(x0, p=drop_rate, training=self.training)

        
        if convolution_type == 'GAT' or convolution_type == 'Transformer':
            x0 = self.conv2(x0, edge_index, edge_attr=edge_attr) 
        else:
            x0 = self.conv2(x0, edge_index) 

        x0 = self.batch_norm2(x0)
        x0 = x0.relu()
        x0 = F.dropout(x0, p=drop_rate, training=self.training)
        

        if convolution_type == 'GAT' or convolution_type == 'Transformer':
            x0 = self.conv3(x0, edge_index, edge_attr=edge_attr) 
        else:
            x0 = self.conv3(x0, edge_index) 
        
        x0 = self.batch_norm3(x0)
        x0 = x0.relu()
        x0 = F.dropout(x0, p=drop_rate, training=self.training)

        
        if convolution_type == 'GAT' or convolution_type == 'Transformer':
            x0 = self.conv3b(x0, edge_index, edge_attr=edge_attr) 
        else:
            x0 = self.conv3b(x0, edge_index) 

        
           

        # 2. Readout layer
        if other_network_type == 0 or other_network_type == 2:
            x1 = global_max_pool(x0, batch)  
        elif other_network_type == 1:
            x1 = global_mean_pool(x0, batch) 
        else:
            print('other_network_type is wrong!!')
            sys.exit()
            
        
        
        # 3. Apply a final classifier
        x2 = F.dropout(x1, p=drop_rate, training=self.training)
        x2 = self.lin(x2)

        ####### 
        if other_network_type == 0 or other_network_type == 1:
            x3 = x0 - x1[batch]    
        elif other_network_type == 2:
            x3 = x0
        else:
            print('other_network_type is wrong!!')
            sys.exit()
            
        x3 = F.dropout(x3, p=drop_rate, training=self.training) 
        x3 = self.lin2(x3)
        
         
        return x2, x3

In [None]:
def Weighted_MSE(prediction, target):    
    squared_errors = (prediction - target) ** 2
    weighted_squared_errors = squared_errors * out1_weight_tensor
    loss = weighted_squared_errors.mean()
    
    return loss

In [None]:
def Train():
    model.train()

    all_loss = 0  
    all_loss1 = 0
    all_loss2 = 0
  
    
    for data in train_loader:  # Iterate in batches over the training dataset.
        optimizer.zero_grad()  # Clear gradients.
        
        data = data.to(device)    
        out1, out2 = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Perform a single forward pass.
        
        data.y = data.y.float()    
        data.y = data.y.view(-1, out1_num) 

        data.y2 = data.y2.float()   
        
        loss1 = Weighted_MSE(out1, data.y)  # Compute the loss. 
        loss2 = criterion2(out2, data.y2[:, 0, :]) # Compute the loss.
        
        sum_loss = loss1 + loss2_coef*loss2
        sum_loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

        
        all_loss += (loss1 + loss2_coef*loss2) * len(data)
        all_loss1 += loss1 * len(data)
        all_loss2 += loss2_coef * loss2 * len(data)
        

    return all_loss / len(train_loader.dataset), all_loss1 / len(train_loader.dataset), all_loss2 / len(train_loader.dataset)

        
def Validation():
    model.eval()    

    all_loss = 0  
    all_loss1 = 0
    all_loss2 = 0
    
    
    for data in validation_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)   
        out1, out2 = model(data.x, data.edge_index, data.edge_attr, data.batch)  # Perform a single forward pass.

        data.y = data.y.float()   
        data.y = data.y.view(-1, out1_num) 

        data.y2 = data.y2.float()   

        loss1 = Weighted_MSE(out1, data.y)  # Compute the loss.  
        loss2 = criterion2(out2, data.y2[:, 0, :]) # Compute the loss.  

        all_loss += (loss1 + loss2_coef*loss2) * len(data)
        all_loss1 += loss1 * len(data)
        all_loss2 += loss2_coef * loss2 * len(data)

    return all_loss / len(validation_loader.dataset), all_loss1 / len(validation_loader.dataset), all_loss2 / len(validation_loader.dataset)


        
        
def Test(loader):   
    model.eval()

    error1 = 0    
    error2 = 0    
    accuracy = 0   
    exact_match_accuracy = 0      
    total_node_count = 0

    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out1, out2 = model(data.x, data.edge_index, data.edge_attr,  data.batch)  
        pred = out1   

        data.y = data.y.float()  
        data.y = data.y.view(-1, out1_num)  

        data.y2 = data.y2.float()   
        error1 += ((pred - data.y)*(pred - data.y)).sum(dim=0)  # Check against ground-truth labels.
        error2 += torch.abs( pred - data.y ).sum(dim=0)
        
        true_labels = torch.argmax(data.y2[:, 0, :], dim=1) 
        predicted_labels = torch.argmax(out2, dim=1) 
        accuracy += torch.sum(predicted_labels == true_labels).item()   
        exact_match_accuracy += torch.sum(torch.all(predicted_labels == true_labels)).item()  
       
        
        total_node_count += data.num_nodes
        
    
    return error1 / len(loader.dataset), error2/ len(loader.dataset), \
            exact_match_accuracy / len(loader.dataset), accuracy / total_node_count  # Derive ratio of correct predictions.

In [None]:
def Detail_record(loader, output_flag, prefix, time_str):    
    model.eval()

    degeneracy_count = 0  
    degeneracy_exact_count = 0  
    
    # For final results 
    correct_big_phi_list = []
    predicted_big_phi_list = []
    mse_list = []
    mae_list = []
    maer_list = []   
    true_labels_list = []
    predicted_labels_list = []
    node_num_list = []
    bit_rate_list = []
    
    # For meta info.
    path_list = []
    fr_list = []
    state_list = []
    T_list = []
    reverse_list = []

    
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out1, out2 = model(data.x, data.edge_index, data.edge_attr, data.batch)  
        
        pred = out1    

        data.y = data.y.float()   
        data.y = data.y.view(-1, out1_num)  

        error1 = ((pred - data.y)*(pred - data.y))  # Check against ground-truth labels.
        error2 = torch.abs( pred - data.y )

        maer = torch.abs((pred - data.y) / (data.y + 1e-10))

        correct_big_phi_list += data.y.tolist() 
        predicted_big_phi_list += pred.tolist() 
        mse_list += error1.tolist()
        mae_list += error2.tolist()
        maer_list += maer.tolist()    
        
        data.y = data.y.view(-1)  
        
        
        ############ Major Complex
        data_list_local = data.to_data_list()      
        node_num = [ data_list_local[num].x.shape[0] for num in range(len(data_list_local)) ]    
        node_num_list += node_num

        
        data.y2 = data.y2.float()   

        predicted_labels = torch.argmax(out2, dim=1)  
        node_start_index = 0

        for batch_index, num_nodes in enumerate(node_num):  
            max_accuracy = -10.0
            best_true_labels = None  
            graph_predicted_labels = predicted_labels[ node_start_index: node_start_index+num_nodes]  
            predicted_labels_list.append(graph_predicted_labels.tolist())  
            
            for i in range(data.y2.size(1)): 
                true_labels = torch.argmax(data.y2[:, i, :], dim=1)[node_start_index:node_start_index+num_nodes]

                correct_predictions = (graph_predicted_labels == true_labels).float().mean()

                if correct_predictions > max_accuracy:
                    if i >= 1:
                        degeneracy_count += 1
                        if correct_predictions > 0.9999:
                            degeneracy_exact_count += 1
                        
                    max_accuracy = correct_predictions
                    best_true_labels = true_labels  

            bit_rate_list.append(max_accuracy.item())
            true_labels_list.append(best_true_labels.tolist())  
            node_start_index += num_nodes 


            
        ############ 
        path_local = [ data_list_local[num].meta[0] for num in range(len(data_list_local)) ]    
        path_list += path_local
        
        fr_local = [ data_list_local[num].meta[1] for num in range(len(data_list_local)) ]  
        fr_list += fr_local
        
        state_local = [ data_list_local[num].meta[2] for num in range(len(data_list_local)) ]   
        state_list += state_local
        
        T_local = [ data_list_local[num].meta[3] for num in range(len(data_list_local)) ]    
        T_list += T_local
        
        reverse_local = [ data_list_local[num].meta[4] for num in range(len(data_list_local)) ]    
        reverse_list += reverse_local
        
        
 
    ###### 
    
    data_dict = {}

    for name, pred_col, corr_col, mse_col, mae_col, maer_col in zip(out1_index, zip(*predicted_big_phi_list), zip(*correct_big_phi_list), zip(*mse_list), zip(*mae_list), zip(*maer_list)):
        data_dict[f'Estimate{name}'] = pred_col
        data_dict[f'True{name}'] = corr_col
        data_dict[f'MSE{name}']  = mse_col
        data_dict[f'MAE{name}']  = mae_col
        data_dict[f'MAER{name}'] = maer_col  


    data_dict.update({
        'True_label': true_labels_list, 
        'Est_label': predicted_labels_list, 
        'Node_num': node_num_list, 
        'Bit_rate': bit_rate_list
        })       
        
    result_df = pd.DataFrame(data_dict)
    
    result_df['Exact_match'] = (result_df['True_label'] == result_df['Est_label']).astype(int)
    
    result_df['True_label(replaced)'] = result_df['True_label'].apply(replace_2_with_1)
    result_df['Est_label(replaced)'] = result_df['Est_label'].apply(replace_2_with_1)
        
    result_df['Bit_rate(replaced)'] = result_df.apply( \
        lambda row: calculate_bit_accuracy(row['True_label(replaced)'], row['Est_label(replaced)']), axis=1)

    result_df['Exact_match(replaced)'] = (result_df['True_label(replaced)'] == result_df['Est_label(replaced)']).astype(int)
    
    
    result_df['Topology'] = path_list
    
    if dataset_type == 0:
        result_df['fr'] = fr_list
        
    if dataset_type == 0:    
        result_df['state_num'] = state_list
    else:
        result_df['data_num'] = state_list
        
        
    result_df['T'] = T_list
    result_df['Reverse'] = reverse_list
    
    
    
    
    
    ###### Output for display
    average_rate_normal = (result_df['Node_num']*result_df['Bit_rate']).sum() / result_df['Node_num'].sum()
    average_rate_replace = (result_df['Node_num']*result_df['Bit_rate(replaced)']).sum() / result_df['Node_num'].sum()
    
    print('Bit_rate(Normal/Replaced):', average_rate_normal, '/', average_rate_replace )    
    print('Exact_match_rate(Normal/Replaced):', result_df['Exact_match'].mean(), '/', result_df['Exact_match(replaced)'].mean() )

    
    print('Degeneracy Count', degeneracy_count, 'Degeneracy Exact_match Count', degeneracy_exact_count)
    
    ###### File output
    if output_flag == 1:
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)
        
        file_name = f'{save_folder+"final_"+prefix+"_"+str(degeneracy_count)+"_"+str(degeneracy_exact_count)+"_"+"result_"+time_str+".xlsx"}'
        result_df.to_excel(file_name)
    
    return result_df


###################### Function definition ##################################
def replace_2_with_1(lst):
    return [1 if x == 2 else x for x in lst]

def calculate_bit_accuracy(true_list, predicted_list):
    correct = sum(a == b for a, b in zip(true_list, predicted_list))
    total = len(true_list)
    return correct / total

In [None]:
def Prelearned_model_utilization():
    import glob
    import re
    import copy  

    global initial_output_flag, model, test_dataset, file_paths

    
    test_dataset_raw = copy.deepcopy(test_dataset)     

    file_paths = glob.glob(os.path.join(model_folder, 'model_param_*.pth'))


    for file_path in file_paths:        
        test_dataset = test_dataset_raw 

        pattern = r"\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}"

        train_time = re.search(pattern, file_path)

        if train_time:
            print("Extracted datetime:", train_time.group())
        else:
            print("No match found.")

        model.load_state_dict(torch.load( file_path ) )

        model = model.to(device) 
        
        
        tmp_data = pd.read_csv(model_folder + 'train_mean_std_' + train_time.group() + '.csv', header=None)

        mean = tmp_data[0].values
        std = tmp_data[1].values
        print('Mean and Std in the learnig process', mean[1:], std[1:])
        
        data_list_local = list(test_dataset)

        for data in data_list_local:
            data.x = (data.x - mean) / std
            data.x = data.x.float()  

        if feature_s_on == 0:
            for data in data_list_local:
                data.x = data.x[:, 1:] 

        test_dataset = BigPhis(data_list_local)    
        
        from torch_geometric.loader import DataLoader
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)     


        

        ########### Detailed results output
        current_time = datetime.datetime.now()
        time_str = current_time.strftime("%Y-%m-%d_%H-%M-%S")

        print('\nTest dataset')
        result_test_df = Detail_record(test_loader, 1, 'test', train_time.group() + '_TO_' + time_str + file_postfix)
        

        ######## Summary of results output
        msel = []      # l:learning
        mael = []
        maerl = [] 
        corl = []

        for val in range(out1_num):
            msel.append( -9999 )    # As learning process is not executed, we set -9999 as dummy
            mael.append( -9999 )
            maerl.append( -9999 )
            corl.append( -9999 )

        ba1l = -9999
        em1l = -9999
        ba2l = -9999
        em2l = -9999

        mset = []     # t:test
        maet = []
        maert = [] 
        cort = []

        for val in range(out1_num):
            mset.append( result_test_df[ 'MSE'+out1_index[val] ].mean() )
            maet.append( result_test_df[ 'MAE'+out1_index[val] ].mean() )
            maert.append( result_test_df[ 'MAER'+out1_index[val] ].mean() ) 
            cort.append( result_test_df[ ['Estimate'+out1_index[val], 'True'+out1_index[val] ] ].corr().iloc[0, 1] )

        ba1t = result_test_df['Bit_rate'].mean()
        em1t = result_test_df['Exact_match'].mean()
        ba2t = result_test_df['Bit_rate(replaced)'].mean()
        em2t = result_test_df['Exact_match(replaced)'].mean()

        mseL_label = []
        maeL_label = []
        maerL_label = [] 
        corL_label = []
        mseT_label = []
        maeT_label = []
        maerT_label = []  
        corT_label = []

        for val in range(out1_num):
            mseL_label.append( 'mseL_'+out1_index[val] )
            maeL_label.append( 'maeL_'+out1_index[val] )
            maerL_label.append( 'maerL_'+out1_index[val] )
            corL_label.append( 'corrL_'+out1_index[val] )
            mseT_label.append( 'mseT_'+out1_index[val] )
            maeT_label.append( 'maeT_'+out1_index[val] )
            maerT_label.append( 'maerT_'+out1_index[val] )
            corT_label.append( 'corrT_'+out1_index[val] )            

            

        epoch = -9999
        val_loss = torch.tensor(-9999)
        data_list = []    
        train_dataset = []
        validation_dataset = []

        
        ###############################
        #
        file_name = save_folder + summary_result_filename + '.csv' 
        headers = ['time', 'epoch', *mseL_label, *maeL_label, *maerL_label, *corL_label, 
                   'Bit_rateRowL', 'Exact_matchRowL', 'Bit_rateRepL', 'Exact_matchRepL', 
                   *mseT_label, *maeT_label, *maerT_label, *corT_label, 
                   'Bit_rateRowT', 'Exact_matchRowT', 'Bit_rateRepT', 'Exact_matchRepT', 'val_loss', 
                   'dataset_type', 'f_s_on',
                   'f_deg_on', 'f_prob_on', 'class_num', 'input_norm', 'inverse_T', 
                   'closeness', 'betweenness', 'clustering', 'optimizer', 'lr', 'wei_decay', 
                   'out1_num', 'out1_weight', 'class_weight', 'patience', 'loss2_coef', 'batch', 
                   'test_shitei', 'random_num', 'learning_data_ratio', 'val_data_ratio', 
                   'augmentation', 'aug_rate', 'over_sample', 'auto_bin', 'os_bin', 'os_fluc', 
                   'time', 'convolution', 'other', 'model', 'drop_p', 'orig_data_num', 'trainset_num', 'valset_num', 'testset_num']

        data_to_save = [time_str, epoch, *msel, *mael, *maerl, *corl, ba1l, em1l, ba2l, em2l, 
                        *mset, *maet, *maert, *cort, ba1t, em1t, ba2t, em2t, val_loss.item(), 
                        dataset_type, feature_s_on, feature_degree_on, feature_prob_on, 
                        class_num, input_normalization, inverse_T, closeness, betweenness,
                        clustering, 
                        optimizer_select, learning_rate, weight_decay_val, out1_num, str(out1_weight), 
                        str(weights), patience_val, loss2_coef, 
                        learning_batch_size,
                        str(test_data_shitei), str(random_data_num), learning_data_ratio, validation_data_ratio, 
                        adding_split_brain_data, add_split_brain_data_rate, over_sampling, auto_bins,
                        str(np.round(bins, 3)), os_fluctuation,
                        time_str, convolution_type, other_network_type, str(model).replace('\n', '; '), drop_rate,
                        len(data_list), len(train_dataset), len(validation_dataset), len(test_dataset)]



        try:
            if not os.path.exists(file_name):
                with open(file_name, 'w', newline='', encoding='shift-jis') as f:  
                    writer = csv.writer(f)
                    writer.writerow(headers)
                    writer.writerow(data_to_save)
                    initial_output_flag = 0
            else:
                with open(file_name, 'a', newline='', encoding='shift-jis') as f:
                    writer = csv.writer(f)

                    if initial_output_flag == 1:  
                        writer.writerow(headers)
                        initial_output_flag = 0

                    writer.writerow(data_to_save)

        except IOError:
            new_file_name = f'{save_folder+"TMP_"+summary_result_filename+"_"+time_str+".csv"}'
            with open(new_file_name, 'w', newline='', encoding='shift-jis') as f:
                writer = csv.writer(f)
                writer.writerow(headers)
                writer.writerow(data_to_save)       

In [None]:
################### Main Cell
# the number of output channels
hc1 = 50        
hc2 = 150 
hc3 = 150 
hc4 = 50 

# Training error and test error are displayed each epoch or not    1:OFF (time-saving; recommended)    0:ON(time-consuming)
time_save = 1   

# folder name where output files are generated 
save_folder = 'result/'      

# name for summary of results (CSV files are output) 
summary_result_filename = 'result_summary' 


# flag utilized internally, do not change it 
initial_output_flag = 1 



global save_folder, summary_result_filename, initial_output_flag, model, pickle_filename, file_postfix
global feature_degree_on, feature_prob_on, class_num, input_normalization
global inverse_T, closeness, betweenness
global learning_rate, weight_decay_val, out1_weight_tensor, weights, patience_val, loss2_coef, optimizer_select
global convolution_type, other_network_type
global test_data_shitei, learning_data_ratio, validation_data_ratio, over_sampling, auto_bins, bins, os_fluctuation
global learning_batch_size
global random_data_num
global total_data_num





Initialization_parameters_node_edge()
Initialization_parameters_optimization()
Initialization_parameters_GNN()
Initialization_parameters_dataset()
                
            
##############################################################################
#### Uncomment the one section you want to run and comment out the other. ####
##############################################################################

### [1] For comparison experiments (the first and the second experiments in the paper)  
'''  
test_data_bigN = 0; prelearned_model = 0

# specify random_connection_graph dataset   
# CHOOSE ONE  ==>  '1000_567':non_extrapolative setting  or '1500':extrapolative setting
#### random_data_num = '1000_567'; test_data_shitei = 0 
#### random_data_num = '1500'; test_data_shitei = [3] 

for inverse_T, feature_prob_on, feature_degree_on, closeness, betweenness, clustering, loss2_coef, convolution_type, optimizer_select, other_network_type in [
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'Lion', 0),
            (-1, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'Lion', 0), 
            (0, 0, 1, 1, 1, 1, 5.0, 'Transformer', 'Lion', 0),     
            (0, 2, 0, 1, 1, 1, 5.0, 'Transformer', 'Lion', 0),
            (0, 2, 1, 0, 1, 1, 5.0, 'Transformer', 'Lion', 0), 
            (0, 2, 1, 1, 0, 1, 5.0, 'Transformer', 'Lion', 0),
            (0, 2, 1, 1, 1, 0, 5.0, 'Transformer', 'Lion', 0),
            (0, 2, 1, 1, 1, 1, 0.0, 'Transformer', 'Lion', 0),
            (0, 2, 1, 1, 1, 1, 100000.0, 'Transformer', 'Lion', 0),
            (0, 2, 1, 1, 1, 1, 5.0, 'GAT', 'Lion', 0), 
            (0, 2, 1, 1, 1, 1, 5.0, 'GraphConv', 'Lion', 0), 
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'Adam', 0),
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'RAdam', 0),
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'AdamW', 0), 
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'Lion', 1),
            (0, 2, 1, 1, 1, 1, 5.0, 'Transformer', 'Lion', 2) 
            ]:
    for iteration in range(100):
'''



### [2] For experiments with N=100 systems
'''
test_data_bigN = 2; prelearned_model = 1; total_data_num = 100 

for p_e in [0.0000, 0.0004, 0.002, 0.004, 0.006, 0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02, 0.04, 0.06, 0.08, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:
                                pickle_filename = 'random_graph_data_N=100/bigN_data_N=100_type9_p=%.4lf.pkl' %(p_e)
                                file_postfix = '_N=100_type9_p=%.4lf' %(p_e) 
                                
                                print(pickle_filename)
'''                            
                            



### [3] For experiments on scaling behavior with three topologies (N=10-100)
'''
test_data_bigN = 2; prelearned_model = 1; total_data_num = 30 

for N in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
                                pickle_filename = 'three_topologies_data_N=10-100/bigN_data_N=%d.pkl' %(N)
                                file_postfix = '_N=%d' %(N) 

                                bigN = N
        
                                print(pickle_filename)
'''


                                clear_output(True) 
                                
                                Set_path_candidate()


            
            
                                if prelearned_model == 0 or (prelearned_model==1 and test_data_bigN == 0):
                                    data_list = []

                                    if dataset_type == 1:
                                        Data_generation_dataset1()
                                    else:
                                        print('dataset_type is wrong!'); sys.exit()
                                        
                                    Dataset_generation()
                                    Split_dataset()
                                
                                if test_data_bigN == 2:
                                    data_list_bigN = []
                                    Data_generation_bigN()
                                    Dataset_generation_bigN()
                                
                                if prelearned_model == 0 or (prelearned_model==1 and test_data_bigN == 0):
                                    Input_normalization_new()
                                    Generation_dataloader()





                                ######################################################################
                                ######################################################################
                                model = GCN_reg(hidden_channels1=hc1, hidden_channels2=hc2, hidden_channels3=hc3, hidden_channels4=hc4, model_type=convolution_type)  
                                
                                if optimizer_select == 'Adam':
                                    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay_val)
                                elif optimizer_select == 'Lion':    
                                    optimizer = Lion(model.parameters(), lr=learning_rate, weight_decay=weight_decay_val)
                                elif optimizer_select == 'RAdam':
                                    optimizer = torch.optim.RAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay_val)
                                elif optimizer_select == 'AdamW':
                                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_val)
                                else:
                                    print('Optimizer select is wrong!'); sys.exit()


                                out1_weight_tensor = torch.tensor( out1_weight )
                                out1_weight_tensor = out1_weight_tensor.to(device)   
                                
                                class_weights = torch.FloatTensor(weights).cuda()
                                criterion2 = torch.nn.CrossEntropyLoss(weight=class_weights)

                                early_stopping = EarlyStopping(patience=patience_val)   
                                
                                if prelearned_model == 1:  
                                    Prelearned_model_utilization()     
                                    continue
                                
                                model = model.to(device)   

        
                                ##########################################################
                                ##################### Epoch loop ########################
                                #########################################################
                                loss_history = []

                                for epoch in range(1, 5000):
                                    train_loss, train_loss1, train_loss2 = Train()
                                    
                                    
                                    if time_save == 1:   
                                        train_error1, train_error2, train_exact_match, train_bit_accuracy = [-10], [-10], -10, -10
                                        test_error1, test_error2, test_exact_match, test_bit_accuracy = [-10], [-10], -10, -10
                                    else:
                                        train_error1, train_error2, train_exact_match, train_bit_accuracy = Test(train_loader)
                                        test_error1, test_error2, test_exact_match, test_bit_accuracy = Test(test_loader)
                                    
                                    val_loss, val_loss1, val_loss2 = Validation()
                                    
                                    train_error1_print = ", ".join(f"{error:.3f}" for error in train_error1)
                                    train_error2_print = ", ".join(f"{error:.3f}" for error in train_error2)
                                    test_error1_print = ", ".join(f"{error:.3f}" for error in test_error1)
                                    test_error2_print = ", ".join(f"{error:.3f}" for error in test_error2)
                                    
                                    print(f'Epoch: {epoch:03d}, [Train] mse: {train_error1_print}, mae: {train_error2_print}, em: {train_exact_match:.3f}, ba: {train_bit_accuracy:.3f}' )
                                    print(f'[Test] MSE: {test_error1_print}, MAE: {test_error2_print}, EM: {test_exact_match:.3f}, BA: {test_bit_accuracy:.3f}')
                                    if time_save == 0:
                                        print('NOTE: Before the early stop, exact-match (em) is correctly obtained only if mini-batch size is set as 1.')
                                    print(f'[Val] loss: {val_loss} \n')

                                    loss_history.append( 
                                        [train_loss.item(), train_loss1.item(), train_loss2.item(), val_loss.item(), val_loss1.item(), val_loss2.item()] )

                                    early_stopping(val_loss, model)

                                    if early_stopping.early_stop:
                                        print('Early Stop!!')
                                        break


                                ######## After early stopping, load the optimal model
                                model.load_state_dict(torch.load('checkpoint.pt'))
                                
                                train_error1, train_error2, train_exact_match, train_bit_accuracy = Test(train_loader)
                                test_error1, test_error2, test_exact_match, test_bit_accuracy = Test(test_loader)
                                val_loss, val_loss1, val_loss2 = Validation()
                                
                                train_error1_print = ", ".join(f"{error:.3f}" for error in train_error1)
                                train_error2_print = ", ".join(f"{error:.3f}" for error in train_error2)
                                test_error1_print = ", ".join(f"{error:.3f}" for error in test_error1)
                                test_error2_print = ", ".join(f"{error:.3f}" for error in test_error2)                                
                                
                                print('\n')
                                print(f'Optimal, [Train] mse: {train_error1_print}, mae: {train_error2_print}, em: {train_exact_match:.3f}, ba: {train_bit_accuracy:.3f}') 
                                print(f'Optimal, [Test] MSE: {test_error1_print}, MAE: {test_error2_print}, EM: {test_exact_match:.3f}, BA: {test_bit_accuracy:.3f}')
                                print(f'[Val] loss: {val_loss} \n')

                                
                                ########### Detailed results output

                                if not os.path.exists(save_folder):
                                    os.makedirs(save_folder)

                                current_time = datetime.datetime.now()

                                time_str = current_time.strftime("%Y-%m-%d_%H-%M-%S")


                                print('Learning Dataset')
                                result_learning_df = Detail_record(train_loader, 1, 'train', time_str)

                                print('\nTest Dataset')
                                result_test_df = Detail_record(test_loader, 1, 'test', time_str)

                                
                                ########### Medel output
                                file_name_parameter = save_folder + 'model_param_' + time_str + '.pth' 
                                file_name_whole = save_folder + 'model_whole_' + time_str + '.pth'
                                file_name_mean_std = save_folder + 'train_mean_std_' + time_str + '.csv'
                                
                                torch.save(model.state_dict(), file_name_parameter)
                                
                                torch.save(model, file_name_whole)
                                
                                pd.DataFrame({'mean_x': mean_x, 'std_x': std_x}).to_csv(file_name_mean_std, header=False, index=False)
                                
                                
                                ######## Summary of results output
                                msel = []     # l:learning
                                mael = []
                                maerl = []    # mean absolute error rate
                                corl = []

                                for val in range(out1_num):
                                    msel.append( result_learning_df[ 'MSE'+out1_index[val] ].mean() )
                                    mael.append( result_learning_df[ 'MAE'+out1_index[val] ].mean() )
                                    maerl.append( result_learning_df[ 'MAER'+out1_index[val] ].mean() ) 
                                    corl.append( result_learning_df[ ['Estimate'+out1_index[val], 'True'+out1_index[val] ] ].corr().iloc[0, 1] )
                                                                                    
                                ba1l = result_learning_df['Bit_rate'].mean()
                                em1l = result_learning_df['Exact_match'].mean()
                                ba2l = result_learning_df['Bit_rate(replaced)'].mean()
                                em2l = result_learning_df['Exact_match(replaced)'].mean()
                                               
                                mset = []      # t:test
                                maet = []
                                maert = []  
                                cort = []

                                for val in range(out1_num):
                                    mset.append( result_test_df[ 'MSE'+out1_index[val] ].mean() )
                                    maet.append( result_test_df[ 'MAE'+out1_index[val] ].mean() )
                                    maert.append( result_test_df[ 'MAER'+out1_index[val] ].mean() )  
                                    cort.append( result_test_df[ ['Estimate'+out1_index[val], 'True'+out1_index[val] ] ].corr().iloc[0, 1] )
                                                                                    
                                ba1t = result_test_df['Bit_rate'].mean()
                                em1t = result_test_df['Exact_match'].mean()
                                ba2t = result_test_df['Bit_rate(replaced)'].mean()
                                em2t = result_test_df['Exact_match(replaced)'].mean()
                                                
                                mseL_label = []
                                maeL_label = []
                                maerL_label = []  
                                corL_label = []
                                mseT_label = []
                                maeT_label = []
                                maerT_label = []  
                                corT_label = []
                                
                                for val in range(out1_num):
                                    mseL_label.append( 'mseL_'+out1_index[val] )
                                    maeL_label.append( 'maeL_'+out1_index[val] )
                                    maerL_label.append( 'maerL_'+out1_index[val] )
                                    corL_label.append( 'corrL_'+out1_index[val] )
                                    mseT_label.append( 'mseT_'+out1_index[val] )
                                    maeT_label.append( 'maeT_'+out1_index[val] )
                                    maerT_label.append( 'maerT_'+out1_index[val] )
                                    corT_label.append( 'corrT_'+out1_index[val] )
                                    
                                                

                                file_name = save_folder + summary_result_filename + '.csv' 
                                headers = ['time', 'epoch', *mseL_label, *maeL_label, *maerL_label, *corL_label, 
                                           'Bit_rateRowL', 'Exact_matchRowL', 'Bit_rateRepL', 'Exact_matchRepL', 
                                           *mseT_label, *maeT_label, *maerT_label, *corT_label, 
                                           'Bit_rateRowT', 'Exact_matchRowT', 'Bit_rateRepT', 'Exact_matchRepT', 'val_loss', 
                                           'dataset_type', 'f_s_on',
                                           'f_deg_on', 'f_prob_on', 'class_num', 'input_norm', 'inverse_T', 
                                           'closeness', 'betweenness', 'clustering', 'optimizer', 'lr', 'wei_decay', 
                                           'out1_num', 'out1_weight', 'class_weight', 'patience', 'loss2_coef', 'batch', 
                                           'test_shitei', 'random_num', 'learning_data_ratio', 'val_data_ratio', 
                                           'augmentation', 'aug_rate', 'over_sample', 'auto_bin', 'os_bin', 'os_fluc', 
                                           'time', 'convolution', 'other', 'model', 'drop_p', 'orig_data_num', 'trainset_num', 'valset_num', 'testset_num']

                                data_to_save = [time_str, epoch, *msel, *mael, *maerl, *corl, ba1l, em1l, ba2l, em2l, 
                                                *mset, *maet, *maert, *cort, ba1t, em1t, ba2t, em2t, val_loss.item(), 
                                                dataset_type, feature_s_on, feature_degree_on, feature_prob_on, 
                                                class_num, input_normalization, inverse_T, closeness, betweenness, clustering, 
                                                optimizer_select, learning_rate, weight_decay_val, out1_num, str(out1_weight), 
                                                str(weights), patience_val, loss2_coef, 
                                                learning_batch_size,
                                                str(test_data_shitei), str(random_data_num), learning_data_ratio, validation_data_ratio, 
                                                adding_split_brain_data, add_split_brain_data_rate, over_sampling, auto_bins,
                                                str(np.round(bins, 3)), os_fluctuation,
                                                time_str, convolution_type, other_network_type, str(model).replace('\n', '; '), drop_rate,
                                                len(data_list), len(train_dataset), len(validation_dataset), len(test_dataset)]


                                try:
                                    if not os.path.exists(file_name):
                                        with open(file_name, 'w', newline='', encoding='utf-8') as f:   
                                            writer = csv.writer(f)
                                            writer.writerow(headers)
                                            writer.writerow(data_to_save)
                                            initial_output_flag = 0
                                    else:
                                        with open(file_name, 'a', newline='', encoding='utf-8') as f:
                                            writer = csv.writer(f)
                                            
                                            if initial_output_flag == 1:   
                                                writer.writerow(headers)
                                                initial_output_flag = 0

                                            writer.writerow(data_to_save)
                                        
                                except IOError:
                                    new_file_name = f'{save_folder+"TMP_"+summary_result_filename+"_"+time_str+".csv"}'
                                    with open(new_file_name, 'w', newline='', encoding='shift-jis') as f:
                                        writer = csv.writer(f)
                                        writer.writerow(headers)
                                        writer.writerow(data_to_save)



                                ######## output of loss_history
                                file_name = save_folder + 'history_' + time_str + '.xlsx' 

                                loss_history_df = pd.DataFrame( loss_history, 
                                                               columns=['train_total_loss', 'train_loss1', 'train_loss2',
                                                                        'val_total_loss', 'val_loss1', 'val_loss2'])

                                loss_history_df.to_excel(file_name)