# Model Training Notebook for Schizophrenia Detection

This notebook is optimized for Google Colab and provides tools for training the SSPNet 3D CNN model for schizophrenia detection.

## Features:
- GPU configuration and memory management
- Google Drive integration for model checkpoints
- Progress tracking for long-running operations
- Hyperparameter tuning interface
- Real-time training visualization
- Memory optimization for large 3D models

## Setup and Configuration

### 1. Environment Setup

In [None]:
# Check if running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("Not running in Google Colab")

In [None]:
# Mount Google Drive for data storage and model checkpoints
if IN_COLAB:
    from google.colab import drive
    import os
    
    # Check if already mounted
    if not os.path.exists('/content/drive'):
        print("Mounting Google Drive...")
        drive.mount('/content/drive')
    else:
        print("Google Drive already mounted")
    
    # Set up project directory
    project_path = '/content/drive/MyDrive/schizophrenia_detection'
    os.makedirs(project_path, exist_ok=True)
    print(f"Project directory: {project_path}")
else:
    import os
    project_path = os.path.abspath('../')
    print(f"Local project directory: {project_path}")

In [None]:
# Install required packages specific to Colab environment
if IN_COLAB:
    print("Installing required packages...")
    
    # Core packages
    !pip install tensorflow==2.12.0 nibabel nilearn scikit-learn matplotlib seaborn tqdm -q
    
    # Interactive visualization packages
    !pip install plotly ipywidgets -q
    
    # Memory management and optimization packages
    !pip install psutil -q
    
    # Advanced packages for hyperparameter tuning
    !pip install optuna -q
    
    # Package for mixed precision training
    !pip install tensorflow-addons -q
    
    print("Packages installed successfully!")
else:
    print("Skipping package installation in local environment")

### 2. GPU Configuration and Memory Management

In [None]:
# Check GPU availability and configure memory
if IN_COLAB:
    import tensorflow as tf
    from psutil import virtual_memory
    
    # Check GPU availability
    gpu_available = tf.test.is_gpu_available()
    print(f"TensorFlow version: {tf.__version__}")
    print(f"GPU available: {gpu_available}")
    
    if gpu_available:
        gpu_name = tf.test.gpu_device_name()
        print(f"GPU device: {gpu_name}")
        
        # Configure GPU memory growth to prevent OOM errors
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                print("GPU memory growth enabled")
                
                # Enable mixed precision training for better performance
                tf.keras.mixed_precision.set_global_policy('mixed_float16')
                print("Mixed precision training enabled")
            except RuntimeError as e:
                print(f"Error setting GPU memory growth: {e}")
    
    # Check RAM availability
    ram_gb = virtual_memory().total / 1e9
    print(f"Available RAM: {ram_gb:.2f} GB")
    
    if ram_gb < 12:
        print("WARNING: Low RAM detected. Consider reducing batch sizes.")
        # Automatically adjust batch size for low RAM
        AUTO_BATCH_SIZE = 2
    else:
        AUTO_BATCH_SIZE = 4
    
    print(f"Auto-configured batch size: {AUTO_BATCH_SIZE}")
else:
    import tensorflow as tf
    AUTO_BATCH_SIZE = 4
    print("Skipping GPU configuration in local environment")

In [None]:
# Memory management utilities for training
import gc
import psutil
import time
import json
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

def check_memory_usage():
    """Check current memory usage"""
    process = psutil.Process()
    mem_info = process.memory_info()
    print(f"Memory usage: {mem_info.rss / 1e6:.2f} MB")
    return mem_info.rss / 1e6

def clear_memory():
    """Clear memory by garbage collection"""
    gc.collect()
    if IN_COLAB:
        tf.keras.backend.clear_session()
    print("Memory cleared")

def monitor_memory(func):
    """Decorator to monitor memory usage of functions"""
    def wrapper(*args, **kwargs):
        start_mem = check_memory_usage()
        result = func(*args, **kwargs)
        end_mem = check_memory_usage()
        print(f"Memory change: {end_mem - start_mem:.2f} MB")
        return result
    return wrapper

# Progress tracking for long operations
class TrainingProgress:
    def __init__(self, total_epochs):
        self.total_epochs = total_epochs
        self.current_epoch = 0
        self.start_time = time.time()
        self.pbar = tqdm(total=total_epochs, desc="Training Progress")
    
    def update(self, epoch, metrics):
        self.current_epoch = epoch
        elapsed = time.time() - self.start_time
        eta = elapsed / (epoch + 1) * (self.total_epochs - epoch - 1)
        
        # Update progress bar with metrics
        metric_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
        self.pbar.set_description(f"Epoch {epoch+1}/{self.total_epochs} - {metric_str}")
        self.pbar.update(1)
    
    def close(self):
        self.pbar.close()

print("Memory management utilities loaded")

### 3. Import Libraries and Configuration

In [None]:
# Import core libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight

# Import visualization libraries
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display, HTML, Image

# Configure warnings and display
warnings.filterwarnings('ignore')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['figure.dpi'] = 100
%matplotlib inline

print("Libraries imported successfully")

In [None]:
# Change to project directory and import project modules
sys.path.append(project_path)
os.chdir(project_path)
print(f"Current working directory: {os.getcwd()}")

# Import project modules
try:
    from config import default_config
    from utils.file_utils import list_files, load_json, save_json
    from utils.data_utils import normalize_data, resize_data
    from data_processing.data_loader import create_data_generator
    from data_processing.fmri_preprocessing import preprocess_fmri
    from models.sspnet_3d_cnn import create_sspnet_model
    from models.model_utils import count_parameters
    from training.trainer import ModelTrainer
    from training.evaluator import ModelEvaluator
    from visualization.model_visualization import plot_training_history, ModelVisualizer
    from visualization.result_plots import plot_confusion_matrix, plot_roc_curve
    print("Project modules imported successfully")
except ImportError as e:
    print(f"Warning: Could not import some project modules: {e}")
    print("Using minimal configuration for training")
    
    # Create minimal configuration for testing
    class MinimalConfig:
        def __init__(self):
            self.data = type('DataConfig', (), {
                'data_root': './data',
                'fmri_data_dir': './data/fmri',
                'meg_data_dir': './data/meg',
                'fmri_shape': (96, 96, 96, 4),
                'meg_shape': (306, 100, 1000),
                'batch_size': AUTO_BATCH_SIZE,
                'train_ratio': 0.7,
                'val_ratio': 0.15,
                'test_ratio': 0.15
            })()
            self.model = type('ModelConfig', (), {
                'input_shape': (96, 96, 96, 4),
                'num_classes': 2,
                'dropout_rate': 0.5
            })()
            self.training = type('TrainingConfig', (), {
                'epochs': 50,
                'learning_rate': 0.001,
                'optimizer': 'adam',
                'loss_function': 'categorical_crossentropy',
                'early_stopping': True,
                'patience': 15,
                'checkpoint_dir': './checkpoints',
                'mixed_precision': True
            })()
            self.evaluation = type('EvaluationConfig', (), {
                'results_dir': './results'
            })()
            self.visualization = type('VisualizationConfig', (), {
                'output_dir': './visualizations'
            })()
    
    default_config = MinimalConfig()

### 4. Training Configuration

In [None]:
# Configuration cell for easy parameter adjustment
TRAINING_CONFIG = {
    # Model parameters
    'input_shape': default_config.model.input_shape,
    'num_classes': default_config.model.num_classes,
    'dropout_rate': default_config.model.dropout_rate,
    
    # Training parameters
    'epochs': default_config.training.epochs,
    'batch_size': AUTO_BATCH_SIZE,  # Auto-configured based on available RAM
    'learning_rate': default_config.training.learning_rate,
    'optimizer': default_config.training.optimizer,
    'loss_function': default_config.training.loss_function,
    
    # Data parameters
    'train_ratio': default_config.data.train_ratio,
    'val_ratio': default_config.data.val_ratio,
    'test_ratio': default_config.data.test_ratio,
    
    # Training control
    'early_stopping': default_config.training.early_stopping,
    'patience': default_config.training.patience,
    'save_best_only': True,
    
    # Memory management
    'use_mixed_precision': default_config.training.mixed_precision,
    'clear_memory_between_epochs': True,
    'use_data_augmentation': True,
    
    # Checkpointing
    'checkpoint_dir': default_config.training.checkpoint_dir,
    'save_to_drive': IN_COLAB,
    
    # Hyperparameter tuning
    'enable_hyperparameter_tuning': False,
    'tuning_trials': 20,
    
    # Visualization
    'real_time_plotting': True,
    'plot_frequency': 5  # Plot every N epochs
}

# Update configuration based on available resources
if IN_COLAB:
    # Adjust for Colab environment
    if TRAINING_CONFIG['batch_size'] > 2:
        TRAINING_CONFIG['batch_size'] = 2
        print("Reduced batch size for Colab environment")
    
    # Reduce epochs for faster demonstration
    if TRAINING_CONFIG['epochs'] > 30:
        TRAINING_CONFIG['epochs'] = 30
        print("Reduced epochs for faster demonstration")

# Create output directories
os.makedirs(TRAINING_CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(default_config.evaluation.results_dir, exist_ok=True)
os.makedirs(default_config.visualization.output_dir, exist_ok=True)

print("Training configuration set:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")

## Data Preparation

### 1. Load and Prepare Metadata

In [None]:
# Load metadata
metadata_path = os.path.join(default_config.data.data_root, 'metadata.csv')

if os.path.exists(metadata_path):
    print(f"Loading metadata from: {metadata_path}")
    metadata = pd.read_csv(metadata_path)
    print(f"Metadata shape: {metadata.shape}")
    display(metadata.head())
    
    # Check class distribution
    if 'diagnosis' in metadata.columns:
        print("\nDiagnosis distribution:")
        diagnosis_counts = metadata['diagnosis'].value_counts()
        print(diagnosis_counts)
        
        # Plot class distribution
        plt.figure(figsize=(8, 6))
        diagnosis_counts.plot(kind='bar', color=['skyblue', 'lightcoral'])
        plt.title('Class Distribution')
        plt.xlabel('Diagnosis')
        plt.ylabel('Count')
        plt.xticks(rotation=45)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
else:
    print("No metadata file found. Creating sample metadata for demonstration.")
    
    # Create sample metadata for demonstration
    from utils.file_utils import list_files
    fmri_files = list_files(default_config.data.fmri_data_dir, '.nii.gz')
    
    if not fmri_files:
        print("No fMRI files found. Creating sample data for demonstration.")
        os.makedirs(default_config.data.fmri_data_dir, exist_ok=True)
        
        # Create sample fMRI files
        import nibabel as nibabel
        for i in range(20):  # Create 20 sample subjects
            sample_shape = default_config.data.fmri_shape
            sample_data = np.random.randn(*sample_shape)
            sample_img = nibabel.Nifti1Image(sample_data, affine=np.eye(4))
            
            sample_path = os.path.join(default_config.data.fmri_data_dir, f'sub-{i:03d}_fmri.nii.gz')
            nibabel.save(sample_img, sample_path)
        
        fmri_files = list_files(default_config.data.fmri_data_dir, '.nii.gz')
    
    # Create sample labels (0 for control, 1 for schizophrenia)
    labels = np.random.randint(0, 2, size=len(fmri_files))
    
    # Create metadata DataFrame
    metadata = pd.DataFrame({
        'subject_id': [f"sub-{i:03d}" for i in range(len(fmri_files))],
        'file_path': fmri_files,
        'diagnosis': ['control' if label == 0 else 'schizophrenia' for label in labels],
        'label': labels,
        'age': np.random.randint(18, 65, len(fmri_files)),
        'gender': np.random.choice(['M', 'F'], len(fmri_files))
    })
    
    # Save sample metadata
    os.makedirs(default_config.data.data_root, exist_ok=True)
    metadata.to_csv(metadata_path, index=False)
    print(f"Created sample metadata with {len(metadata)} subjects: {metadata_path}")
    
    display(metadata.head())
    print(f"\nDiagnosis distribution:")
    print(metadata['diagnosis'].value_counts())

### 2. Data Splitting

In [None]:
# Convert diagnosis to numeric labels if needed
if 'label' not in metadata.columns:
    label_map = {'control': 0, 'schizophrenia': 1}
    metadata['label'] = metadata['diagnosis'].map(label_map)

# Split data into train, validation, and test sets
print("Splitting data into train, validation, and test sets...")

train_df, temp_df = train_test_split(
    metadata, 
    test_size=(1 - TRAINING_CONFIG['train_ratio']),
    random_state=42,
    stratify=metadata['label']
)

val_ratio = TRAINING_CONFIG['val_ratio'] / (1 - TRAINING_CONFIG['train_ratio'])
val_df, test_df = train_test_split(
    temp_df,
    test_size=(1 - val_ratio),
    random_state=42,
    stratify=temp_df['label']
)

print(f"Training set: {len(train_df)} subjects")
print(f"Validation set: {len(val_df)} subjects")
print(f"Test set: {len(test_df)} subjects")

print("\nTraining set distribution:")
print(train_df['diagnosis'].value_counts())
print("\nValidation set distribution:")
print(val_df['diagnosis'].value_counts())
print("\nTest set distribution:")
print(test_df['diagnosis'].value_counts())

# Calculate class weights for imbalanced datasets
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label']
)
class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}
print(f"\nClass weights: {class_weight_dict}")

### 3. Data Loading and Preprocessing

In [None]:
# Import nibabel for fMRI data loading
try:
    import nibabel as nib
except ImportError:
    print("Installing nibabel...")
    !pip install nibabel -q
    import nibabel as nib

# Memory-efficient data loading function
@monitor_memory
def load_and_preprocess_fmri(file_path, target_shape=None, normalize=True):
    """Load and preprocess fMRI data with memory management"""
    try:
        # Load fMRI data with memory mapping
        img = nib.load(file_path, mmap='r+')
        data = img.get_fdata()
        
        # Handle 4D data (take first few time points if needed)
        if len(data.shape) == 4:
            if data.shape[3] > target_shape[3] if target_shape else data.shape[3] > 4:
                # Take first few time points
                n_timepoints = target_shape[3] if target_shape else 4
                data = data[..., :n_timepoints]
        elif len(data.shape) == 3:
            # Add time dimension if 3D
            n_timepoints = target_shape[3] if target_shape else 4
            data = np.expand_dims(data, axis=-1)
            # Repeat to match required time points
            data = np.repeat(data, n_timepoints, axis=-1)
        
        # Resize if needed
        if target_shape and data.shape != target_shape:
            try:
                from utils.data_utils import resize_data
                data = resize_data(data, target_shape)
            except:
                print(f"Warning: Could not resize data from {data.shape} to {target_shape}")
        
        # Normalize data
        if normalize:
            try:
                from utils.data_utils import normalize_data
                data = normalize_data(data, method='standard')
            except:
                # Simple z-score normalization
                data = (data - np.mean(data)) / (np.std(data) + 1e-8)
        
        return data
    
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

print("Data loading functions defined")

In [None]:
# Create data generators with memory management
def create_memory_efficient_generator(file_paths, labels, batch_size, training=True, shuffle=True):
    """Create memory-efficient data generator for 3D neuroimaging data"""
    num_samples = len(file_paths)
    
    while True:
        # Create indices for batching
        indices = np.random.permutation(num_samples) if shuffle and training else np.arange(num_samples)
        
        for i in range(0, num_samples, batch_size):
            batch_indices = indices[i:i+batch_size]
            batch_size_actual = len(batch_indices)
            
            # Initialize batch arrays
            batch_x = np.zeros((batch_size_actual,) + TRAINING_CONFIG['input_shape'], dtype=np.float32)
            batch_y = np.zeros((batch_size_actual, TRAINING_CONFIG['num_classes']), dtype=np.float32)
            
            valid_samples = 0
            for j, idx in enumerate(batch_indices):
                try:
                    # Load and preprocess data
                    data = load_and_preprocess_fmri(
                        file_paths[idx], 
                        target_shape=TRAINING_CONFIG['input_shape'],
                        normalize=True
                    )
                    
                    if data is not None:
                        batch_x[valid_samples] = data
                        
                        # One-hot encode label
                        batch_y[valid_samples, labels[idx]] = 1
                        valid_samples += 1
                    
                except Exception as e:
                    print(f"Error processing {file_paths[idx]}: {e}")
                    continue
            
            # Only yield if we have valid samples
            if valid_samples > 0:
                yield batch_x[:valid_samples], batch_y[:valid_samples]
            
            # Clear memory between batches if enabled
            if TRAINING_CONFIG['clear_memory_between_epochs']:
                if i % (batch_size * 5) == 0:  # Clear every 5 batches
                    gc.collect()

# Create data generators
print("Creating data generators...")

train_generator = create_memory_efficient_generator(
    train_df['file_path'].tolist(),
    train_df['label'].tolist(),
    TRAINING_CONFIG['batch_size'],
    training=True,
    shuffle=True
)

val_generator = create_memory_efficient_generator(
    val_df['file_path'].tolist(),
    val_df['label'].tolist(),
    TRAINING_CONFIG['batch_size'],
    training=False,
    shuffle=False
)

test_generator = create_memory_efficient_generator(
    test_df['file_path'].tolist(),
    test_df['label'].tolist(),
    TRAINING_CONFIG['batch_size'],
    training=False,
    shuffle=False
)

print("Data generators created successfully")

# Calculate steps per epoch
steps_per_epoch = len(train_df) // TRAINING_CONFIG['batch_size']
validation_steps = len(val_df) // TRAINING_CONFIG['batch_size']
test_steps = len(test_df) // TRAINING_CONFIG['batch_size']

print(f"Steps per epoch: {steps_per_epoch}")
print(f"Validation steps: {validation_steps}")
print(f"Test steps: {test_steps}")

## Model Building

### 1. Model Architecture

In [None]:
# Build the SSPNet 3D CNN model
print("Building SSPNet 3D CNN model...")

try:
    # Try to use the project's model creation function
    model = create_sspnet_model(
        input_shape=TRAINING_CONFIG['input_shape'],
        num_classes=TRAINING_CONFIG['num_classes'],
        dropout_rate=TRAINING_CONFIG['dropout_rate']
    )
    print("Model created using project's SSPNet architecture")
except:
    print("Could not load project model. Creating a simplified 3D CNN model...")
    
    # Create a simplified 3D CNN model for demonstration
    from tensorflow.keras import layers, models, regularizers
    
    def create_simple_3d_cnn(input_shape, num_classes, dropout_rate=0.5):
        inputs = layers.Input(shape=input_shape)
        
        # First 3D convolution block
        x = layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling3D((2, 2, 2))(x)
        x = layers.Dropout(dropout_rate * 0.5)(x)
        
        # Second 3D convolution block
        x = layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling3D((2, 2, 2))(x)
        x = layers.Dropout(dropout_rate * 0.5)(x)
        
        # Third 3D convolution block
        x = layers.Conv3D(128, (3, 3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.MaxPooling3D((2, 2, 2))(x)
        x = layers.Dropout(dropout_rate)(x)
        
        # Flatten and dense layers
        x = layers.Flatten()(x)
        x = layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.01))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(dropout_rate)(x)
        
        # Output layer
        outputs = layers.Dense(num_classes, activation='softmax')(x)
        
        # Create model
        model = models.Model(inputs=inputs, outputs=outputs)
        return model
    
    model = create_simple_3d_cnn(
        input_shape=TRAINING_CONFIG['input_shape'],
        num_classes=TRAINING_CONFIG['num_classes'],
        dropout_rate=TRAINING_CONFIG['dropout_rate']
    )

# Print model summary
print("Model Architecture:")
model.summary()

# Count parameters
try:
    total_params = count_parameters(model)
    print(f"\nTotal parameters: {total_params:,}")
except:
    total_params = model.count_params()
    print(f"\nTotal parameters: {total_params:,}")

# Calculate model memory usage
param_memory = total_params * 4  # 4 bytes per float32
print(f"Estimated model memory: {param_memory / 1e6:.2f} MB")

### 2. Model Compilation

In [None]:
# Compile the model with appropriate optimizer and loss
print("Compiling model...")

# Configure optimizer based on mixed precision setting
if TRAINING_CONFIG['use_mixed_precision']:
    optimizer = tf.keras.mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(learning_rate=TRAINING_CONFIG['learning_rate'])
    )
    print("Using mixed precision optimizer")
else:
    optimizer = tf.keras.optimizers.Adam(learning_rate=TRAINING_CONFIG['learning_rate'])
    print("Using standard precision optimizer")

# Compile model
model.compile(
    optimizer=optimizer,
    loss=TRAINING_CONFIG['loss_function'],
    metrics=['accuracy', 'AUC']
)

print("Model compiled successfully")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Learning rate: {TRAINING_CONFIG['learning_rate']}")
print(f"Loss function: {TRAINING_CONFIG['loss_function']}")

## Training Setup

### 1. Callbacks Configuration

In [None]:
# Setup training callbacks
from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, 
    TensorBoard, CSVLogger, Callback
)

# Custom callback for real-time plotting
class RealTimePlotCallback(Callback):
    def __init__(self, plot_frequency=5):
        super().__init__()
        self.plot_frequency = plot_frequency
        self.history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}
        
    def on_epoch_end(self, epoch, logs=None):
        # Update history
        for key in self.history.keys():
            if key in logs:
                self.history[key].append(logs[key])
        
        # Plot every N epochs
        if (epoch + 1) % self.plot_frequency == 0:
            self.plot_training_progress()
    
    def plot_training_progress(self):
        if len(self.history['loss']) < 2:
            return
            
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot loss
        axes[0].plot(self.history['loss'], label='Training Loss', color='blue')
        if 'val_loss' in self.history and self.history['val_loss']:
            axes[0].plot(self.history['val_loss'], label='Validation Loss', color='red')
        axes[0].set_title('Model Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Plot accuracy
        axes[1].plot(self.history['accuracy'], label='Training Accuracy', color='blue')
        if 'val_accuracy' in self.history and self.history['val_accuracy']:
            axes[1].plot(self.history['val_accuracy'], label='Validation Accuracy', color='red')
        axes[1].set_title('Model Accuracy')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Create callbacks list
callbacks = []

# Model checkpoint callback
checkpoint_path = os.path.join(
    TRAINING_CONFIG['checkpoint_dir'], 
    'sspnet_best_model.h5'
)
checkpoint_callback = ModelCheckpoint(
    checkpoint_path,
    monitor='val_accuracy',
    save_best_only=TRAINING_CONFIG['save_best_only'],
    save_weights_only=False,
    verbose=1
)
callbacks.append(checkpoint_callback)

# Early stopping callback
if TRAINING_CONFIG['early_stopping']:
    early_stopping_callback = EarlyStopping(
        monitor='val_accuracy',
        patience=TRAINING_CONFIG['patience'],
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stopping_callback)

# Learning rate reduction callback
reduce_lr_callback = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=10,
    min_lr=1e-7,
    verbose=1
)
callbacks.append(reduce_lr_callback)

# CSV logger callback
csv_logger_path = os.path.join(
    default_config.evaluation.results_dir, 
    'training_log.csv'
)
csv_logger_callback = CSVLogger(csv_logger_path)
callbacks.append(csv_logger_callback)

# Real-time plotting callback
if TRAINING_CONFIG['real_time_plotting']:
    plot_callback = RealTimePlotCallback(plot_frequency=TRAINING_CONFIG['plot_frequency'])
    callbacks.append(plot_callback)

# TensorBoard callback (for Colab)
if IN_COLAB:
    tensorboard_callback = TensorBoard(
        log_dir=os.path.join(default_config.training.checkpoint_dir, 'tensorboard_logs'),
        histogram_freq=1,
        write_graph=True,
        update_freq='epoch'
    )
    callbacks.append(tensorboard_callback)
    
    # Load TensorBoard extension
    %load_ext tensorboard
    %tensorboard --logdir {os.path.join(default_config.training.checkpoint_dir, 'tensorboard_logs')}

print(f"Callbacks configured: {len(callbacks)} callbacks")
for callback in callbacks:
    print(f"  - {callback.__class__.__name__}")

### 2. Training Progress Tracking

In [None]:
# Initialize progress tracking
progress_tracker = TrainingProgress(TRAINING_CONFIG['epochs'])

# Custom training loop with memory management
class MemoryEfficientTrainer:
    def __init__(self, model, train_gen, val_gen, steps_per_epoch, val_steps, config):
        self.model = model
        self.train_gen = train_gen
        self.val_gen = val_gen
        self.steps_per_epoch = steps_per_epoch
        self.val_steps = val_steps
        self.config = config
        self.history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': []}
        
    def train_epoch(self, epoch):
        """Train for one epoch with memory management"""
        # Initialize metrics
        epoch_loss = 0
        epoch_accuracy = 0
        num_batches = 0
        
        # Training loop
        for step in range(self.steps_per_epoch):
            # Get batch data
            batch_x, batch_y = next(self.train_gen)
            
            # Train on batch
            metrics = self.model.train_on_batch(batch_x, batch_y)
            
            # Update metrics
            if isinstance(metrics, list):
                epoch_loss += metrics[0]
                epoch_accuracy += metrics[1]
            else:
                epoch_loss += metrics
                epoch_accuracy += 0  # No accuracy metric
            
            num_batches += 1
            
            # Clear memory periodically
            if step % 10 == 0 and self.config['clear_memory_between_epochs']:
                gc.collect()
        
        # Calculate average metrics
        avg_loss = epoch_loss / num_batches
        avg_accuracy = epoch_accuracy / num_batches
        
        return avg_loss, avg_accuracy
    
    def validate_epoch(self):
        """Validate for one epoch"""
        # Initialize metrics
        val_loss = 0
        val_accuracy = 0
        num_batches = 0
        
        # Validation loop
        for step in range(self.val_steps):
            # Get batch data
            batch_x, batch_y = next(self.val_gen)
            
            # Validate on batch
            metrics = self.model.test_on_batch(batch_x, batch_y)
            
            # Update metrics
            if isinstance(metrics, list):
                val_loss += metrics[0]
                val_accuracy += metrics[1]
            else:
                val_loss += metrics
                val_accuracy += 0  # No accuracy metric
            
            num_batches += 1
        
        # Calculate average metrics
        avg_val_loss = val_loss / num_batches
        avg_val_accuracy = val_accuracy / num_batches
        
        return avg_val_loss, avg_val_accuracy
    
    def train(self, epochs, callbacks=None):
        """Train the model with custom loop"""
        print(f"Starting training for {epochs} epochs...")
        
        for epoch in range(epochs):
            print(f"\nEpoch {epoch + 1}/{epochs}")
            
            # Train epoch
            train_loss, train_acc = self.train_epoch(epoch)
            
            # Validate epoch
            val_loss, val_acc = self.validate_epoch()
            
            # Update history
            self.history['loss'].append(train_loss)
            self.history['accuracy'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_accuracy'].append(val_acc)
            
            # Update progress tracker
            metrics = {
                'loss': train_loss,
                'accuracy': train_acc,
                'val_loss': val_loss,
                'val_accuracy': val_acc
            }
            progress_tracker.update(epoch, metrics)
            
            # Print epoch summary
            print(f"loss: {train_loss:.4f} - accuracy: {train_acc:.4f} - ")
            print(f"val_loss: {val_loss:.4f} - val_accuracy: {val_acc:.4f}")
            
            # Check early stopping
            if len(self.history['val_accuracy']) > self.config['patience']:
                recent_val_acc = self.history['val_accuracy'][-self.config['patience']:]
                if max(recent_val_acc) <= self.history['val_accuracy'][-self.config['patience']-1]:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
        
        progress_tracker.close()
        print("\nTraining completed!")
        return self.history

print("Custom training utilities defined")

## Model Training

### 1. Standard Training

In [None]:
# Choose training method based on configuration
use_custom_training = False  # Set to True for custom training loop

if use_custom_training:
    print("Using custom training loop with memory management...")
    
    # Create custom trainer
    trainer = MemoryEfficientTrainer(
        model=model,
        train_gen=train_generator,
        val_gen=val_generator,
        steps_per_epoch=steps_per_epoch,
        val_steps=validation_steps,
        config=TRAINING_CONFIG
    )
    
    # Train model
    history = trainer.train(
        epochs=TRAINING_CONFIG['epochs'],
        callbacks=callbacks
    )
    
else:
    print("Using standard Keras fit method...")
    
    # Train model using standard fit method
    print("\nStarting model training...")
    print(f"Training for {TRAINING_CONFIG['epochs']} epochs")
    print(f"Batch size: {TRAINING_CONFIG['batch_size']}")
    print(f"Steps per epoch: {steps_per_epoch}")
    
    # Check memory before training
    print("\nMemory usage before training:")
    check_memory_usage()
    
    # Train model
    history = model.fit(
        train_generator,
        steps_per_epoch=steps_per_epoch,
        epochs=TRAINING_CONFIG['epochs'],
        validation_data=val_generator,
        validation_steps=validation_steps,
        callbacks=callbacks,
        class_weight=class_weight_dict,
        verbose=1
    )
    
    print("\nTraining completed!")
    
    # Check memory after training
    print("\nMemory usage after training:")
    check_memory_usage()

### 2. Training Visualization

In [None]:
# Plot training history
if use_custom_training:
    # Use custom training history
    training_history = history
else:
    # Use Keras training history
    training_history = history.history

# Create comprehensive training visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot training loss
axes[0, 0].plot(training_history['loss'], label='Training Loss', color='blue', linewidth=2)
if 'val_loss' in training_history:
    axes[0, 0].plot(training_history['val_loss'], label='Validation Loss', color='red', linewidth=2)
axes[0, 0].set_title('Model Loss', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot training accuracy
axes[0, 1].plot(training_history['accuracy'], label='Training Accuracy', color='blue', linewidth=2)
if 'val_accuracy' in training_history:
    axes[0, 1].plot(training_history['val_accuracy'], label='Validation Accuracy', color='red', linewidth=2)
axes[0, 1].set_title('Model Accuracy', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Plot learning curves (loss comparison)
if 'val_loss' in training_history:
    axes[1, 0].plot(training_history['loss'], label='Training Loss', color='blue', linewidth=2)
    axes[1, 0].plot(training_history['val_loss'], label='Validation Loss', color='red', linewidth=2)
    axes[1, 0].set_title('Learning Curves', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Highlight best epoch
    best_val_loss_epoch = np.argmin(training_history['val_loss'])
    best_val_loss = training_history['val_loss'][best_val_loss_epoch]
    axes[1, 0].scatter(best_val_loss_epoch, best_val_loss, color='green', s=100, zorder=5)
    axes[1, 0].annotate(f'Best: Epoch {best_val_loss_epoch+1}', 
                        xy=(best_val_loss_epoch, best_val_loss),
                        xytext=(best_val_loss_epoch+1, best_val_loss*1.1),
                        arrowprops=dict(arrowstyle='->', color='green'))
else:
    axes[1, 0].text(0.5, 0.5, 'No validation data available', 
                   ha='center', va='center', transform=axes[1, 0].transAxes)
    axes[1, 0].set_title('Learning Curves')

# Plot training progress summary
if len(training_history['loss']) > 0:
    final_train_loss = training_history['loss'][-1]
    final_train_acc = training_history['accuracy'][-1]
    
    summary_text = f"Training Summary:\n"
    summary_text += f"Final Loss: {final_train_loss:.4f}\n"
    summary_text += f"Final Accuracy: {final_train_acc:.4f}\n"
    
    if 'val_loss' in training_history:
        final_val_loss = training_history['val_loss'][-1]
        final_val_acc = training_history['val_accuracy'][-1]
        summary_text += f"Final Val Loss: {final_val_loss:.4f}\n"
        summary_text += f"Final Val Accuracy: {final_val_acc:.4f}\n"
        
        # Calculate overfitting metric
        overfitting = final_train_acc - final_val_acc
        summary_text += f"Overfitting: {overfitting:.4f}"
    
    axes[1, 1].text(0.1, 0.5, summary_text, transform=axes[1, 1].transAxes,
                   fontsize=12, verticalalignment='center',
                   bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    axes[1, 1].set_title('Training Summary', fontsize=14, fontweight='bold')
    axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

# Save training plots
training_plot_path = os.path.join(default_config.visualization.output_dir, 'training_history.png')
plt.savefig(training_plot_path, dpi=300, bbox_inches='tight')
print(f"Training plots saved to {training_plot_path}")

## Model Evaluation

### 1. Test Set Evaluation

In [None]:
# Evaluate model on test set
print("Evaluating model on test set...")

# Load best model
if os.path.exists(checkpoint_path):
    print(f"Loading best model from {checkpoint_path}")
    model.load_weights(checkpoint_path)
else:
    print("No checkpoint found. Using current model weights.")

# Collect predictions
y_true = []
y_pred = []
y_pred_proba = []

print("\nProcessing test data...")
for i in tqdm(range(test_steps), desc="Evaluating"):
    batch_x, batch_y = next(test_generator)
    batch_pred = model.predict(batch_x, verbose=0)
    
    y_true.extend(np.argmax(batch_y, axis=1))
    y_pred.extend(np.argmax(batch_pred, axis=1))
    y_pred_proba.extend(batch_pred[:, 1])  # Probability of positive class

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_pred_proba = np.array(y_pred_proba)

print(f"\nEvaluated {len(y_true)} samples")

# Calculate metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
auc = roc_auc_score(y_true, y_pred_proba)

print("\n=== TEST RESULTS ===")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC: {auc:.4f}")

# Detailed classification report
print("\nClassification Report:")
target_names = ['Control', 'Schizophrenia']
print(classification_report(y_true, y_pred, target_names=target_names))

### 2. Visualization of Results

In [None]:
# Plot confusion matrix
plt.figure(figsize=(8, 6))
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=target_names, yticklabels=target_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# Save confusion matrix
cm_path = os.path.join(default_config.visualization.output_dir, 'confusion_matrix.png')
plt.savefig(cm_path, dpi=300, bbox_inches='tight')
print(f"Confusion matrix saved to {cm_path}")

In [None]:
# Plot ROC curve
from sklearn.metrics import roc_curve

fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)
plt.show()

# Save ROC curve
roc_path = os.path.join(default_config.visualization.output_dir, 'roc_curve.png')
plt.savefig(roc_path, dpi=300, bbox_inches='tight')
print(f"ROC curve saved to {roc_path}")

## Model Saving and Export

In [None]:
# Save the final model
final_model_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'sspnet_final_model.h5')
model.save(final_model_path)
print(f"Final model saved to {final_model_path}")

# Save model architecture as JSON
model_json_path = os.path.join(TRAINING_CONFIG['checkpoint_dir'], 'sspnet_model_architecture.json')
with open(model_json_path, 'w') as f:
    f.write(model.to_json())
print(f"Model architecture saved to {model_json_path}")

# Save training history
history_path = os.path.join(default_config.evaluation.results_dir, 'training_history.json')
with open(history_path, 'w') as f:
    json.dump(training_history, f, indent=2)
print(f"Training history saved to {history_path}")

# Save test results
test_results = {
    'accuracy': float(accuracy),
    'precision': float(precision),
    'recall': float(recall),
    'f1_score': float(f1),
    'auc': float(auc),
    'confusion_matrix': cm.tolist(),
    'classification_report': classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
}

test_results_path = os.path.join(default_config.evaluation.results_dir, 'test_results.json')
with open(test_results_path, 'w') as f:
    json.dump(test_results, f, indent=2)
print(f"Test results saved to {test_results_path}")

## Export to Google Drive

In [None]:
# Export results to Google Drive
if IN_COLAB and TRAINING_CONFIG['save_to_drive']:
    print("\nExporting results to Google Drive...")
    
    import shutil
    
    # Create timestamped directory in Google Drive
    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    drive_export_dir = os.path.join(project_path, f'training_results_{timestamp}')
    os.makedirs(drive_export_dir, exist_ok=True)
    
    # Copy model files
    shutil.copy2(final_model_path, drive_export_dir)
    shutil.copy2(model_json_path, drive_export_dir)
    
    # Copy results files
    shutil.copy2(history_path, drive_export_dir)
    shutil.copy2(test_results_path, drive_export_dir)
    shutil.copy2(csv_logger_path, drive_export_dir)
    
    # Copy visualization files
    viz_export_dir = os.path.join(drive_export_dir, 'visualizations')
    os.makedirs(viz_export_dir, exist_ok=True)
    
    for viz_file in [training_plot_path, cm_path, roc_path]:
        if os.path.exists(viz_file):
            shutil.copy2(viz_file, viz_export_dir)
    
    print(f"Results exported to Google Drive: {drive_export_dir}")
    
    # Create summary file
    summary = {
        'training_completed': datetime.now().isoformat(),
        'model_architecture': 'SSPNet 3D CNN',
        'training_config': TRAINING_CONFIG,
        'final_metrics': test_results,
        'files_exported': {
            'model': 'sspnet_final_model.h5',
            'architecture': 'sspnet_model_architecture.json',
            'training_history': 'training_history.json',
            'test_results': 'test_results.json',
            'training_log': 'training_log.csv'
        }
    }
    
    summary_path = os.path.join(drive_export_dir, 'training_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"Training summary saved to {summary_path}")
else:
    print("\nResults saved locally. Google Drive export skipped.")

## Hyperparameter Tuning (Optional)

In [None]:
# Optional: Hyperparameter tuning with Optuna
if TRAINING_CONFIG['enable_hyperparameter_tuning']:
    print("\nStarting hyperparameter tuning...")
    
    try:
        import optuna
    except ImportError:
        print("Installing Optuna for hyperparameter tuning...")
        !pip install optuna -q
        import optuna
    
    def objective(trial):
        """Objective function for hyperparameter optimization"""
        # Define hyperparameters to tune
        hp = {
            'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
            'dropout_rate': trial.suggest_float('dropout_rate', 0.2, 0.7),
            'batch_size': trial.suggest_categorical('batch_size', [2, 4]),
            'epochs': trial.suggest_int('epochs', 10, 30)
        }
        
        print(f"\nTrial {trial.number}: {hp}")
        
        # Create model with hyperparameters
        try:
            tuned_model = create_sspnet_model(
                input_shape=TRAINING_CONFIG['input_shape'],
                num_classes=TRAINING_CONFIG['num_classes'],
                dropout_rate=hp['dropout_rate']
            )
        except:
            tuned_model = create_simple_3d_cnn(
                input_shape=TRAINING_CONFIG['input_shape'],
                num_classes=TRAINING_CONFIG['num_classes'],
                dropout_rate=hp['dropout_rate']
            )
        
        # Compile model
        tuned_model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=hp['learning_rate']),
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        
        # Create data generators with new batch size
        tuned_train_gen = create_memory_efficient_generator(
            train_df['file_path'].tolist(),
            train_df['label'].tolist(),
            hp['batch_size'],
            training=True,
            shuffle=True
        )
        
        tuned_val_gen = create_memory_efficient_generator(
            val_df['file_path'].tolist(),
            val_df['label'].tolist(),
            hp['batch_size'],
            training=False,
            shuffle=False
        )
        
        # Calculate steps
        tuned_steps = len(train_df) // hp['batch_size']
        tuned_val_steps = len(val_df) // hp['batch_size']
        
        # Train model
        tuned_history = tuned_model.fit(
            tuned_train_gen,
            steps_per_epoch=tuned_steps,
            epochs=hp['epochs'],
            validation_data=tuned_val_gen,
            validation_steps=tuned_val_steps,
            verbose=0
        )
        
        # Get best validation accuracy
        best_val_acc = max(tuned_history.history['val_accuracy'])
        
        print(f"Trial {trial.number} - Best validation accuracy: {best_val_acc:.4f}")
        
        # Clear memory
        clear_memory()
        
        return best_val_acc
    
    # Create study
    study = optuna.create_study(direction='maximize')
    
    # Optimize
    study.optimize(objective, n_trials=TRAINING_CONFIG['tuning_trials'])
    
    # Print results
    print("\n=== HYPERPARAMETER TUNING RESULTS ===")
    print(f"Best trial: {study.best_trial.number}")
    print(f"Best value: {study.best_value:.4f}")
    print("Best parameters:")
    for key, value in study.best_params.items():
        print(f"  {key}: {value}")
    
    # Save tuning results
    tuning_results = {
        'best_trial': study.best_trial.number,
        'best_value': study.best_value,
        'best_params': study.best_params,
        'all_trials': [
            {
                'number': trial.number,
                'value': trial.value,
                'params': trial.params
            } for trial in study.trials
        ]
    }
    
    tuning_results_path = os.path.join(default_config.evaluation.results_dir, 'hyperparameter_tuning_results.json')
    with open(tuning_results_path, 'w') as f:
        json.dump(tuning_results, f, indent=2)
    
    print(f"\nTuning results saved to {tuning_results_path}")
    
    # Plot optimization history
    try:
        optuna.visualization.plot_optimization_history(study).show()
        optuna.visualization.plot_param_importances(study).show()
    except:
        print("Could not display Optuna visualizations")
    
else:
    print("\nHyperparameter tuning disabled. Set 'enable_hyperparameter_tuning': True to enable.")

## Summary and Conclusion

In [None]:
# Create comprehensive training summary
print("\n" + "="*80)
print("SCHIZOPHRENIA DETECTION MODEL TRAINING SUMMARY")
print("="*80)

print(f"\n🔧 Training Configuration:")
print(f"  - Model Architecture: SSPNet 3D CNN")
print(f"  - Input Shape: {TRAINING_CONFIG['input_shape']}")
print(f"  - Batch Size: {TRAINING_CONFIG['batch_size']}")
print(f"  - Epochs Trained: {len(training_history['loss'])}")
print(f"  - Learning Rate: {TRAINING_CONFIG['learning_rate']}")
print(f"  - Optimizer: {TRAINING_CONFIG['optimizer']}")
print(f"  - Mixed Precision: {TRAINING_CONFIG['use_mixed_precision']}")

print(f"\n📊 Training Results:")
print(f"  - Final Training Loss: {training_history['loss'][-1]:.4f}")
print(f"  - Final Training Accuracy: {training_history['accuracy'][-1]:.4f}")
if 'val_loss' in training_history:
    print(f"  - Final Validation Loss: {training_history['val_loss'][-1]:.4f}")
    print(f"  - Final Validation Accuracy: {training_history['val_accuracy'][-1]:.4f}")

print(f"\n🧪 Test Set Performance:")
print(f"  - Accuracy: {accuracy:.4f}")
print(f"  - Precision: {precision:.4f}")
print(f"  - Recall: {recall:.4f}")
print(f"  - F1 Score: {f1:.4f}")
print(f"  - AUC: {auc:.4f}")

print(f"\n💾 Generated Files:")
print(f"  - Model: {final_model_path}")
print(f"  - Architecture: {model_json_path}")
print(f"  - Training History: {history_path}")
print(f"  - Test Results: {test_results_path}")
print(f"  - Visualizations: {default_config.visualization.output_dir}")

if IN_COLAB and TRAINING_CONFIG['save_to_drive']:
    print(f"  - Google Drive Export: {drive_export_dir}")

print(f"\n📈 Model Performance Assessment:")
if accuracy > 0.8:
    print("  ✅ Excellent performance (>80% accuracy)")
elif accuracy > 0.7:
    print("  ✅ Good performance (>70% accuracy)")
elif accuracy > 0.6:
    print("  ⚠️ Moderate performance (>60% accuracy)")
else:
    print("  ❌ Poor performance (<60% accuracy)")

# Check for overfitting
if 'val_accuracy' in training_history:
    overfitting = training_history['accuracy'][-1] - training_history['val_accuracy'][-1]
    if overfitting > 0.1:
        print(f"  ⚠️ Potential overfitting detected (gap: {overfitting:.3f})")
    else:
        print(f"  ✅ Good generalization (gap: {overfitting:.3f})")

print(f"\n🚀 Next Steps:")
print(f"  1. Run the results_analysis.ipynb notebook for detailed evaluation")
print(f"  2. Use the trained model for inference on new data")
print(f"  3. Consider hyperparameter tuning for better performance")
print(f"  4. Implement cross-validation for more robust evaluation")

print("\n" + "="*80)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("="*80)

In [None]:
# Clean up memory before ending
print("\n🧹 Cleaning up memory...")
clear_memory()
check_memory_usage()

print("\n✅ Model training notebook completed successfully!")
print("\nModel is ready for use in schizophrenia detection tasks.")