In [7]:
import numpy as np
import pandas as pd
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# DATA LOADING (Same as GCN)
# ============================================================================

def categorize_pcl5(pcl5_score):
    """Categorize PCL5 scores into severity groups"""
    if 0 <= pcl5_score <= 25:
        return 0
    elif 26 <= pcl5_score <= 50:
        return 1
    elif 51 <= pcl5_score <= 75:
        return 2
    elif 76 <= pcl5_score <= 100:
        return 3
    else:
        return -1

def load_ptsd_data_with_pcl5(mat_file, behavioral_csv):
    """Load PTSD data"""
    print("Loading PTSD dataset with PCL5 scores...")
    print("="*50)
    
    behavioral_df = pd.read_csv(behavioral_csv)
    behavioral_df['PCL5_Category'] = behavioral_df['PCL5 score'].apply(categorize_pcl5)
    behavioral_df = behavioral_df[behavioral_df['PCL5_Category'] != -1]
    
    data = sio.loadmat(mat_file)
    
    controls = np.transpose(data['conn_sfc_dc_filt_controls'], (2, 0, 1))
    ptsd = np.transpose(data['conn_sfc_dc_filt_ptsd'], (2, 0, 1))
    pcs_ptsd = np.transpose(data['conn_sfc_dc_filt_pcsptsd'], (2, 0, 1))
    
    fc_matrices = np.concatenate([controls, ptsd, pcs_ptsd], axis=0)
    
    if len(behavioral_df) != fc_matrices.shape[0]:
        min_len = min(len(behavioral_df), fc_matrices.shape[0])
        fc_matrices = fc_matrices[:min_len]
        behavioral_df = behavioral_df.iloc[:min_len]
    
    labels = behavioral_df['PCL5_Category'].values
    pcl5_scores = behavioral_df['PCL5 score'].values
    
    print(f"Final dataset: {len(labels)} subjects")
    print("="*50)
    
    return fc_matrices, labels, pcl5_scores

# ============================================================================
# GENERATOR
# ============================================================================

class Generator(nn.Module):
    """Generator for creating synthetic FC matrices"""
    def __init__(self, latent_dim=50, n_regions=125, embed_dim=10, n_classes=4):
        super(Generator, self).__init__()
        
        self.n_regions = n_regions
        self.embed_dim = embed_dim
        self.n_classes = n_classes
        
        self.fc1 = nn.Linear(latent_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, n_regions * embed_dim)
        
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(1024)
    
    def forward(self, z):
        x = F.leaky_relu(self.bn1(self.fc1(z)), 0.2)
        x = F.leaky_relu(self.bn2(self.fc2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.fc3(x)), 0.2)
        x = self.fc4(x)
        
        x = x.view(-1, self.n_regions, self.embed_dim)
        fc_matrix = torch.bmm(x, x.transpose(1, 2))
        fc_matrix = torch.tanh(fc_matrix)
        
        return fc_matrix

# ============================================================================
# DISCRIMINATOR
# ============================================================================

class BrainNetCNN_Discriminator(nn.Module):
    """Discriminator based on BrainNetCNN"""
    def __init__(self, n_regions=125, n_classes=4, dropout_rate=0.5):
        super(BrainNetCNN_Discriminator, self).__init__()
        
        self.n_regions = n_regions
        
        self.e2e_conv = nn.Conv2d(1, 32, kernel_size=(1, n_regions), padding=0)
        self.e2n_conv = nn.Conv2d(32, 64, kernel_size=(n_regions, 1), padding=0)
        self.n2g_fc = nn.Linear(64, 128)
        
        self.combined_fc1 = nn.Linear(128, 64)
        self.combined_fc2 = nn.Linear(64, 32)
        
        self.validity_head = nn.Linear(32, 1)
        self.class_head = nn.Linear(32, n_classes)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(64)
    
    def forward(self, fc_matrix):
        x = fc_matrix.unsqueeze(1)
        
        x = self.e2e_conv(x)
        x = self.bn1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)
        
        x = self.e2n_conv(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)
        
        x = x.view(x.size(0), -1)
        x = self.n2g_fc(x)
        x = self.bn3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)
        
        x = self.combined_fc1(x)
        x = self.bn4(x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout(x)
        
        x = self.combined_fc2(x)
        x = F.leaky_relu(x, 0.2)
        
        validity = torch.sigmoid(self.validity_head(x))
        class_logits = self.class_head(x)
        
        return validity, class_logits

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

def train_gan(generator, discriminator, train_loader, 
              g_optimizer, d_optimizer, device, n_classes=4, latent_dim=50):
    """Training function for GAN"""
    generator.train()
    discriminator.train()
    
    g_losses = []
    d_losses = []
    
    for fc_batch, labels in train_loader:
        fc_batch = fc_batch.to(device)
        labels = labels.to(device)
        
        batch_size = fc_batch.size(0)
        
        real_labels_validity = torch.ones(batch_size, 1).to(device)
        fake_labels_validity = torch.zeros(batch_size, 1).to(device)
        
        # Train Discriminator
        d_optimizer.zero_grad()
        
        real_validity, real_class = discriminator(fc_batch)
        d_real_loss = F.binary_cross_entropy(real_validity, real_labels_validity)
        d_real_class_loss = F.cross_entropy(real_class, labels)
        
        z_noise = torch.randn(batch_size, latent_dim - n_classes).to(device)
        z_labels = F.one_hot(labels, n_classes).float()
        z = torch.cat([z_noise, z_labels], dim=1)
        
        fake_fc = generator(z)
        fake_validity, fake_class = discriminator(fake_fc.detach())
        d_fake_loss = F.binary_cross_entropy(fake_validity, fake_labels_validity)
        
        d_loss = d_real_loss + d_fake_loss + d_real_class_loss
        d_loss.backward()
        d_optimizer.step()
        
        # Train Generator
        g_optimizer.zero_grad()
        
        z_noise = torch.randn(batch_size, latent_dim - n_classes).to(device)
        z_labels = F.one_hot(labels, n_classes).float()
        z = torch.cat([z_noise, z_labels], dim=1)
        
        fake_fc = generator(z)
        fake_validity, fake_class = discriminator(fake_fc)
        
        g_loss = F.binary_cross_entropy(fake_validity, real_labels_validity)
        g_class_loss = F.cross_entropy(fake_class, labels)
        g_total_loss = g_loss + g_class_loss
        
        g_total_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_total_loss.item())
        d_losses.append(d_loss.item())
    
    return np.mean(g_losses), np.mean(d_losses)

def generate_synthetic_data(generator, n_samples, labels, device, n_classes=4, latent_dim=50):
    """Generate synthetic FC matrices - FIXED"""
    generator.eval()
    
    synthetic_fc = []
    
    with torch.no_grad():
        for i in range(0, n_samples, 32):
            batch_size = min(32, n_samples - i)
            batch_labels = labels[i:i+batch_size]
            
            z_noise = torch.randn(batch_size, latent_dim - n_classes).to(device)
            z_labels = F.one_hot(torch.tensor(batch_labels, dtype=torch.long).to(device), n_classes).float()
            z = torch.cat([z_noise, z_labels], dim=1)
            
            fake_fc = generator(z)
            synthetic_fc.append(fake_fc.cpu().numpy())
    
    return np.concatenate(synthetic_fc, axis=0)

def evaluate_discriminator(discriminator, loader, device):
    """Evaluate discriminator as classifier"""
    discriminator.eval()
    predictions = []
    true_labels = []
    probabilities = []
    
    with torch.no_grad():
        for fc_batch, labels in loader:
            fc_batch = fc_batch.to(device)
            
            _, class_logits = discriminator(fc_batch)
            pred = class_logits.argmax(dim=1)
            prob = F.softmax(class_logits, dim=1)
            
            predictions.extend(pred.cpu().numpy())
            true_labels.extend(labels.numpy())
            probabilities.extend(prob.cpu().numpy())
    
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    probabilities = np.array(probabilities)
    
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='macro', zero_division=0
    )
    
    try:
        auc = roc_auc_score(true_labels, probabilities, multi_class='ovr', average='macro')
    except:
        auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'predictions': predictions,
        'true_labels': true_labels,
        'probabilities': probabilities
    }

# ============================================================================
# HYPERPARAMETER TUNING FOR GAN
# ============================================================================

def hyperparameter_tuning_gan(train_fc, train_labels, val_fc, val_labels, n_classes, device):
    """
    Perform hyperparameter tuning for GAN
    """
    print("\n" + "="*70)
    print("HYPERPARAMETER TUNING FOR GAN")
    print("="*70)
    
    # Define hyperparameter grid
    param_grid = {
        'g_lr': [0.00005, 0.0001, 0.0002],
        'd_lr': [0.00005, 0.0001, 0.0002],
        'batch_size': [8, 16, 32],
        'dropout_rate': [0.3, 0.5, 0.7],
        'latent_dim': [50, 100, 150],
        'embed_dim': [8, 10, 12]
    }
    
    print(f"\nSearching over hyperparameters...")
    print("Using Random Search with 25 configurations\n")
    
    best_score = 0
    best_params = None
    best_gen_state = None
    best_disc_state = None
    all_results = []
    
    n_random_search = 25
    
    for trial in range(n_random_search):
        # Random sample - FIX: Convert to Python types
        params = {
            'g_lr': float(np.random.choice(param_grid['g_lr'])),
            'd_lr': float(np.random.choice(param_grid['d_lr'])),
            'batch_size': int(np.random.choice(param_grid['batch_size'])),
            'dropout_rate': float(np.random.choice(param_grid['dropout_rate'])),
            'latent_dim': int(np.random.choice(param_grid['latent_dim'])),
            'embed_dim': int(np.random.choice(param_grid['embed_dim']))
        }
        
        print(f"Trial {trial+1}/{n_random_search}")
        print(f"  Params: g_lr={params['g_lr']}, d_lr={params['d_lr']}, "
              f"bs={params['batch_size']}, dropout={params['dropout_rate']}, "
              f"latent={params['latent_dim']}, embed={params['embed_dim']}")
        
        # Create data loaders
        train_dataset = TensorDataset(
            torch.tensor(train_fc, dtype=torch.float32),
            torch.tensor(train_labels, dtype=torch.long)
        )
        val_dataset = TensorDataset(
            torch.tensor(val_fc, dtype=torch.float32),
            torch.tensor(val_labels, dtype=torch.long)
        )
        
        train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], shuffle=False)
        
        # Initialize models
        generator = Generator(
            latent_dim=params['latent_dim'],
            n_regions=125,
            embed_dim=params['embed_dim'],
            n_classes=n_classes
        ).to(device)
        
        discriminator = BrainNetCNN_Discriminator(
            n_regions=125,
            n_classes=n_classes,
            dropout_rate=params['dropout_rate']
        ).to(device)
        
        g_optimizer = torch.optim.Adam(generator.parameters(), lr=params['g_lr'], betas=(0.5, 0.999))
        d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=params['d_lr'], betas=(0.5, 0.999))
        
        # Train GAN
        best_val_acc = 0
        best_val_f1 = 0
        n_epochs = 100
        
        for epoch in range(n_epochs):
            g_loss, d_loss = train_gan(
                generator, discriminator, train_loader,
                g_optimizer, d_optimizer, device,
                n_classes=n_classes, latent_dim=params['latent_dim']
            )
            
            if (epoch + 1) % 20 == 0:
                val_metrics = evaluate_discriminator(discriminator, val_loader, device)
                
                if val_metrics['f1'] > best_val_f1:
                    best_val_acc = val_metrics['accuracy']
                    best_val_f1 = val_metrics['f1']
                    temp_gen_state = {k: v.cpu().clone() for k, v in generator.state_dict().items()}
                    temp_disc_state = {k: v.cpu().clone() for k, v in discriminator.state_dict().items()}
        
        print(f"  Best Val Acc: {best_val_acc:.4f}, Best Val F1: {best_val_f1:.4f}")
        
        # Store results
        result = {
            'params': params.copy(),
            'val_accuracy': best_val_acc,
            'val_f1': best_val_f1
        }
        all_results.append(result)
        
        # Update best
        if best_val_f1 > best_score:
            best_score = best_val_f1
            best_params = params.copy()
            best_gen_state = temp_gen_state
            best_disc_state = temp_disc_state
            print(f"  *** New best F1: {best_score:.4f} ***")
        
        print()
    
    # Display top 5
    print("\n" + "="*70)
    print("TOP 5 CONFIGURATIONS")
    print("="*70)
    
    sorted_results = sorted(all_results, key=lambda x: x['val_f1'], reverse=True)
    for i, result in enumerate(sorted_results[:5]):
        print(f"\nRank {i+1}:")
        print(f"  Val Acc: {result['val_accuracy']:.4f}, Val F1: {result['val_f1']:.4f}")
        print(f"  Params: {result['params']}")
    
    print("\n" + "="*70)
    print("BEST HYPERPARAMETERS")
    print("="*70)
    print(f"Val F1 Score: {best_score:.4f}")
    print(f"Parameters:")
    for key, value in best_params.items():
        print(f"  {key}: {value}")
    print("="*70 + "\n")
    
    return best_params, best_gen_state, best_disc_state, all_results

# ============================================================================
# MAIN FUNCTION
# ============================================================================

def main_ptsd_gan_pcl5_tuning():
    """Main pipeline with hyperparameter tuning for GAN"""
    print(f"\nUsing device: {device}\n")
    
    # Load data
    fc_matrices, labels, pcl5_scores = load_ptsd_data_with_pcl5(
        'Static_functional_connectivity_ptsd_dc_filt.mat',
        'PTSD_Behavioral_measures.csv'
    )
    
    # Split data: 70% train, 15% val, 15% test
    indices = np.arange(len(labels))
    train_idx, test_val_idx = train_test_split(
        indices, test_size=0.3, stratify=labels, random_state=42
    )
    
    val_idx, test_idx = train_test_split(
        test_val_idx, test_size=0.5, stratify=labels[test_val_idx], random_state=42
    )
    
    print(f"Data split:")
    print(f"  Train: {len(train_idx)} ({len(train_idx)/len(labels)*100:.1f}%)")
    print(f"  Val:   {len(val_idx)} ({len(val_idx)/len(labels)*100:.1f}%)")
    print(f"  Test:  {len(test_idx)} ({len(test_idx)/len(labels)*100:.1f}%)")
    
    train_fc = fc_matrices[train_idx]
    train_labels = labels[train_idx]
    
    val_fc = fc_matrices[val_idx]
    val_labels = labels[val_idx]
    
    test_fc = fc_matrices[test_idx]
    test_labels = labels[test_idx]
    test_pcl5 = pcl5_scores[test_idx]
    
    n_classes = len(np.unique(labels))
    
    # Hyperparameter tuning
    best_params, best_gen_state, best_disc_state, all_results = hyperparameter_tuning_gan(
        train_fc, train_labels, val_fc, val_labels, n_classes, device
    )
    
    # Train final model with best hyperparameters
    print("\n" + "="*70)
    print("TRAINING FINAL GAN MODEL")
    print("="*70)
    
    # Combine train and val
    final_train_fc = np.concatenate([train_fc, val_fc], axis=0)
    final_train_labels = np.concatenate([train_labels, val_labels], axis=0)
    
    final_train_dataset = TensorDataset(
        torch.tensor(final_train_fc, dtype=torch.float32),
        torch.tensor(final_train_labels, dtype=torch.long)
    )
    final_train_loader = DataLoader(
        final_train_dataset, 
        batch_size=best_params['batch_size'], 
        shuffle=True
    )
    
    # Initialize final models
    final_generator = Generator(
        latent_dim=best_params['latent_dim'],
        n_regions=125,
        embed_dim=best_params['embed_dim'],
        n_classes=n_classes
    ).to(device)
    
    final_discriminator = BrainNetCNN_Discriminator(
        n_regions=125,
        n_classes=n_classes,
        dropout_rate=best_params['dropout_rate']
    ).to(device)
    
    g_optimizer = torch.optim.Adam(
        final_generator.parameters(), 
        lr=best_params['g_lr'], 
        betas=(0.5, 0.999)
    )
    d_optimizer = torch.optim.Adam(
        final_discriminator.parameters(), 
        lr=best_params['d_lr'], 
        betas=(0.5, 0.999)
    )
    
    # Train GAN
    print("\nTraining GAN for 150 epochs...")
    for epoch in range(150):
        g_loss, d_loss = train_gan(
            final_generator, final_discriminator, final_train_loader,
            g_optimizer, d_optimizer, device,
            n_classes=n_classes, latent_dim=best_params['latent_dim']
        )
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}: G_Loss={g_loss:.4f}, D_Loss={d_loss:.4f}")
    
    # Generate synthetic data
    print("\nGenerating synthetic data...")
    n_synthetic = len(final_train_labels)
    synthetic_fc = generate_synthetic_data(
        final_generator, n_synthetic, final_train_labels, device,
        n_classes=n_classes, latent_dim=best_params['latent_dim']
    )
    
    # Augment and retrain discriminator
    print("Augmenting training data and retraining discriminator...")
    augmented_fc = np.concatenate([final_train_fc, synthetic_fc], axis=0)
    augmented_labels = np.concatenate([final_train_labels, final_train_labels], axis=0)
    
    augmented_dataset = TensorDataset(
        torch.tensor(augmented_fc, dtype=torch.float32),
        torch.tensor(augmented_labels, dtype=torch.long)
    )
    augmented_loader = DataLoader(
        augmented_dataset, 
        batch_size=best_params['batch_size'], 
        shuffle=True
    )
    
    d_optimizer_retrain = torch.optim.Adam(
        final_discriminator.parameters(), 
        lr=best_params['d_lr']
    )
    
    for epoch in range(100):
        final_discriminator.train()
        for fc_batch, label_batch in augmented_loader:
            fc_batch = fc_batch.to(device)
            label_batch = label_batch.to(device)
            
            d_optimizer_retrain.zero_grad()
            _, class_logits = final_discriminator(fc_batch)
            loss = F.cross_entropy(class_logits, label_batch)
            loss.backward()
            d_optimizer_retrain.step()
        
        if (epoch + 1) % 20 == 0:
            print(f"Retrain Epoch {epoch+1}")
    
    # Save models
    torch.save(final_generator.state_dict(), 'final_generator_pcl5.pt')
    torch.save(final_discriminator.state_dict(), 'final_discriminator_pcl5.pt')
    
    # Test evaluation
    print("\n" + "="*70)
    print("TEST SET EVALUATION")
    print("="*70)
    
    test_dataset = TensorDataset(
        torch.tensor(test_fc, dtype=torch.float32),
        torch.tensor(test_labels, dtype=torch.long)
    )
    test_loader = DataLoader(test_dataset, batch_size=best_params['batch_size'], shuffle=False)
    
    test_results = evaluate_discriminator(final_discriminator, test_loader, device)
    
    print(f"\nTest Results:")
    print(f"  Accuracy:  {test_results['accuracy']:.4f}")
    print(f"  Precision: {test_results['precision']:.4f}")
    print(f"  Recall:    {test_results['recall']:.4f}")
    print(f"  F1 Score:  {test_results['f1']:.4f}")
    print(f"  AUC:       {test_results['auc']:.4f}")
    
    cm_test = confusion_matrix(test_results['true_labels'], test_results['predictions'])
    print(f"\nConfusion Matrix:\n{cm_test}")
    
    # Plot
    severity_labels = ['Minimal\n(0-25)', 'Mild\n(26-50)', 'Moderate\n(51-75)', 'Severe\n(76-100)']
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Greens',
                xticklabels=severity_labels[:n_classes],
                yticklabels=severity_labels[:n_classes])
    plt.title('GAN Test Set Confusion Matrix\n(PCL5 Severity Categories)', fontsize=14)
    plt.ylabel('True Severity', fontsize=12)
    plt.xlabel('Predicted Severity', fontsize=12)
    plt.tight_layout()
    plt.savefig('confusion_matrix_gan_pcl5_tuned.png', dpi=300)
    print("\nConfusion matrix saved")
    
    # PCL5 Analysis
    print("\n" + "="*70)
    print("PCL5 SCORE ANALYSIS")
    print("="*70)
    
    for cat in range(n_classes):
        cat_mask = test_labels == cat
        if np.sum(cat_mask) > 0:
            cat_scores = test_pcl5[cat_mask]
            cat_predictions = test_results['predictions'][cat_mask]
            correct = np.sum(cat_predictions == cat)
            
            print(f"\nCategory {cat} ({severity_labels[cat].strip()}):")
            print(f"  N: {np.sum(cat_mask)}")
            print(f"  Correct: {correct}/{np.sum(cat_mask)} ({100*correct/np.sum(cat_mask):.1f}%)")
            print(f"  PCL5: {cat_scores.mean():.1f} ± {cat_scores.std():.1f}")
    
    # Plot hyperparameter results
    plot_hyperparameter_results_gan(all_results)
    
    return best_params, test_results, final_discriminator, final_generator, all_results

def plot_hyperparameter_results_gan(all_results):
    """Plot GAN hyperparameter search results"""
    
    g_lrs = [r['params']['g_lr'] for r in all_results]
    d_lrs = [r['params']['d_lr'] for r in all_results]
    latent_dims = [r['params']['latent_dim'] for r in all_results]
    accuracies = [r['val_accuracy'] for r in all_results]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Generator LR vs Accuracy
    axes[0].scatter(g_lrs, accuracies, alpha=0.6, s=100)
    axes[0].set_xscale('log')
    axes[0].set_xlabel('Generator Learning Rate', fontsize=12)
    axes[0].set_ylabel('Validation Accuracy', fontsize=12)
    axes[0].set_title('Generator LR vs Accuracy', fontsize=13)
    axes[0].grid(alpha=0.3)
    
    # Discriminator LR vs Accuracy
    axes[1].scatter(d_lrs, accuracies, alpha=0.6, s=100, color='orange')
    axes[1].set_xscale('log')
    axes[1].set_xlabel('Discriminator Learning Rate', fontsize=12)
    axes[1].set_ylabel('Validation Accuracy', fontsize=12)
    axes[1].set_title('Discriminator LR vs Accuracy', fontsize=13)
    axes[1].grid(alpha=0.3)
    
    # Latent Dimension vs Accuracy
    axes[2].scatter(latent_dims, accuracies, alpha=0.6, s=100, color='green')
    axes[2].set_xlabel('Latent Dimension', fontsize=12)
    axes[2].set_ylabel('Validation Accuracy', fontsize=12)
    axes[2].set_title('Latent Dim vs Accuracy', fontsize=13)
    axes[2].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('hyperparameter_search_results_gan.png', dpi=300)
    print("Hyperparameter search plots saved")
    plt.show()

# Run
if __name__ == "__main__":
    best_params, test_results, final_disc, final_gen, search_results = main_ptsd_gan_pcl5_tuning()

Using device: cpu

Using device: cpu

Loading PTSD dataset with PCL5 scores...
Final dataset: 174 subjects
Data split:
  Train: 121 (69.5%)
  Val:   26 (14.9%)
  Test:  27 (15.5%)

HYPERPARAMETER TUNING FOR GAN

Searching over hyperparameters...
Using Random Search with 25 configurations

Trial 1/25
  Params: g_lr=0.0002, d_lr=5e-05, bs=32, dropout=0.7, latent=50, embed=8
  Best Val Acc: 0.2692, Best Val F1: 0.1061
  *** New best F1: 0.1061 ***

Trial 2/25
  Params: g_lr=0.0002, d_lr=0.0001, bs=32, dropout=0.7, latent=150, embed=12
  Best Val Acc: 0.3462, Best Val F1: 0.2571
  *** New best F1: 0.2571 ***

Trial 3/25
  Params: g_lr=5e-05, d_lr=0.0002, bs=16, dropout=0.3, latent=100, embed=10
  Best Val Acc: 0.8077, Best Val F1: 0.8093
  *** New best F1: 0.8093 ***

Trial 4/25
  Params: g_lr=0.0001, d_lr=0.0001, bs=8, dropout=0.3, latent=100, embed=10


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 64, 1, 1])

In [None]:
import numpy as np
import pandas as pd
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product
import warnings
warnings.filterwarnings('ignore')

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ============================================================================
# DATA LOADING WITH PCL5 SCORES
# ============================================================================

def categorize_pcl5(pcl5_score):
    """Categorize PCL5 scores into severity groups"""
    if 0 <= pcl5_score <= 25:
        return 0  # Minimal/No PTSD
    elif 26 <= pcl5_score <= 50:
        return 1  # Mild PTSD
    elif 51 <= pcl5_score <= 75:
        return 2  # Moderate PTSD
    elif 76 <= pcl5_score <= 100:
        return 3  # Severe PTSD
    else:
        return -1  # Invalid score

def load_ptsd_data_with_pcl5(mat_file, behavioral_csv):
    """Load PTSD data from .mat file and merge with PCL5 scores"""
    print("Loading PTSD dataset with PCL5 scores...")
    print("="*50)
    
    behavioral_df = pd.read_csv(behavioral_csv)
    print(f"Behavioral data loaded: {len(behavioral_df)} subjects")
    
    behavioral_df['PCL5_Category'] = behavioral_df['PCL5 score'].apply(categorize_pcl5)
    
    invalid_scores = behavioral_df[behavioral_df['PCL5_Category'] == -1]
    if len(invalid_scores) > 0:
        print(f"Warning: {len(invalid_scores)} subjects have invalid PCL5 scores")
        behavioral_df = behavioral_df[behavioral_df['PCL5_Category'] != -1]
    
    print(f"\nPCL5 Category Distribution:")
    print(behavioral_df['PCL5_Category'].value_counts().sort_index())
    
    data = sio.loadmat(mat_file)
    
    controls = data['conn_sfc_dc_filt_controls']
    ptsd = data['conn_sfc_dc_filt_ptsd']
    pcs_ptsd = data['conn_sfc_dc_filt_pcsptsd']
    
    controls = np.transpose(controls, (2, 0, 1))
    ptsd = np.transpose(ptsd, (2, 0, 1))
    pcs_ptsd = np.transpose(pcs_ptsd, (2, 0, 1))
    
    fc_matrices = np.concatenate([controls, ptsd, pcs_ptsd], axis=0)
    
    if len(behavioral_df) != fc_matrices.shape[0]:
        print(f"\nWarning: Behavioral data ({len(behavioral_df)}) and FC matrices ({fc_matrices.shape[0]}) mismatch!")
        min_len = min(len(behavioral_df), fc_matrices.shape[0])
        fc_matrices = fc_matrices[:min_len]
        behavioral_df = behavioral_df.iloc[:min_len]
    
    labels = behavioral_df['PCL5_Category'].values
    pcl5_scores = behavioral_df['PCL5 score'].values
    
    print(f"\nFinal dataset: {len(labels)} subjects")
    print("="*50)
    
    return fc_matrices, labels, pcl5_scores

# ============================================================================
# ADJACENCY MATRIX FUNCTIONS
# ============================================================================

def compute_mean_fc_and_threshold(fc_matrices, percentage=16.19):
    """Compute mean FC and threshold"""
    mean_fc = np.mean(fc_matrices, axis=0)
    
    n = mean_fc.shape[0]
    upper_tri_indices = np.triu_indices(n, k=1)
    correlations = mean_fc[upper_tri_indices]
    
    sorted_corrs = np.sort(correlations)[::-1]
    n_edges_to_keep = int(len(sorted_corrs) * percentage / 100)
    threshold = sorted_corrs[n_edges_to_keep] if n_edges_to_keep < len(sorted_corrs) else sorted_corrs[-1]
    
    return mean_fc, threshold

def create_adjacency_matrix(mean_fc, threshold):
    """Create binary adjacency matrix"""
    adj_matrix = (mean_fc > threshold).astype(np.float32)
    adj_matrix = np.maximum(adj_matrix, adj_matrix.T)
    np.fill_diagonal(adj_matrix, 1)
    return adj_matrix

def normalize_adjacency(adj_matrix):
    """Normalize adjacency matrix"""
    adj_with_self_loops = adj_matrix + np.eye(adj_matrix.shape[0])
    degree = np.sum(adj_with_self_loops, axis=1)
    degree_inv_sqrt = np.power(degree, -0.5)
    degree_inv_sqrt[np.isinf(degree_inv_sqrt)] = 0.
    D_inv_sqrt = np.diag(degree_inv_sqrt)
    normalized_adj = D_inv_sqrt @ adj_with_self_loops @ D_inv_sqrt
    return normalized_adj

def prepare_graph_data(fc_matrices, adj_matrix, labels):
    """Convert FC matrices to PyTorch Geometric Data objects"""
    data_list = []
    
    edge_index = torch.tensor(np.array(np.nonzero(adj_matrix)), dtype=torch.long)
    norm_adj = normalize_adjacency(adj_matrix)
    edge_weight = torch.tensor(norm_adj[np.nonzero(adj_matrix)], dtype=torch.float)
    
    for i in range(fc_matrices.shape[0]):
        x = torch.tensor(fc_matrices[i], dtype=torch.float)
        y = torch.tensor([labels[i]], dtype=torch.long)
        
        graph_data = Data(
            x=x,
            edge_index=edge_index.clone(),
            edge_weight=edge_weight.clone(),
            y=y
        )
        data_list.append(graph_data)
    
    return data_list

# ============================================================================
# GCN MODEL
# ============================================================================

class PTSD_GCN(nn.Module):
    """GCN for PTSD severity classification"""
    def __init__(self, n_features=125, hidden_dim1=64, hidden_dim2=32, 
                 hidden_dim3=16, n_classes=4, dropout_rate=0.5):
        super(PTSD_GCN, self).__init__()
        
        self.conv1 = GCNConv(n_features, hidden_dim1)
        self.conv2 = GCNConv(hidden_dim1, hidden_dim2)
        self.conv3 = GCNConv(hidden_dim2, hidden_dim3)
        
        self.fc1 = nn.Linear(hidden_dim3, 64)
        self.fc2 = nn.Linear(64, n_classes)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.bn1 = nn.BatchNorm1d(hidden_dim1)
        self.bn2 = nn.BatchNorm1d(hidden_dim2)
        self.bn3 = nn.BatchNorm1d(hidden_dim3)
    
    def forward(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_weight, data.batch
        
        x = self.conv1(x, edge_index, edge_weight)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index, edge_weight)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index, edge_weight)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = global_mean_pool(x, batch)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return F.log_softmax(x, dim=1)

# ============================================================================
# TRAINING & EVALUATION
# ============================================================================

def train_gcn(model, loader, optimizer, device):
    """Training function"""
    model.train()
    total_loss = 0
    
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(loader.dataset)

def evaluate_gcn(model, loader, device):
    """Evaluation function"""
    model.eval()
    predictions = []
    true_labels = []
    probabilities = []
    
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            prob = torch.exp(output)
            
            predictions.extend(pred.cpu().numpy())
            true_labels.extend(data.y.cpu().numpy())
            probabilities.extend(prob.cpu().numpy())
    
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    probabilities = np.array(probabilities)
    
    accuracy = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels, predictions, average='macro', zero_division=0
    )
    
    try:
        auc = roc_auc_score(true_labels, probabilities, multi_class='ovr', average='macro')
    except:
        auc = 0.0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc,
        'predictions': predictions,
        'true_labels': true_labels,
        'probabilities': probabilities
    }

# ============================================================================
# HYPERPARAMETER TUNING
# ============================================================================

def hyperparameter_tuning_gcn(train_graphs, val_graphs, n_classes, device):
    """
    Perform grid search for hyperparameter tuning
    """
    print("\n" + "="*70)
    print("HYPERPARAMETER TUNING")
    print("="*70)
    
    # Define hyperparameter grid
    param_grid = {
        'hidden_dim1': [32, 64, 128],
        'hidden_dim2': [16, 32, 64],
        'hidden_dim3': [8, 16, 32],
        'dropout_rate': [0.3, 0.5, 0.7],
        'learning_rate': [0.0001, 0.001, 0.01],
        'batch_size': [8, 16, 32],
        'weight_decay': [1e-5, 5e-4, 1e-3]
    }
    
    print(f"\nSearching over {np.prod([len(v) for v in param_grid.values()])} configurations...")
    print("This may take a while...\n")
    
    best_score = 0
    best_params = None
    best_model_state = None
    all_results = []
    
    # Random search (sample 30 configurations)
    n_random_search = 30
    print(f"Using Random Search with {n_random_search} configurations\n")
    
    for trial in range(n_random_search):
        # Randomly sample hyperparameters - FIX: Convert to Python int
        params = {
            'hidden_dim1': int(np.random.choice(param_grid['hidden_dim1'])),
            'hidden_dim2': int(np.random.choice(param_grid['hidden_dim2'])),
            'hidden_dim3': int(np.random.choice(param_grid['hidden_dim3'])),
            'dropout_rate': float(np.random.choice(param_grid['dropout_rate'])),
            'learning_rate': float(np.random.choice(param_grid['learning_rate'])),
            'batch_size': int(np.random.choice(param_grid['batch_size'])),
            'weight_decay': float(np.random.choice(param_grid['weight_decay']))
        }
        
        print(f"Trial {trial+1}/{n_random_search}")
        print(f"  Params: hidden=[{params['hidden_dim1']}, {params['hidden_dim2']}, {params['hidden_dim3']}], "
              f"dropout={params['dropout_rate']}, lr={params['learning_rate']}, "
              f"bs={params['batch_size']}, wd={params['weight_decay']}")
        
        # Create data loaders with current batch size
        train_loader = DataLoader(train_graphs, batch_size=params['batch_size'], shuffle=True)
        val_loader = DataLoader(val_graphs, batch_size=params['batch_size'], shuffle=False)
        
        # Initialize model
        model = PTSD_GCN(
            n_features=125,
            hidden_dim1=params['hidden_dim1'],
            hidden_dim2=params['hidden_dim2'],
            hidden_dim3=params['hidden_dim3'],
            n_classes=n_classes,
            dropout_rate=params['dropout_rate']
        ).to(device)
        
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=params['learning_rate'],
            weight_decay=params['weight_decay']
        )
        
        # Train for fixed epochs
        n_epochs = 100
        best_val_f1 = 0
        patience = 20
        patience_counter = 0
        
        for epoch in range(n_epochs):
            train_loss = train_gcn(model, train_loader, optimizer, device)
            val_metrics = evaluate_gcn(model, val_loader, device)
            
            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                patience_counter = 0
                # Save this model state
                temp_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break
        
        print(f"  Best Val F1: {best_val_f1:.4f} (stopped at epoch {epoch+1})")
        
        # Store results
        result = {
            'params': params.copy(),
            'val_f1': best_val_f1,
            'val_accuracy': val_metrics['accuracy']
        }
        all_results.append(result)
        
        # Update best
        if best_val_f1 > best_score:
            best_score = best_val_f1
            best_params = params.copy()
            best_model_state = temp_model_state
            print(f"  *** New best F1: {best_score:.4f} ***")
        
        print()
    
    # Display top 5 configurations
    print("\n" + "="*70)
    print("TOP 5 CONFIGURATIONS")
    print("="*70)
    
    sorted_results = sorted(all_results, key=lambda x: x['val_f1'], reverse=True)
    for i, result in enumerate(sorted_results[:5]):
        print(f"\nRank {i+1}:")
        print(f"  Val F1: {result['val_f1']:.4f}, Val Acc: {result['val_accuracy']:.4f}")
        print(f"  Params: {result['params']}")
    
    print("\n" + "="*70)
    print("BEST HYPERPARAMETERS")
    print("="*70)
    print(f"Val F1 Score: {best_score:.4f}")
    print(f"Parameters:")
    for key, value in best_params.items():
        print(f"  {key}: {value}")
    print("="*70 + "\n")
    
    return best_params, best_model_state, all_results

# ============================================================================
# MAIN FUNCTION
# ============================================================================

def main_ptsd_pcl5_tuning():
    """Main pipeline with hyperparameter tuning"""
    print(f"\nUsing device: {device}\n")
    
    # Load data
    fc_matrices, labels, pcl5_scores = load_ptsd_data_with_pcl5(
        'Static_functional_connectivity_ptsd_dc_filt.mat',
        'PTSD_Behavioral_measures.csv'
    )
    
    # Compute adjacency matrix
    print("\nComputing adjacency matrix...")
    mean_fc, threshold = compute_mean_fc_and_threshold(fc_matrices, percentage=16.19)
    adj_matrix = create_adjacency_matrix(mean_fc, threshold)
    print(f"Threshold: {threshold:.4f}, Edges: {np.sum(adj_matrix) / 2:.0f}\n")
    
    # Split data: 70% train, 15% validation, 15% test
    indices = np.arange(len(labels))
    train_idx, test_val_idx = train_test_split(
        indices, test_size=0.3, stratify=labels, random_state=42
    )
    
    val_idx, test_idx = train_test_split(
        test_val_idx, test_size=0.66, stratify=labels[test_val_idx], random_state=42
    )
    
    print(f"Data split:")
    print(f"  Train: {len(train_idx)} ({len(train_idx)/len(labels)*100:.1f}%)")
    print(f"  Val:   {len(val_idx)} ({len(val_idx)/len(labels)*100:.1f}%)")
    print(f"  Test:  {len(test_idx)} ({len(test_idx)/len(labels)*100:.1f}%)")
    
    train_fc = fc_matrices[train_idx]
    train_labels = labels[train_idx]
    
    val_fc = fc_matrices[val_idx]
    val_labels = labels[val_idx]
    
    test_fc = fc_matrices[test_idx]
    test_labels = labels[test_idx]
    test_pcl5 = pcl5_scores[test_idx]
    
    print(f"\nTrain dist: {np.bincount(train_labels)}")
    print(f"Val dist:   {np.bincount(val_labels)}")
    print(f"Test dist:  {np.bincount(test_labels)}")
    
    n_classes = len(np.unique(labels))
    
    # Prepare graph data
    print("\nPreparing graph data...")
    train_graphs = prepare_graph_data(train_fc, adj_matrix, train_labels)
    val_graphs = prepare_graph_data(val_fc, adj_matrix, val_labels)
    test_graphs = prepare_graph_data(test_fc, adj_matrix, test_labels)
    
    # Hyperparameter tuning
    best_params, best_model_state, all_results = hyperparameter_tuning_gcn(
        train_graphs, val_graphs, n_classes, device
    )
    
    # Train final model with best hyperparameters on train+val
    print("\n" + "="*70)
    print("TRAINING FINAL MODEL ON TRAIN+VAL DATA")
    print("="*70)
    
    # Combine train and val for final training
    final_train_graphs = train_graphs + val_graphs
    final_train_loader = DataLoader(
        final_train_graphs, 
        batch_size=best_params['batch_size'], 
        shuffle=True
    )
    test_loader = DataLoader(test_graphs, batch_size=best_params['batch_size'], shuffle=False)
    
    # Initialize final model
    final_model = PTSD_GCN(
        n_features=125,
        hidden_dim1=best_params['hidden_dim1'],
        hidden_dim2=best_params['hidden_dim2'],
        hidden_dim3=best_params['hidden_dim3'],
        n_classes=n_classes,
        dropout_rate=best_params['dropout_rate']
    ).to(device)
    
    final_optimizer = torch.optim.Adam(
        final_model.parameters(),
        lr=best_params['learning_rate'],
        weight_decay=best_params['weight_decay']
    )
    
    print(f"\nTraining final model for 200 epochs...")
    for epoch in range(200):
        train_loss = train_gcn(final_model, final_train_loader, final_optimizer, device)
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}: Loss={train_loss:.4f}")
    
    # Save final model
    torch.save(final_model.state_dict(), 'final_gcn_model_pcl5.pt')
    
    # Test evaluation
    print("\n" + "="*70)
    print("TEST SET EVALUATION")
    print("="*70)
    
    test_results = evaluate_gcn(final_model, test_loader, device)
    
    print(f"\nTest Results:")
    print(f"  Accuracy:  {test_results['accuracy']:.4f}")
    print(f"  Precision: {test_results['precision']:.4f}")
    print(f"  Recall:    {test_results['recall']:.4f}")
    print(f"  F1 Score:  {test_results['f1']:.4f}")
    print(f"  AUC:       {test_results['auc']:.4f}")
    
    cm_test = confusion_matrix(test_results['true_labels'], test_results['predictions'])
    print(f"\nConfusion Matrix:\n{cm_test}")
    
    # Plot
    severity_labels = ['Minimal\n(0-25)', 'Mild\n(26-50)', 'Moderate\n(51-75)', 'Severe\n(76-100)']
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues',
                xticklabels=severity_labels[:n_classes],
                yticklabels=severity_labels[:n_classes])
    plt.title('GCN Test Set Confusion Matrix\n(PCL5 Severity Categories)', fontsize=14)
    plt.ylabel('True Severity', fontsize=12)
    plt.xlabel('Predicted Severity', fontsize=12)
    plt.tight_layout()
    plt.savefig('confusion_matrix_gcn_pcl5_tuned.png', dpi=300)
    print("\nConfusion matrix saved")
    
    # PCL5 Analysis
    print("\n" + "="*70)
    print("PCL5 SCORE ANALYSIS")
    print("="*70)
    
    for cat in range(n_classes):
        cat_mask = test_labels == cat
        if np.sum(cat_mask) > 0:
            cat_scores = test_pcl5[cat_mask]
            cat_predictions = test_results['predictions'][cat_mask]
            correct = np.sum(cat_predictions == cat)
            
            print(f"\nCategory {cat} ({severity_labels[cat].strip()}):")
            print(f"  N: {np.sum(cat_mask)}")
            print(f"  Correct: {correct}/{np.sum(cat_mask)} ({100*correct/np.sum(cat_mask):.1f}%)")
            print(f"  PCL5: {cat_scores.mean():.1f} ± {cat_scores.std():.1f}")
    
    # Plot hyperparameter search results
    plot_hyperparameter_results(all_results)

    print("\n" + "="*70)
    print("GENERATING EXPLANATIONS FOR BEST GCN MODEL")
    print("="*70)
    
    gcn_explanations = explain_best_gcn_model(
        model=final_model,
        test_graphs=test_graphs,
        test_labels=test_labels,
        test_pcl5=test_pcl5,
        best_params=best_params,
        n_samples_per_class=3,
        save_dir='explanations_gcn'
    )
    return best_params, test_results, final_model, all_results, gcn_explanations

def plot_hyperparameter_results(all_results):
    """Plot hyperparameter search results"""
    
    # Extract data
    learning_rates = [r['params']['learning_rate'] for r in all_results]
    dropout_rates = [r['params']['dropout_rate'] for r in all_results]
    batch_sizes = [r['params']['batch_size'] for r in all_results]
    f1_scores = [r['val_f1'] for r in all_results]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Learning rate vs F1
    axes[0].scatter(learning_rates, f1_scores, alpha=0.6, s=100)
    axes[0].set_xscale('log')
    axes[0].set_xlabel('Learning Rate', fontsize=12)
    axes[0].set_ylabel('Validation F1 Score', fontsize=12)
    axes[0].set_title('Learning Rate vs F1', fontsize=13)
    axes[0].grid(alpha=0.3)
    
    # Dropout vs F1
    axes[1].scatter(dropout_rates, f1_scores, alpha=0.6, s=100, color='orange')
    axes[1].set_xlabel('Dropout Rate', fontsize=12)
    axes[1].set_ylabel('Validation F1 Score', fontsize=12)
    axes[1].set_title('Dropout Rate vs F1', fontsize=13)
    axes[1].grid(alpha=0.3)
    
    # Batch size vs F1
    axes[2].scatter(batch_sizes, f1_scores, alpha=0.6, s=100, color='green')
    axes[2].set_xlabel('Batch Size', fontsize=12)
    axes[2].set_ylabel('Validation F1 Score', fontsize=12)
    axes[2].set_title('Batch Size vs F1', fontsize=13)
    axes[2].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('hyperparameter_search_results_gcn.png', dpi=300)
    print("Hyperparameter search plots saved")
    plt.show()

# Run
if __name__ == "__main__":
    best_params, test_results, final_model, search_results, gcn_explanations = main_ptsd_pcl5_tuning()

Using device: cpu

Using device: cpu

Loading PTSD dataset with PCL5 scores...
Behavioral data loaded: 174 subjects

PCL5 Category Distribution:
PCL5_Category
0    38
1    50
2    48
3    38
Name: count, dtype: int64

Final dataset: 174 subjects

Computing adjacency matrix...
Threshold: 0.1919, Edges: 1318

Data split:
  Train: 121 (69.5%)
  Val:   18 (10.3%)
  Test:  35 (20.1%)

Train dist: [26 35 33 27]
Val dist:   [4 5 5 4]
Test dist:  [ 8 10 10  7]

Preparing graph data...

HYPERPARAMETER TUNING

Searching over 2187 configurations...
This may take a while...

Using Random Search with 30 configurations

Trial 1/30
  Params: hidden=[128, 16, 32], dropout=0.7, lr=0.0001, bs=8, wd=0.001
  Best Val F1: 0.4686 (stopped at epoch 64)
  *** New best F1: 0.4686 ***

Trial 2/30
  Params: hidden=[64, 64, 32], dropout=0.7, lr=0.01, bs=8, wd=0.001
