In [8]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle as pkl
import wandb

import models as m
import functions as f
from functions import dict_to_array, normalize_array

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_networkx

import networkx as nx
from networkx.convert_matrix import from_numpy_array

from sklearn.model_selection import ParameterGrid

In [15]:
# Generating the fake dataset. 
# We decided that patients without Alzheimer's (class 0) would have a stronger connectivity of their brain regions in the top left corner of the correlation matrix. 
# This is a simplification of the real data, but it will allow us to test the model without the complications from the data.

def generate_correlation_matrix(dimensions, stronger_top_left=True):
    # Generate a random matrix
    matrix = np.zeros((dimensions, dimensions))
    
    # Introduce more 1's in the top-left corner for class 1 matrices
    for row in range(dimensions):
        for column in range(dimensions):
            if row <= column and random.random() < 0.25:
                matrix[row, column] = 1
    for i in range(dimensions//2):
        for j in range(dimensions//2):
            if i <= j:
                if stronger_top_left and random.random() < 0.75:  # 75% chance of setting a value to 1
                    matrix[i, j] = 1

    # Make the matrix symmetric
    corr_matrix = np.maximum(matrix, matrix.T)
    
    # Make the diagonal elements equal to 1
    np.fill_diagonal(corr_matrix, 1)
    
    return corr_matrix

def generate_dataset(num_samples, dimensions, stronger_top_left=True, class_ratio=0.5):
    num_class0 = int(num_samples * class_ratio)
    num_class1 = num_samples - num_class0
    
    class0_matrices = [generate_correlation_matrix(dimensions, stronger_top_left) for _ in range(num_class0)]
    class1_matrices = [generate_correlation_matrix(dimensions, stronger_top_left=False) for _ in range(num_class1)]
    
    labels = [0] * num_class0 + [1] * num_class1
    
    # Shuffle the data
    combined = list(zip(class0_matrices + class1_matrices, labels))
    random.shuffle(combined)
    corr_matrices, labels = zip(*combined)
    
    return corr_matrices, labels

In [16]:
# Defining the properties of our dataset
num_samples = 1000
dimensions = 116
class_ratio = 0.5  # Ratio of samples for class 1
corr_matrices, labels = generate_dataset(num_samples, dimensions, stronger_top_left=True, class_ratio=class_ratio)

# matrices contain the generated correlation matrices
# labels contain the corresponding class labels

for i in range(5):
    print(corr_matrices[i])
    print(labels[i])
    print()

[[1. 0. 1. ... 1. 0. 0.]
 [0. 1. 1. ... 0. 0. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 1.]
 [0. 1. 0. ... 0. 1. 1.]]
0

[[1. 1. 1. ... 0. 1. 0.]
 [1. 1. 1. ... 0. 1. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 1. 0.]
 [1. 1. 0. ... 1. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]]
0

[[1. 0. 0. ... 1. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [1. 0. 0. ... 1. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 1.]]
1

[[1. 0. 0. ... 0. 0. 0.]
 [0. 1. 1. ... 1. 0. 0.]
 [0. 1. 1. ... 1. 1. 0.]
 ...
 [0. 1. 1. ... 1. 0. 0.]
 [0. 0. 1. ... 0. 1. 1.]
 [0. 0. 0. ... 0. 1. 1.]]
0

[[1. 1. 0. ... 0. 1. 0.]
 [1. 1. 1. ... 1. 0. 0.]
 [0. 1. 1. ... 0. 0. 1.]
 ...
 [0. 1. 0. ... 1. 0. 1.]
 [1. 0. 0. ... 0. 1. 0.]
 [0. 0. 1. ... 1. 0. 1.]]
0



# Running different models on this data

In [46]:
# Training function
stratify = True
def train_fake(model, filename, method_wandb, optimizer, criterion, w_decay, parameters, train_loader, valid_loader, test_loader=False, testing=False, n_epochs=80):
    test_loader = test_loader
    testing = testing
    n_epochs = n_epochs

    train_losses = []
    train_accuracies = []
    valid_losses = []
    valid_accuracies = []
    max_valid_accuracy = 0
    test_accuracy = 0

    # start a new wandb run to track this script
    run = wandb.init(
        # set the wandb project where this run will be logged
        project = "Fake_Alzheimers",
        # track hyperparameters and run metadata
        config = {
        "model type": method_wandb,
        "strat + w loss": stratify,
        "weight_decay": w_decay,
        "learning_rate": parameters[0],
        "hidden_channels": parameters[1],
        "num_layers": parameters[2],
        "dropout": parameters[3],
        "epochs": n_epochs},)

    for epoch in range(n_epochs):
        if testing:
            train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy, test_accuracy = f.epochs_training(model, optimizer, criterion, train_loader, valid_loader, test_loader, testing, test_accuracy, train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy)
            wandb.log({"Train Loss": train_losses[-1], "Train Accuracy": train_accuracies[-1], "Validation Loss": valid_losses[-1], "Validation Accuracy": valid_accuracies[-1], "Max Valid Accuracy": max_valid_accuracy, "Test Accuracy": test_accuracy})
        else:
            train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy = f.epochs_training(model, optimizer, criterion, train_loader, valid_loader, test_loader, testing, test_accuracy, train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy)
            wandb.log({"Train Loss": train_losses[-1], "Train Accuracy": train_accuracies[-1], "Validation Loss": valid_losses[-1], "Validation Accuracy": valid_accuracies[-1], "Max Valid Accuracy": max_valid_accuracy})
        print(f'Epoch {epoch+1}/{n_epochs}')
        print(f'Train Loss: {train_losses[-1]:.4f}, Validation Loss: {valid_losses[-1]:.4f}')
        print(f'Train Accuracy: {train_accuracies[-1]:.4f}, Validation Accuracy: {valid_accuracies[-1]:.4f}')
        print(f'Max Validation Accuracy: {max_valid_accuracy:.4f}')

    if testing:
        print('Test Accuracy:', test_accuracy)

    plt.figure(figsize=(12, 5))

    # Plot Losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label=f'Train Loss')
    plt.plot(valid_losses, label=f'Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label=f'Train Accuracy')
    plt.plot(valid_accuracies, label=f'Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    # Save the plot
    plt.savefig(filename)
    plt.show()
    wandb.finish()

    if testing:
        return train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy, test_accuracy
    else:
        return train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy

## Graph Neural Networks

In [22]:
# Defining a class to preprocess raw data into a format suitable for training Graph Neural Networks (GNNs).
## With the possibility of assigning weight to edges, adding the age feature, sex feature, and matrixe profiling.

class Fake2C_Raw_to_graph(InMemoryDataset):
    def __init__(self, root, corr_matrices, labels, transform=None, pre_transform=None):
        self.corr_matrices = corr_matrices
        self.labels = labels
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    # This function is used to process the raw data into a format suitable for GNNs, by constructing graphs out of the connectivity matrices.
    def process(self):
        graphs=[]
        for patient_idx, patient_matrix in enumerate(corr_matrices):
            # Here ROIs stands for Regions of Interest
            nbr_ROIs = patient_matrix.shape[0]
            edge_matrix = np.zeros((nbr_ROIs,nbr_ROIs))
            for j in range(nbr_ROIs):
                for k in range(nbr_ROIs):
                    edge_matrix[j,k] = patient_matrix[j,k]

            # Create a NetworkX graph from the edge matrix
            NetworkX_graph = from_numpy_array(edge_matrix)

            # Compute the degree, betweenness centrality, clustering coefficient, local efficiency for each node of the graph and the global efficiency of the graph
            degree_dict = dict(NetworkX_graph.degree())
            between_central_dict = nx.betweenness_centrality(NetworkX_graph)
            cluster_coeff_dict = nx.clustering(NetworkX_graph)
            global_eff = nx.global_efficiency(NetworkX_graph)
            local_eff_dict = {}
            for node in NetworkX_graph.nodes():
                subgraph_neighb = NetworkX_graph.subgraph(NetworkX_graph.neighbors(node))
                if subgraph_neighb.number_of_nodes() > 1:
                    efficiency = nx.global_efficiency(subgraph_neighb)
                else:
                    efficiency = 0.0
                local_eff_dict[node] = efficiency

            # Convert the degree, betweenness centrality, local efficiency, clustering coefficient and ratio of local to global efficiency dictionaries to NumPy arrays then normalize them
            degree_array = dict_to_array(degree_dict)
            degree_array_norm = normalize_array(degree_array)

            between_central_array = dict_to_array(between_central_dict)
            between_central_array_norm = normalize_array(between_central_array)

            local_efficiency_array = dict_to_array(local_eff_dict)
            local_eff_array_norm = normalize_array(local_efficiency_array)

            ratio_local_global_array = dict_to_array(local_eff_dict) / global_eff
            ratio_local_global_array_norm = normalize_array(ratio_local_global_array)

            cluster_coeff_array = dict_to_array(cluster_coeff_dict)
            cluster_coeff_array_norm = normalize_array(cluster_coeff_array)

            # Initializing an array for the graph features
            x_array = np.stack([degree_array_norm, between_central_array_norm, local_eff_array_norm, cluster_coeff_array_norm, ratio_local_global_array_norm], axis=-1)
            x_array = x_array.astype(np.float32)

            # Concatenate the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to form a single feature vector
            x = torch.tensor(x_array, dtype=torch.float)

            # Create a Pytorch Geometric Data object from the NetworkX
            graph_data = from_networkx(NetworkX_graph)
            ## The feature matrix of the graph is the degree, betweenness centrality, local efficiency, clustering coefficient and ratio of local to global efficiency of each node
            graph_data.x = x
            ## The target/output variable that we want to predict is the diagnostic label of the patient
            graph_data.y = float(labels[patient_idx])
            graphs.append(graph_data)
            print('done with patient', patient_idx)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])


### GCN

In [23]:
os.environ['WANDB_NOTEBOOK_NAME']="Fake2C_GCN.ipynb"
method = 'GCN'

In [19]:
root = f'Fake2C_Raw_to_graph/model{method}'
dataset = Fake2C_Raw_to_graph(root, corr_matrices, labels)

data = dataset[0]

print()
print(data)
print('=============================================================')

# Some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Data(edge_index=[2, 5246], weight=[5246], x=[116, 5], y=[1], num_nodes=116)
Number of nodes: 116
Number of edges: 5246
Average node degree: 45.22
Has isolated nodes: False
Has self-loops: True
Is undirected: True


In [49]:
param_grid = {
    'learning_rate': [0.0001],
    'hidden_channels': [64],
    'num_layers': [1],
    'dropout_rate': [0.0],
    'weight_decay': [0.0001]
}

# Creating the train, validation and test sets
train_loader, valid_loader, test_loader, nbr_classes, y_train = f.create_train_test_valid(dataset, stratify)

# Create combinations of hyperparameters
param_combinations = ParameterGrid(param_grid)
n_epochs = 800
in_channels = 5
nbr_classes = 2
stratify = True
method_wandb = 'GCN'
# Train using each combination
for params in param_combinations:
    filename = f'Fake2C_Models/GCN/lr{params["learning_rate"]}_hc{params["hidden_channels"]}_nl{params["num_layers"]}_d{params["dropout_rate"]}_epochs{n_epochs}_wdecay{params["weight_decay"]}.png'
    if os.path.exists(filename):
        pass
    else:
        parameters = [params['learning_rate'], params['hidden_channels'], params['num_layers'], params['dropout_rate']]
        model = m.GCN(in_channels=in_channels, hidden_channels=parameters[1], out_channels=nbr_classes, num_layers=parameters[2], dropout=parameters[3], nbr_classes=nbr_classes)
        if stratify:
            diag_lab = [0 , 1]
            class_freq = []
            for i in diag_lab:
                class_freq.append(np.count_nonzero(torch.Tensor(y_train) == i))
            class_freq = torch.FloatTensor(class_freq)
            class_weights = 1 / class_freq
            class_weights /= class_weights.sum()
            criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = torch.nn.CrossEntropyLoss() 
        if 'weight_decay' not in params.keys():
            w_decay = 0
        else:
            w_decay = params['weight_decay']
        optimizer = torch.optim.Adam(model.parameters(), lr=parameters[0], weight_decay=w_decay)
        train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy, test_accuracy = train_fake(model, filename, method_wandb, optimizer, criterion, w_decay, parameters, train_loader, valid_loader, test_loader, testing=True, n_epochs=80)

Number of training graphs: 700
Number of validation graphs: 100
Number of test graphs: 200
Number of classes: 2


  from IPython.core.display import HTML, display  # type: ignore


wandb: ERROR Error while calling W&B API: run alzheimers-cl/Fake_Alzheimers/2qixmudb was previously created and deleted; try a new run name (<Response [409]>)
Thread SenderThread:
Traceback (most recent call last):
  File "/Users/mathilde/anaconda3/envs/alzheimers-cl/lib/python3.11/site-packages/wandb/apis/normalize.py", line 41, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mathilde/anaconda3/envs/alzheimers-cl/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 2216, in upsert_run
    response = self.gql(
               ^^^^^^^^^
  File "/Users/mathilde/anaconda3/envs/alzheimers-cl/lib/python3.11/site-packages/wandb/sdk/internal/internal_api.py", line 341, in gql
    ret = self._retry_gql(
          ^^^^^^^^^^^^^^^^
  File "/Users/mathilde/anaconda3/envs/alzheimers-cl/lib/python3.11/site-packages/wandb/sdk/lib/retry.py", line 131, in __call__
    result = self._call_fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^

Problem at: /Users/mathilde/anaconda3/envs/alzheimers-cl/lib/python3.11/site-packages/wandb/sdk/wandb_init.py 854 getcaller


MailboxError: transport failed

### GAT

In [None]:
os.environ['WANDB_NOTEBOOK_NAME']="Fake2C_GAT.ipynb"
method = 'GAT'

In [40]:
root = f'Fake2C_Raw_to_graph/model{method}'
dataset = Fake2C_Raw_to_graph(root, corr_matrices, labels)

Processing...


done with patient 0
done with patient 1
done with patient 2
done with patient 3
done with patient 4
done with patient 5
done with patient 6
done with patient 7
done with patient 8
done with patient 9
done with patient 10
done with patient 11
done with patient 12
done with patient 13
done with patient 14
done with patient 15
done with patient 16
done with patient 17
done with patient 18
done with patient 19
done with patient 20
done with patient 21
done with patient 22
done with patient 23
done with patient 24
done with patient 25
done with patient 26
done with patient 27
done with patient 28
done with patient 29
done with patient 30
done with patient 31
done with patient 32
done with patient 33
done with patient 34
done with patient 35
done with patient 36
done with patient 37
done with patient 38
done with patient 39
done with patient 40
done with patient 41
done with patient 42
done with patient 43
done with patient 44
done with patient 45
done with patient 46
done with patient 47
do

Done!


AttributeError: 'Fake2C_Raw_to_graph' object has no attribute 'threshold'

In [None]:
param_grid = {
    'learning_rate': [0.0001],
    'hidden_channels': [64],
    'num_layers': [1],
    'dropout_rate': [0.0],
    'weight_decay': [0.0001],
    'heads': [3, 4]
}

# Create combinations of hyperparameters
param_combinations = ParameterGrid(param_grid)
n_epochs = 800
in_channels = 5
nbr_classes = 2
stratify = True
method_wandb = 'GAT'

# Creating the train, validation and test sets
train_loader, valid_loader, test_loader, nbr_classes, y_train = f.create_train_test_valid(dataset, stratify)

# Train using each combination
for params in param_combinations:
    filename = f'Fake2C_Models/GAT/lr{params["learning_rate"]}_hc{params["hidden_channels"]}_nl{params["num_layers"]}_d{params["dropout_rate"]}_epochs{n_epochs}_heads{params["heads"]}_wdecay{params["weight_decay"]}.png'
    if os.path.exists(filename):
        pass
    else:
        parameters = [params['learning_rate'], params['hidden_channels'], params['num_layers'], params['dropout_rate'], params['heads']]
        model = m.GAT(in_channels=in_channels, hidden_channels=parameters[1], out_channels=nbr_classes, num_layers=parameters[2], dropout=parameters[3], heads=parameters[4], nbr_classes=nbr_classes)
        if stratify:
            diag_lab = [0 , 1]
            class_freq = []
            for i in diag_lab:
                class_freq.append(np.count_nonzero(torch.Tensor(y_train) == i))
            class_freq = torch.FloatTensor(class_freq)
            class_weights = 1 / class_freq
            class_weights /= class_weights.sum()
            criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = torch.nn.CrossEntropyLoss() 
        if 'weight_decay' not in params.keys():
            w_decay = 0
        else:
            w_decay = params['weight_decay']
        optimizer = torch.optim.Adam(model.parameters(), lr=parameters[0], weight_decay=w_decay)
        train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy, test_accuracy = train_fake(model, filename, method_wandb, optimizer, criterion, w_decay, train_loader, valid_loader, parameters, test_loader, testing=True, n_epochs=80)



Number of training graphs: 700
Number of validation graphs: 100
Number of test graphs: 200
Number of classes: 2


  from IPython.core.display import HTML, display  # type: ignore




  from IPython.core.display import HTML, display  # type: ignore


TypeError: 'bool' object is not iterable

# Hypergraph Neural Networks

In [None]:
class Fake2C_Raw_to_Hypergraph(InMemoryDataset):
    def __init__(self, root, hg_data_path, labels, transform=None, pre_transform=None):
        self.method = method
        self.weight = weight
        self.threshold = threshold
        self.age = age
        self.sex = sex
        self.hg_data_path = hg_data_path
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    # This function is used to process the raw data into a format suitable for GNNs, by constructing graphs out of the connectivity matrices.
    def process(self):
        # Loading the prebuilt hypergraphs and the correlation matrices
        hg_dict_list = f.load_hg_dict(self.hg_data_path)

        graphs=[]
        for patient_idx, patient_matrix in enumerate(corr_matrices):
            # Create a NetworkX graph from the hypergraph matrix
            patient_hg = hg_dict_list[patient_idx]
            hypergraph = hnx.Hypergraph(patient_hg)

            # Adding the matrix profiling features to the feature array
            path = f'ADNI_full/matrix_profiles/matrix_profile_pearson/{patient_matrix}'
            if patient_matrix.endswith('.DS_Store'):
                continue  # Skip hidden system files like .DS_Store
            with open(path, "rb") as fl:
                patient_dict = pkl.load(fl)
            # combine dimensions
            features = np.array(patient_dict['mp']).reshape(len(patient_dict['mp']),-1)
            features = features.astype(np.float32)

            # Concatenate the degree, participation coefficient, betweenness centrality, local efficiency, and ratio of local to global efficiency arrays to form a single feature vector
            x = torch.tensor(features, dtype=torch.float)

            # Create a Pytorch Geometric Data object
            edge_index0 = []
            edge_index1 = []
            i = 0
            for hyperedge, nodes in hypergraph.incidence_dict.items():
                edge_index0 = np.concatenate((edge_index0, nodes), axis=0)
                for j in range(len(nodes)):
                    edge_index1.append(i)
                i += 1
            edge_index = np.stack([[int(x) for x in edge_index0], edge_index1], axis=0)
            y = torch.tensor(float(diagnostic_label[patient_idx]))
            hg_data = Data(x=x, edge_index=torch.tensor(edge_index, dtype=torch.long), y=y)
            graphs.append(hg_data)

        data, slices = self.collate(graphs)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
# Defining functions to save the fake hypergraphs
def save_fake_hypergraph(hg_dict, directory, method, id):
    dir = f'{directory}/{method}'
    if not os.path.exists(dir):
        os.makedirs(dir)
    with open(f'{dir}/{method}/{id}.pkl', 'wb') as f:
        pkl.dump(hg_dict, f)
    return

def save_all_fake_hypergraphs(method_list, corr_matrices, labels):
    for i, patient_matrix in enumerate(corr_matrices):
        print(f'Processing patient {i}')
        for method in method_list:
            if method == 'maximal_clique':
                root = f'Fake2C_Raw_to_graph/model{method}'
                dataset = Fake2C_Raw_to_graph(root, corr_matrices, labels)
                graph = f.r2g_to_nx(dataset[i])
                _, hg_dict = m.graph_to_hypergraph_max_cliques(graph)
            elif method == 'knn':
                k_neighbors = 3
                _, hg_dict = m.generate_hypergraph_from_knn(patient_matrix, k_neighbors)
            save_fake_hypergraph(hg_dict, 'Fake_hypergraphs', method, i)
            print(f'Patient {i} processed and saved for the {method}')


In [None]:
os.environ['WANDB_NOTEBOOK_NAME']="Fake2C_HGKNN.ipynb"
method = 'HGKNN'

In [None]:
# Building the graphs
hg_data_path = f'Fake_hypergraphs/{method}/'
root = f'Fake2C_Raw_to_hypergraph/model_{method}'
dataset = Fake2C_Raw_to_Hypergraph(root, hg_data_path, labels)

In [None]:
# HGConv with KNN method

# param_grid = {
#     'learning_rate': [0.001, 0.0001],
#     'hidden_channels': [128, 64],
#     'num_layers': [3, 2, 1],
#     'dropout_rate': [0.2, 0.1, 0.0],
#     'weight_decay': [0.001, 0.0001]
# }
param_grid = {
    'learning_rate': [0.0001, 0.001],
    'hidden_channels': [64, 128],
    'num_layers': [1, 2, 3],
    'dropout_rate': [0.0, 0.1, 0.2],
    'weight_decay': [0.0001, 0.001]
}

# Creating the train, validation and test sets
stratify = True
train_loader, valid_loader, test_loader, nbr_classes, y_train = f.create_train_test_valid(dataset, stratify)

# Create combinations of hyperparameters
param_combinations = ParameterGrid(param_grid)
n_epochs = 80
in_channels = 5
nbr_classes = 2
method_wandb = 'HGKNN'

# Train using each combination
for params in param_combinations:
    filename = f'Fake2C_Models/HGKNN/lr{params["learning_rate"]}_hc{params["hidden_channels"]}_nl{params["num_layers"]}_d{params["dropout_rate"]}_epochs{n_epochs}_wdecay{params["weight_decay"]}_w{weight}.png'
    if os.path.exists(filename):
        pass
    else:
        parameters = [params['learning_rate'], params['hidden_channels'], params['num_layers'], params['dropout_rate']]
        model = m.HGConv(in_channels=in_channels, hidden_channels=parameters[1], out_channels=nbr_classes, num_layers=parameters[2], dropout=parameters[3], nbr_classes=nbr_classes)
        if stratify:
            diag_lab = [0 , 1]
            class_freq = []
            for i in diag_lab:
                class_freq.append(np.count_nonzero(torch.Tensor(y_train) == i))
            class_freq = torch.FloatTensor(class_freq)
            class_weights = 1 / class_freq
            class_weights /= class_weights.sum()
            criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        else:
            criterion = torch.nn.CrossEntropyLoss() 
        if 'weight_decay' not in params.keys():
            w_decay = 0
        else:
            w_decay = params['weight_decay']
        optimizer = torch.optim.Adam(model.parameters(), lr=parameters[0], weight_decay=w_decay)
        train_losses, train_accuracies, valid_losses, valid_accuracies, max_valid_accuracy, test_accuracy = train_fake(model, filename, method_wandb, optimizer, criterion, w_decay, parameters, train_loader, valid_loader, test_loader, testing=True, n_epochs=800)

In [None]:
# HGConv with maximal clique method

os.environ['WANDB_NOTEBOOK_NAME']="Fake2C_HGMC.ipynb"
