## Section 1: Setup & Configuration

In [None]:
# ============================================================================
# DEPENDENCIES INSTALLATION
# ============================================================================

!pip install jax jaxlib optax flax pandas numpy matplotlib seaborn scikit-learn -q

print("✓ JAX dependencies installed")

In [None]:
# ============================================================================
# IMPORTS
# ============================================================================

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import json
import os
from tqdm import tqdm

# JAX ecosystem
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import optax  # Optimisation library
from flax import linen as nn  # Neural network library
from flax.training import train_state

# Scikit-learn utilities
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix, 
    roc_curve, 
    auc, 
    roc_auc_score,
    classification_report
)
from sklearn.utils.class_weight import compute_class_weight

# Check JAX device
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")
print(f"Default backend: {jax.default_backend()}")

# Configuration
RANDOM_SEED = 42
MOMENTUM_RANGE = {'min': 1.0, 'max': 2.0, 'name': '1-2 GeV/c'}
PARTICLE_NAMES = ['Pion', 'Kaon', 'Proton', 'Electron']
NUM_CLASSES = len(PARTICLE_NAMES)

BASE_DIR = '/kaggle/working'
SAVE_DIR = os.path.join(BASE_DIR, 'JAX_Models')
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"✓ Configuration loaded")
print(f"  Momentum range: {MOMENTUM_RANGE['name']}")
print(f"  Target classes: {NUM_CLASSES}")
print(f"  Random seed: {RANDOM_SEED}")

## Section 2: Data Loading & Preprocessing

In [None]:
# ============================================================================
# LOAD DATA
# ============================================================================

CSV_PATH = '/kaggle/input/pid-features/pid_features_large.csv'

print("Loading data...")
df_iter = pd.read_csv(CSV_PATH, dtype='float32', chunksize=500_000, low_memory=False)
df = pd.concat(df_iter, ignore_index=True)
print(f"✓ Loaded: {df.shape}")

# ============================================================================
# MOMENTUM RANGE SELECTION (1-2 GeV/c)
# ============================================================================

print(f"\nFiltering momentum range: {MOMENTUM_RANGE['name']}")
df_range = df[(df['pt'] >= MOMENTUM_RANGE['min']) & 
              (df['pt'] < MOMENTUM_RANGE['max'])].copy()
print(f"✓ Selected {len(df_range):,} tracks in range")

# ============================================================================
# HANDLE MISSING VALUES
# ============================================================================

# Replace sentinel values with NaN
df_range.replace(-999, np.nan, inplace=True)

# Identify feature groups
tof_features = [col for col in df_range.columns if 'tof' in col.lower()]
tpc_features = [col for col in df_range.columns if 'tpc' in col.lower()]
bayes_features = [col for col in df_range.columns if 'bayes_prob' in col.lower()]

print("\nHandling missing values:")
print(f"  TOF features: {len(tof_features)}")
print(f"  TPC features: {len(tpc_features)}")
print(f"  Bayesian features: {len(bayes_features)}")

# Fill missing values with 0 (detector not hit)
df_range[tof_features] = df_range[tof_features].fillna(0)
df_range[tpc_features] = df_range[tpc_features].fillna(0)
df_range[bayes_features] = df_range[bayes_features].fillna(0)

# Create indicator features for missing data
df_range['has_tof'] = (df_range[tof_features].abs().sum(axis=1) > 0).astype(int)
df_range['has_tpc'] = (df_range[tpc_features].abs().sum(axis=1) > 0).astype(int)

print(f"  ✓ Missing values handled")
print(f"  Tracks with TOF: {df_range['has_tof'].sum():,} ({df_range['has_tof'].mean()*100:.1f}%)")
print(f"  Tracks with TPC: {df_range['has_tpc'].sum():,} ({df_range['has_tpc'].mean()*100:.1f}%)")

# ============================================================================
# MAP PDG CODES TO PARTICLE SPECIES
# ============================================================================

def pdg_to_species(pdg):
    ap = abs(int(pdg))
    if ap == 211:
        return 0  # Pion
    elif ap == 321:
        return 1  # Kaon
    elif ap == 2212:
        return 2  # Proton
    elif ap == 11:
        return 3  # Electron
    else:
        return -1  # Unknown

df_range['particle_species'] = df_range['mc_pdg'].apply(pdg_to_species)

# Keep only valid particles
df_range = df_range[df_range['particle_species'] >= 0].reset_index(drop=True)
print(f"✓ Valid particles: {len(df_range):,}")

# ============================================================================
# CLASS DISTRIBUTION
# ============================================================================

print("\nClass distribution:")
class_counts = df_range['particle_species'].value_counts().sort_index()
for idx, count in class_counts.items():
    print(f"  {PARTICLE_NAMES[idx]:10s}: {count:7,} ({count/len(df_range)*100:5.2f}%)")

# ============================================================================
# FEATURE ENGINEERING
# ============================================================================

# Log-transform momentum-related features
for feature in ['pt', 'p', 'tpc_signal', 'tof_beta']:
    if feature in df_range.columns:
        df_range[feature] = np.log1p(df_range[feature].abs())

# Define training features
training_features = [
    'pt', 'eta', 'phi', 'tpc_signal',
    'tpc_nsigma_pi', 'tpc_nsigma_ka', 'tpc_nsigma_pr', 'tpc_nsigma_el',
    'tof_beta',
    'tof_nsigma_pi', 'tof_nsigma_ka', 'tof_nsigma_pr', 'tof_nsigma_el',
    'bayes_prob_pi', 'bayes_prob_ka', 'bayes_prob_pr', 'bayes_prob_el',
    'dca_xy', 'dca_z',
    'has_tpc', 'has_tof'
]

available_features = [f for f in training_features if f in df_range.columns]
print(f"\n✓ Available features: {len(available_features)}")

# Drop any remaining NaNs in features
df_range = df_range.dropna(subset=available_features)
print(f"✓ Final dataset: {df_range.shape}")

## Section 3: Background Cleaning

In [None]:
# ============================================================================
# BACKGROUND CLEANING
# ============================================================================

print("\n" + "="*80)
print("BACKGROUND CLEANING")
print("="*80)

initial_count = len(df_range)

# 1. Remove tracks with unrealistic momenta
df_range = df_range[(df_range['pt'] > np.log1p(0.05)) & 
                    (df_range['pt'] < np.log1p(20))].copy()
print(f"✓ Momentum cut: {initial_count - len(df_range):,} tracks removed")

# 2. Remove tracks with poor DCA (distance of closest approach)
dca_cut = 3.0  # cm
df_range = df_range[(df_range['dca_xy'].abs() < dca_cut) & 
                    (df_range['dca_z'].abs() < dca_cut)].copy()
print(f"✓ DCA cut: {initial_count - len(df_range):,} tracks removed")

# 3. Remove tracks with inconsistent PID signals
# (High n-sigma in both TPC and TOF suggests misidentification)
nsigma_threshold = 5.0
for particle in ['pi', 'ka', 'pr', 'el']:
    tpc_col = f'tpc_nsigma_{particle}'
    tof_col = f'tof_nsigma_{particle}'
    if tpc_col in df_range.columns and tof_col in df_range.columns:
        df_range = df_range[
            ~((df_range[tpc_col].abs() > nsigma_threshold) & 
              (df_range[tof_col].abs() > nsigma_threshold))
        ].copy()

print(f"✓ PID consistency cut: {initial_count - len(df_range):,} tracks removed")
print(f"✓ Clean dataset: {len(df_range):,} tracks ({len(df_range)/initial_count*100:.1f}% retained)")

# Update class distribution after cleaning
print("\nClass distribution after cleaning:")
class_counts_clean = df_range['particle_species'].value_counts().sort_index()
for idx, count in class_counts_clean.items():
    print(f"  {PARTICLE_NAMES[idx]:10s}: {count:7,} ({count/len(df_range)*100:5.2f}%)")

## Section 4: Train/Test Split & Scaling

In [None]:
# ============================================================================
# PREPARE TRAINING DATA
# ============================================================================

X = df_range[available_features].astype('float32')
y = df_range['particle_species'].values.astype('int32')

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=RANDOM_SEED, stratify=y
)

print(f"\n✓ Train samples: {len(X_train):,}")
print(f"✓ Test samples: {len(X_test):,}")

# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.values)
X_test_scaled = scaler.transform(X_test.values)

print(f"✓ Features standardized (mean=0, std=1)")

# Compute class weights for imbalanced data
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(y_train),
    y=y_train
)
class_weights_dict = dict(enumerate(class_weights))
print(f"\n✓ Class weights computed:")
for idx, weight in class_weights_dict.items():
    print(f"  {PARTICLE_NAMES[idx]:10s}: {weight:.4f}")

# Convert to JAX arrays
X_train_jax = jnp.array(X_train_scaled, dtype=jnp.float32)
X_test_jax = jnp.array(X_test_scaled, dtype=jnp.float32)
y_train_jax = jnp.array(y_train, dtype=jnp.int32)
y_test_jax = jnp.array(y_test, dtype=jnp.int32)
class_weights_jax = jnp.array(list(class_weights_dict.values()), dtype=jnp.float32)

print(f"\n✓ Data converted to JAX arrays")
print(f"  X_train shape: {X_train_jax.shape}")
print(f"  y_train shape: {y_train_jax.shape}")

## Section 5: JAX Neural Network Definition

In [None]:
# ============================================================================
# DEFINE NEURAL NETWORK WITH FLAX
# ============================================================================

class PIDNeuralNetwork(nn.Module):
    """
    Neural network for particle identification with:
    - Batch normalization
    - Dropout for regularization
    - Skip connections
    """
    hidden_dims: list
    num_classes: int
    dropout_rate: float = 0.3
    
    @nn.compact
    def __call__(self, x, training: bool = False):
        # Input layer
        z = x
        
        # Hidden layers with batch norm and dropout
        for i, dim in enumerate(self.hidden_dims):
            z = nn.Dense(dim, name=f'dense_{i}')(z)
            z = nn.BatchNorm(use_running_average=not training, name=f'bn_{i}')(z)
            z = nn.relu(z)
            z = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(z)
        
        # Output layer
        logits = nn.Dense(self.num_classes, name='output')(z)
        return logits

# Initialize model
key = random.PRNGKey(RANDOM_SEED)
model = PIDNeuralNetwork(
    hidden_dims=[256, 128, 64],
    num_classes=NUM_CLASSES,
    dropout_rate=0.3
)

# Initialize parameters
dummy_input = jnp.ones((1, X_train_jax.shape))
params = model.init(key, dummy_input, training=False)

print("✓ Model architecture:")
print(f"  Input features: {X_train_jax.shape}")
print(f"  Hidden layers: {[256, 128, 64]}")
print(f"  Output classes: {NUM_CLASSES}")
print(f"  Dropout rate: 0.3")

## Section 6: Training Setup

In [None]:
# ============================================================================
# DEFINE LOSS FUNCTION WITH CLASS WEIGHTS
# ============================================================================

def weighted_cross_entropy_loss(logits, labels, class_weights):
    """Cross-entropy loss with class weighting"""
    # One-hot encode labels
    one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
    
    # Compute log softmax
    log_softmax = jax.nn.log_softmax(logits, axis=-1)
    
    # Apply class weights
    sample_weights = class_weights[labels]
    
    # Compute weighted loss
    loss = -jnp.sum(one_hot_labels * log_softmax, axis=-1)
    weighted_loss = loss * sample_weights
    
    return jnp.mean(weighted_loss)

@jit
def train_step(state, batch_x, batch_y, class_weights):
    """Single training step"""
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch_x, training=True, 
                                rngs={'dropout': state.step})
        loss = weighted_cross_entropy_loss(logits, batch_y, class_weights)
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

@jit
def eval_step(state, batch_x, batch_y):
    """Single evaluation step"""
    logits = state.apply_fn({'params': state.params}, batch_x, training=False)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == batch_y)
    return accuracy, logits

# ============================================================================
# INITIALIZE TRAINING STATE
# ============================================================================

learning_rate = 1e-3
tx = optax.adam(learning_rate)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params['params'],
    tx=tx
)

print("✓ Training state initialized")
print(f"  Optimiser: Adam")
print(f"  Learning rate: {learning_rate}")

## Section 7: Training Loop

In [None]:
# ============================================================================
# TRAIN MODEL
# ============================================================================

print("\n" + "="*80)
print("TRAINING NEURAL NETWORK")
print("="*80)

BATCH_SIZE = 256
NUM_EPOCHS = 50
PATIENCE = 10

# Create batches
num_batches = len(X_train_jax) // BATCH_SIZE
best_val_acc = 0.0
patience_counter = 0

train_losses = []
val_accuracies = []

for epoch in range(NUM_EPOCHS):
    # Shuffle training data
    key, subkey = random.split(key)
    perm = random.permutation(subkey, len(X_train_jax))
    X_train_shuffled = X_train_jax[perm]
    y_train_shuffled = y_train_jax[perm]
    
    # Training
    epoch_losses = []
    for batch_idx in range(num_batches):
        start_idx = batch_idx * BATCH_SIZE
        end_idx = start_idx + BATCH_SIZE
        
        batch_x = X_train_shuffled[start_idx:end_idx]
        batch_y = y_train_shuffled[start_idx:end_idx]
        
        state, loss = train_step(state, batch_x, batch_y, class_weights_jax)
        epoch_losses.append(loss)
    
    avg_train_loss = np.mean(epoch_losses)
    train_losses.append(avg_train_loss)
    
    # Validation
    val_acc, _ = eval_step(state, X_test_jax, y_test_jax)
    val_accuracies.append(float(val_acc))
    
    # Print progress
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS} | "
              f"Loss: {avg_train_loss:.4f} | "
              f"Val Acc: {val_acc:.4f}")
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        best_params = state.params
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n✓ Early stopping at epoch {epoch+1}")
            break

# Restore best parameters
state = state.replace(params=best_params)

print(f"\n✓ Training complete!")
print(f"  Best validation accuracy: {best_val_acc:.4f}")

## Section 8: Evaluation

In [None]:
# ============================================================================
# EVALUATE MODEL
# ============================================================================

print("\n" + "="*80)
print("MODEL EVALUATION")
print("="*80)

# Predictions
train_acc, train_logits = eval_step(state, X_train_jax, y_train_jax)
test_acc, test_logits = eval_step(state, X_test_jax, y_test_jax)

print(f"Train Accuracy: {train_acc:.4f}")
print(f"Test Accuracy:  {test_acc:.4f}")

# Convert logits to probabilities
train_probs = jax.nn.softmax(train_logits, axis=-1)
test_probs = jax.nn.softmax(test_logits, axis=-1)

# Predictions
y_pred_test = jnp.argmax(test_logits, axis=-1)

# Classification report
print("\nClassification Report:")
print(classification_report(
    y_test, 
    np.array(y_pred_test), 
    target_names=PARTICLE_NAMES,
    digits=4
))

# Save model and results
model_save_path = os.path.join(SAVE_DIR, 'pid_model_jax_1-2gev.pkl')
with open(model_save_path, 'wb') as f:
    pickle.dump({
        'params': state.params,
        'scaler': scaler,
        'features': available_features,
        'train_acc': float(train_acc),
        'test_acc': float(test_acc),
        'class_weights': class_weights_dict,
        'config': {
            'hidden_dims': [256, 128, 64],
            'dropout_rate': 0.3,
            'learning_rate': learning_rate,
            'num_epochs': epoch + 1
        }
    }, f)

print(f"\n✓ Model saved to: {model_save_path}")

## Section 9: Visualizations

In [None]:
# ============================================================================
# TRAINING HISTORY
# ============================================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
axes.plot(train_losses, linewidth=2, color='#3B82F6', label='Training Loss')
axes.set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes.set_ylabel('Loss', fontsize=12, fontweight='bold')
axes.set_title('Training Loss', fontsize=14, fontweight='bold')
axes.grid(alpha=0.3)
axes.legend(fontsize=11)

# Validation accuracy
axes.plot(val_accuracies, linewidth=2, color='#22C55E', label='Validation Accuracy')
axes.axhline(y=best_val_acc, color='r', linestyle='--', alpha=0.7, 
                label=f'Best: {best_val_acc:.4f}')
axes.set_xlabel('Epoch', fontsize=12, fontweight='bold')
axes.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
axes.set_title('Validation Accuracy', fontsize=14, fontweight='bold')
axes.grid(alpha=0.3)
axes.legend(fontsize=11)

plt.suptitle(f'Training History - {MOMENTUM_RANGE["name"]}', 
             fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# ============================================================================
# CONFUSION MATRIX
# ============================================================================

cm = confusion_matrix(y_test, np.array(y_pred_test), normalize='true')

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=PARTICLE_NAMES,
            yticklabels=PARTICLE_NAMES,
            cbar_kws={'shrink': 0.8})
plt.xlabel('Predicted', fontsize=13, fontweight='bold')
plt.ylabel('True', fontsize=13, fontweight='bold')
plt.title(f'Confusion Matrix - {MOMENTUM_RANGE["name"]}', 
          fontsize=14, fontweight='bold', pad=15)
plt.tight_layout()
plt.show()

# ============================================================================
# ROC CURVES
# ============================================================================

fig, ax = plt.subplots(figsize=(10, 8))
colors = ['#3B82F6', '#F59E0B', '#22C55E', '#EF4444']

for i, (particle, color) in enumerate(zip(PARTICLE_NAMES, colors)):
    y_true_binary = (y_test == i).astype(int)
    y_score = np.array(test_probs[:, i])
    
    fpr, tpr, _ = roc_curve(y_true_binary, y_score)
    roc_auc = auc(fpr, tpr)
    
    ax.plot(fpr, tpr, color=color, lw=2.5, 
            label=f'{particle} (AUC = {roc_auc:.3f})')

ax.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.5)
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=13, fontweight='bold')
ax.set_ylabel('True Positive Rate', fontsize=13, fontweight='bold')
ax.set_title(f'ROC Curves - {MOMENTUM_RANGE["name"]}', 
             fontsize=14, fontweight='bold', pad=15)
ax.legend(loc='lower right', fontsize=11)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

# ============================================================================
# PER-CLASS METRICS
# ============================================================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Compute per-class metrics
precisions = []
recalls = []
for i in range(NUM_CLASSES):
    y_true_binary = (y_test == i).astype(int)
    y_pred_binary = (np.array(y_pred_test) == i).astype(int)
    
    tp = np.sum((y_true_binary == 1) & (y_pred_binary == 1))
    fp = np.sum((y_true_binary == 0) & (y_pred_binary == 1))
    fn = np.sum((y_true_binary == 1) & (y_pred_binary == 0))
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    precisions.append(precision)
    recalls.append(recall)

# Precision bar chart
bars1 = axes.bar(PARTICLE_NAMES, precisions, color=colors, alpha=0.8, 
                     edgecolor='black', linewidth=1.5)
for bar, val in zip(bars1, precisions):
    axes.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{val:.3f}', ha='center', va='bottom', 
                 fontsize=11, fontweight='bold')
axes.set_ylabel('Precision', fontsize=12, fontweight='bold')
axes.set_title('Precision by Particle Type', fontsize=13, fontweight='bold')
axes.set_ylim(0, 1.1)
axes.grid(axis='y', alpha=0.3)

# Recall bar chart
bars2 = axes.bar(PARTICLE_NAMES, recalls, color=colors, alpha=0.8, 
                     edgecolor='black', linewidth=1.5)
for bar, val in zip(bars2, recalls):
    axes.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{val:.3f}', ha='center', va='bottom', 
                 fontsize=11, fontweight='bold')
axes.set_ylabel('Recall (Efficiency)', fontsize=12, fontweight='bold')
axes.set_title('Recall by Particle Type', fontsize=13, fontweight='bold')
axes.set_ylim(0, 1.1)
axes.grid(axis='y', alpha=0.3)

plt.suptitle(f'Per-Class Performance - {MOMENTUM_RANGE["name"]}', 
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
print(f"✓ Model trained on {len(X_train):,} samples")
print(f"✓ Tested on {len(X_test):,} samples")
print(f"✓ Best validation accuracy: {best_val_acc:.4f}")
print(f"✓ Test accuracy: {test_acc:.4f}")
print(f"✓ Momentum range: {MOMENTUM_RANGE['name']} (challenging region)")
print("="*80)