# Final ROC Plotter

## Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.optimizers import legacy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Dropout
from tensorflow.keras import regularizers
import random
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import awkward as ak

In [None]:

# Set random seeds for reproducibility
seed_value = 42
random.seed(seed_value)
np.random.seed(seed_value)
tf.random.set_seed(seed_value)

def load_and_prepare_data(data_path='padded_waveforms.parquet'):
    """
    Load and prepare the data for modeling.
    """
    print("Loading data...")
    # Load data
    arr = ak.from_parquet(data_path)
    df = pd.read_parquet(data_path)
    
    # Constants
    electron_size = 58.5
    padding_length = 500
    
    # Normalize data
    def normalise_array(arr):
        min_val = ak.min(arr, axis=-1)
        max_val = ak.max(arr, axis=-1)
        return (arr - min_val) / (max_val - min_val)
    
    normalised_times = normalise_array(arr['times'])
    normalised_samples = normalise_array(arr['samples'])
    
    # Pad the data
    def pad_to_max_length(array, max_length):
        return ak.Array(
            np.array(
                [np.pad(sub_array, (0, max_length - len(sub_array)), 'constant') 
                 for sub_array in ak.to_list(array)]
            )
        )
    
    # Get the maximum length
    times_lengths = ak.num(normalised_times, axis=1)
    max_time_length = ak.max(times_lengths)
    samples_lengths = ak.num(normalised_samples, axis=1)
    max_samples_length = ak.max(samples_lengths)
    
    # Pad times and samples
    padded_times = np.array(pad_to_max_length(normalised_times, max_time_length))
    padded_samples = np.array(pad_to_max_length(normalised_samples, max_samples_length))
    
    # Combine for CNN input
    X_combined = np.concatenate([padded_times, padded_samples], axis=1)
    
    # Add zero-padding on each side of the data
    X_padded = np.pad(X_combined, ((0, 0), (padding_length, padding_length)), 
                       mode='constant', constant_values=0)
    
    # Create feature set for BDT and RF
    bdt_features = df[['area', 'max_pulse_height', 'r', 'S2_width', 'x', 'y']]
    
    # Get labels and normalize area
    y = np.array(arr['label'])
    normalized_area = np.array(arr['area'] / electron_size)
    
    # Calculate weights (if you have a specific weighting scheme)
    # For simplicity, we're using uniform weights here
    weights = np.ones(len(y))
    
    # Split data
    X_cnn_train, X_cnn_test, X_bdt_train, X_bdt_test, y_train, y_test, area_train, area_test, weights_train, weights_test = train_test_split(
        X_padded, bdt_features, y, normalized_area, weights, test_size=0.25, random_state=seed_value
    )
    
    # Make sure the y arrays are the right type
    y_train = np.array(y_train)
    y_test = np.array(y_test)
    
    # Reshape CNN input to have a channel dimension
    X_cnn_train_reshaped = X_cnn_train.reshape(X_cnn_train.shape[0], X_cnn_train.shape[1], 1)
    X_cnn_test_reshaped = X_cnn_test.reshape(X_cnn_test.shape[0], X_cnn_test.shape[1], 1)
    
    print(f"Data prepared. Training samples: {len(X_cnn_train)}, Test samples: {len(X_cnn_test)}")
    
    return {
        'X_cnn_train': X_cnn_train_reshaped,
        'X_cnn_test': X_cnn_test_reshaped,
        'X_bdt_train': X_bdt_train,
        'X_bdt_test': X_bdt_test,
        'y_train': y_train,
        'y_test': y_test,
        'weights_train': weights_train,
        'weights_test': weights_test,
        'area_test': area_test
    }

def build_and_train_models(data):
    """
    Build and train the CNN, BDT, and RF models.
    """
    print("Building and training models...")
    models = {}
    predictions = {}
    
    # 1. CNN Model with exact configuration from provided code
    print("Training CNN model...")
    
    # Setup callbacks for CNN
    early_stopping = EarlyStopping(
        monitor='val_loss',
        min_delta=0.005,
        patience=3,
        verbose=1,
        restore_best_weights=True
    )
    
    lr_scheduler = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=2,
        min_lr=1e-6,
        verbose=1
    )
    
    callbacks = [early_stopping, lr_scheduler]
    
    # Define the CNN model with the exact architecture provided
    convoNN = keras.Sequential([
        # First 1D convolution layer
        keras.layers.Conv1D(
            filters=64, 
            kernel_size=100, 
            activation='relu', 
            padding='same', 
            kernel_regularizer=regularizers.l2(0.001), 
            input_shape=(data['X_cnn_train'].shape[1], 1)
        ),
        keras.layers.MaxPooling1D(pool_size=2),
        Dropout(0.2),
        
        # Second 1D convolution layer
        keras.layers.Conv1D(
            filters=154, 
            kernel_size=60, 
            padding='same', 
            activation='relu'
        ),
        keras.layers.MaxPooling1D(pool_size=2),
        # Dropout(0.3),  # Commented out as in original code
        
        # Flatten layer
        keras.layers.Flatten(),
        keras.layers.Dense(96, activation='relu', kernel_regularizer=regularizers.l2(0.001)),
        keras.layers.Dense(3, activation='softmax')
    ])
    
    # Compile the model with legacy Adam optimizer
    optimizer = legacy.Adam(learning_rate=5.762e-4)
    convoNN.compile(
        optimizer=optimizer, 
        loss='sparse_categorical_crossentropy', 
        metrics=['accuracy']
    )
    
    # Train the model with callbacks
    history = convoNN.fit(
        data['X_cnn_train'], 
        data['y_train'], 
        sample_weight=data['weights_train'],
        epochs=15, 
        batch_size=323, 
        validation_split=0.2, 
        callbacks=callbacks, 
        verbose=1  # Set to 1 to show progress
    )
    
    # Get CNN predictions
    predictions['cnn'] = convoNN.predict(data['X_cnn_test'], verbose=0)
    models['cnn'] = convoNN
    
    # Plot training history
    plt.figure(figsize=(12, 6))
    
    # Plot training & validation accuracy
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], color='#E63946', linewidth=2.5, label='Train')
    plt.plot(history.history['val_accuracy'], color='#457B9D', linewidth=2.5, label='Validation')
    plt.ylabel('Accuracy', fontsize=18)
    plt.xlabel('Epoch', fontsize=18)
    plt.legend(loc='lower right', fontsize=18)
    
    # Plot training & validation loss
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], color='#E63946', linewidth=2.5, label='Train')
    plt.plot(history.history['val_loss'], color='#457B9D', linewidth=2.5, label='Validation')
    plt.ylabel('Loss', fontsize=18)
    plt.xlabel('Epoch', fontsize=18)
    if len(history.history['loss']) > 10:
        plt.xlim(0, 10)
    plt.legend(loc='upper right', fontsize=18)
    
    plt.tight_layout()
    plt.savefig('training_validation_metrics.png', dpi=1200)
    plt.show()
    
    # 2. XGBoost (BDT) Model
    print("Training BDT model...")
    dtrain = xgb.DMatrix(data['X_bdt_train'], label=data['y_train'], weight=data['weights_train'])
    dtest = xgb.DMatrix(data['X_bdt_test'], label=data['y_test'])
    
    params = {
        "objective": "multi:softprob",
        "num_class": 3,
        "eval_metric": "mlogloss",
        "eta": 0.02,
        "max_depth": 3,
        "seed": seed_value
    }
    
    bst = xgb.train(
        params=params,
        dtrain=dtrain,
        num_boost_round=100,
        verbose_eval=False
    )
    
    # Get BDT predictions
    predictions['bdt'] = bst.predict(dtest)
    models['bdt'] = bst
    
    # 3. Random Forest Model
    print("Training RF model...")
    rf_model = RandomForestClassifier(
        n_estimators=100, 
        max_depth=20, 
        min_samples_split=10,
        max_features='sqrt', 
        random_state=seed_value
    )
    
    rf_model.fit(data['X_bdt_train'], data['y_train'])
    
    # Get RF predictions
    predictions['rf'] = rf_model.predict_proba(data['X_bdt_test'])
    models['rf'] = rf_model
    
    print("All models trained.")
    return models, predictions, data['y_test']

def plot_roc_curves(y_test, cnn_probs, bdt_probs, rf_probs, event_types=['gate', 'cathode'], save_plots=True):
    """
    Plot ROC curves comparing CNN, BDT, and RF models for event discrimination.
    """
    plt.rcParams['figure.figsize'] = [12, 8]
    plt.rcParams['font.size'] = 14
    
    models = {
        'CNN': {'probs': cnn_probs, 'color': 'black', 'linestyle': '-', 'linewidth': 2},
        'BDT': {'probs': bdt_probs, 'color': 'darkblue', 'linestyle': '-', 'linewidth': 2},
        'RF': {'probs': rf_probs, 'color': 'cornflowerblue', 'linestyle': '-', 'linewidth': 2}
    }
    
    # Plot combined ROC curves for each event type
    for event_type in event_types:
        # Set up the event-specific data
        if event_type == 'gate':
            event_idx = 1
            title = 'Tritium vs Gate'
            xlabel = 'Gate Leakage (False Positive Rate)'
            filename = 'ROC_curves_combined_gate.png'
        else:  # cathode
            event_idx = 0
            title = 'Tritium vs Cathode'
            xlabel = 'Cathode Leakage (False Positive Rate)'
            filename = 'ROC_curves_combined_cathode.png'
        
        # Create binary classification mask (tritium vs the specific event type)
        mask = (y_test == 2) | (y_test == event_idx)
        binary_labels = np.where(y_test[mask] == 2, 1, 0)  # Tritium = 1, Other = 0
        
        plt.figure(figsize=(12, 8))
        
        # Plot each model's ROC curve
        for model_name, model_info in models.items():
            # Get tritium probabilities for binary classification
            tritium_probs = model_info['probs'][mask, 2]
            
            # Calculate ROC curve and AUC
            fpr, tpr, _ = roc_curve(binary_labels, tritium_probs)
            roc_auc = auc(fpr, tpr)
            
            # Plot the ROC curve
            plt.plot(
                fpr, tpr, 
                label=f'{model_name} (AUC = {roc_auc:.2f})', 
                color=model_info['color'],
                linestyle=model_info['linestyle'],
                linewidth=model_info['linewidth']
            )
        
        # Add plot labels and styling
        plt.xlabel(xlabel)
        plt.ylabel('Tritium Acceptance (True Positive Rate)')
        plt.legend(loc='lower right')
        plt.grid(False)
        
        # Save the figure if requested
        if save_plots:
            plt.savefig(filename, dpi=1500, bbox_inches='tight')
        
        plt.show()
        
        # Also create individual plots for each model
        if save_plots:
            for model_name, model_info in models.items():
                plt.figure(figsize=(12, 8))
                
                # Get tritium probabilities for binary classification
                tritium_probs = model_info['probs'][mask, 2]
                
                # Calculate ROC curve and AUC
                fpr, tpr, _ = roc_curve(binary_labels, tritium_probs)
                roc_auc = auc(fpr, tpr)
                
                # Plot single model ROC curve
                plt.plot(
                    fpr, tpr, 
                    label=f'{model_name} (AUC = {roc_auc:.2f})', 
                    color=model_info['color'],
                    linestyle=model_info['linestyle'],
                    linewidth=model_info['linewidth']
                )
                
                # Add plot labels and styling
                plt.xlabel(xlabel)
                plt.ylabel('Tritium Acceptance (True Positive Rate)')
                plt.legend(loc='lower right')
                plt.grid(False)
                
                # Save the individual plot
                plt.savefig(f'ROC_curve{model_name}_{event_type}.png', dpi=1500, bbox_inches='tight')
                plt.close()

def main(data_path='padded_waveforms.parquet'):
    """
    Main function to run the full pipeline.
    """
    # Load and prepare data
    data = load_and_prepare_data(data_path)
    
    # Build and train models
    models, predictions, y_test = build_and_train_models(data)
    
    # Plot ROC curves
    plot_roc_curves(
        y_test,
        predictions['cnn'],
        predictions['bdt'],
        predictions['rf'],
        event_types=['gate', 'cathode'],
        save_plots=True
    )

if __name__ == "__main__":
    # Replace with your data path
    main(data_path='padded_waveforms.parquet')