In [None]:
# Split the data for cross validation

# Split the data with stratification into cv and test sets
cv_data_sc, test_data_sc, cv_labels_sc, test_labels_sc = train_test_split(graphs_sc, labels_sc, test_size=0.15, random_state=42, stratify=labels_sc)
cv_data_sc_combined, test_data_sc_combined, cv_labels_sc_combined, test_labels_sc_combined = train_test_split(graphs_sc_combined, labels_sc_combined, test_size=0.15, random_state=42, stratify=labels_sc_combined)

print(f'cv_data_sc len: {len(cv_data_sc)}, type: {type(cv_data_sc)}')
print(f'test_data_sc len: {len(test_data_sc)}, type: {type(test_data_sc)}')

In [None]:
from sklearn.model_selection import StratifiedKFold, ParameterGrid

# Define the parameter grid
"""
param_grid = {
    'num_heads': [1, 3, 5],
    'hidden_channels': [8, 16],
    'out_channels': [8, 16],
    'num_epochs': [50, 100],
    'learning_rate': [1e-5, 1e-4, 1e-3]
}
"""
# parameter grid simplified
param_grid = {
    'num_heads': [1],
    'out_channels': [16, 32],
    'num_epochs': [30, 50],
    'learning_rate': [1e-5, 1e-4]
    }

def perform_grid_search(graphs, labels, num_splits, param_grid, model, criterion, optimizer, device):
    
    # Initialize StratifiedKFold
    kf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)

    # Initialize lists to store overall metrics
    overall_results = []

    # Loop over each parameter combination
    for params in ParameterGrid(param_grid):
        num_heads = params['num_heads']
        hidden_channels = params['hidden_channels']
        out_channels = params['out_channels']
        num_epochs = params['num_epochs']
        learning_rate = params['learning_rate']
        batch_size = params['batch_size']

        fold_train_loss = []
        fold_val_loss = []
        fold_train_accuracy = []
        fold_val_accuracy = []
        fold_train_f1 = []
        fold_val_f1 = []

        for fold, (train_index, val_index) in enumerate(kf.split(graphs, labels)):
            # Split the dataset into training and validation sets for this fold
            train_data = [graphs[i] for i in train_index]
            val_data = [graphs[i] for i in val_index]

            train_labels = [labels[i] for i in train_index]
            val_labels = [labels[i] for i in val_index]

            # Create PyTorch datasets and data loaders
            train_dataset = GraphDataset(train_data, train_labels)
            val_dataset = GraphDataset(val_data, val_labels)

            # Create data loaders
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

            # Initialize metrics for this fold
            train_loss_history = []
            val_loss_history = []
            train_accuracy_history = []
            val_accuracy_history = []
            train_f1_history = []
            val_f1_history = []

            lr = learning_rate  # Initialize learning rate

            for epoch in range(num_epochs):
                train_loss, train_accuracy, train_f1 = train(train_loader, model, criterion, optimizer)
                val_loss, val_accuracy, val_f1 = validate(val_loader, model, criterion, device)

                # Add metrics to lists
                train_loss_history.append(train_loss)
                val_loss_history.append(val_loss)
                train_accuracy_history.append(train_accuracy)
                val_accuracy_history.append(val_accuracy)
                train_f1_history.append(train_f1)
                val_f1_history.append(val_f1)

                #print(f'Params: {params}, Fold [{fold+1}/5], Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Val Accuracy: {val_accuracy:.4f}, Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}')
                # Code to update the lr
                # lr *= learning_rate_decay
                # update_lr(optimizer, lr)

            # Aggregate metrics for this fold
            fold_train_loss.append(train_loss_history)
            fold_val_loss.append(val_loss_history)
            fold_train_accuracy.append(train_accuracy_history)
            fold_val_accuracy.append(val_accuracy_history)
            fold_train_f1.append(train_f1_history)
            fold_val_f1.append(val_f1_history)

        # Calculate average metrics across all folds for this parameter combination
        avg_train_loss = np.mean(fold_train_loss, axis=0)
        avg_val_loss = np.mean(fold_val_loss, axis=0)
        avg_train_accuracy = np.mean(fold_train_accuracy, axis=0)
        avg_val_accuracy = np.mean(fold_val_accuracy, axis=0)
        avg_train_f1 = np.mean(fold_train_f1, axis=0)
        avg_val_f1 = np.mean(fold_val_f1, axis=0)

        overall_results.append({
            'params': params,
            'avg_train_loss': avg_train_loss[-1],
            'avg_val_loss': avg_val_loss[-1],
            'avg_train_accuracy': avg_train_accuracy[-1],
            'avg_val_accuracy': avg_val_accuracy[-1],
            'avg_train_f1': avg_train_f1[-1],
            'avg_val_f1': avg_val_f1[-1]
        })

    best_result = max(overall_results, key=lambda x: x['avg_val_f1'])
    return best_result['params'], best_result['avg_val_f1'], best_result['avg_val_accuracy']


In [None]:
# Perform grid search
num_splits = 5
best_params_sc_combined, best_val_f1_sc_combined, best_val_accuracy_sc_combined = perform_grid_search(cv_data_sc_combined, cv_labels_sc_combined, num_splits, param_grid, batch_size, model, criterion, optimizer, device)
print(f'Best parameters for combined connectivity: {best_params_sc_combined}, Best validation F1: {best_val_f1_sc_combined}, Best validation accuracy: {best_val_accuracy_sc_combined}')
