In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
# Use standard TensorDataset now
from torch.utils.data import Dataset, DataLoader, random_split, Subset, TensorDataset
import h5py
import matplotlib.pyplot as plt
import random
import os
import time
from torch.amp import autocast, GradScaler

# --- Configuration ---
PHOTON_FILE = 'photons.hdf5'
ELECTRON_FILE = 'electrons.hdf5'
MODEL_SAVE_PATH = 'electron_photon_resnet15_v4_ram.pth' # Suffix v4_ram
OPTIMIZER_SAVE_PATH = MODEL_SAVE_PATH + ".opt"

# --- Hyperparameters ---
SEED = 42
BATCH_SIZE = 512      # Can potentially increase further with RAM loading
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 30
NUM_WORKERS = 0         # Workers mainly help with batch collation now
TRAIN_SPLIT = 0.7
VAL_SPLIT = 0.1
TEST_SPLIT = 0.2

# --- Feature Flags ---
USE_LOG_TRANSFORM = True # Apply log(1+E) during preprocessing
USE_AUGMENTATION = True # Apply light augmentations during training
USE_MIXED_PRECISION = torch.cuda.is_available()

# --- Reproducibility ---
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
# Optional: torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Mixed Precision Training Enabled: {USE_MIXED_PRECISION}")

# --- File Checks ---
if not os.path.exists(PHOTON_FILE): raise FileNotFoundError(f"Photon file not found: {PHOTON_FILE}")
if not os.path.exists(ELECTRON_FILE): raise FileNotFoundError(f"Electron file not found: {ELECTRON_FILE}")

Using device: cuda
Mixed Precision Training Enabled: True


In [2]:
def load_data_to_device(photon_file, electron_file, device):
    print("Loading data from HDF5 files to RAM...")
    start_time = time.time()
    with h5py.File(photon_file, 'r') as f:
        photon_X_np = f['X'][:] # Shape (N, H, W, C)
        photon_y_np = f['y'][:] # Shape (N,)
        num_photons = len(photon_y_np)
    with h5py.File(electron_file, 'r') as f:
        electron_X_np = f['X'][:]
        electron_y_np = f['y'][:]
        num_electrons = len(electron_y_np)

    print(f"Loaded {num_photons} photons and {num_electrons} electrons.")

    # Concatenate NumPy arrays
    all_X_np = np.concatenate([photon_X_np, electron_X_np], axis=0).astype(np.float32)
    all_y_np = np.concatenate([photon_y_np, electron_y_np], axis=0).astype(np.float32)

    # Add channel dimension for labels (N, 1)
    all_y_np = np.expand_dims(all_y_np, axis=-1)

    ram_load_time = time.time() - start_time
    print(f"Data loaded into RAM in {ram_load_time:.2f}s.")
    print(f"Combined X shape: {all_X_np.shape}, dtype: {all_X_np.dtype}")
    print(f"Combined y shape: {all_y_np.shape}, dtype: {all_y_np.dtype}")

    print(f"Moving data to device: {device}...")
    start_time = time.time()
    # Convert to PyTorch tensors and move to device
    # Keep X in (N, H, W, C) format for now, preprocessing will handle permute
    all_X_tensor = torch.from_numpy(all_X_np).to(device)
    all_y_tensor = torch.from_numpy(all_y_np).to(device)
    gpu_load_time = time.time() - start_time
    print(f"Data moved to {device} in {gpu_load_time:.2f}s.")

    # Verify memory usage on GPU (optional)
    if device.type == 'cuda':
        print(f"GPU Memory Used (X): {all_X_tensor.element_size() * all_X_tensor.nelement() / (1024**3):.2f} GB")
        print(f"GPU Memory Used (y): {all_y_tensor.element_size() * all_y_tensor.nelement() / (1024**3):.2f} GB")
        print(f"Total GPU Memory Used: {(all_X_tensor.element_size() * all_X_tensor.nelement() + all_y_tensor.element_size() * all_y_tensor.nelement()) / (1024**3):.2f} GB")


    return all_X_tensor, all_y_tensor, num_photons, num_electrons

# Load data onto the target device
ALL_X_DEVICE, ALL_Y_DEVICE, num_photons, num_electrons = load_data_to_device(PHOTON_FILE, ELECTRON_FILE, device)
total_samples = num_photons + num_electrons

# Define labels (already known but good to have)
PHOTON_LABEL = 0.0
ELECTRON_LABEL = 1.0

Loading data from HDF5 files to RAM...
Loaded 249000 photons and 249000 electrons.
Data loaded into RAM in 26.76s.
Combined X shape: (498000, 32, 32, 2), dtype: float32
Combined y shape: (498000, 1), dtype: float32
Moving data to device: cuda...
Data moved to cuda in 1.08s.
GPU Memory Used (X): 3.80 GB
GPU Memory Used (y): 0.00 GB
Total GPU Memory Used: 3.80 GB


In [3]:
@torch.no_grad()
def preprocess_full_tensor(x_tensor_nhwc, use_log): # Input is NHWC
    """
    Applies log transform, calculates stats, normalizes, and permutes.
    Returns:
        x_tensor_processed (NCHW, normalized),
        mean_gpu,
        std_gpu,
        x_tensor_prenorm_nchw (NCHW, log-transformed but not normalized)
    """
    print("\nPreprocessing full dataset tensor on GPU...")
    start_time = time.time()
    # Input x_tensor shape: (N, H, W, C)

    # Create a working copy to avoid modifying the original ALL_X_DEVICE if needed later
    x_working = x_tensor_nhwc.clone()

    # 1. Log Transform (Energy channel: index 0) - In-place on the clone
    if use_log:
        print("Applying log1p transform to energy channel...")
        x_working[:, :, :, 0] = torch.log1p(torch.relu(x_working[:, :, :, 0]))

    # Store the pre-normalization tensor (after log transform)
    # Permute it now to NCHW for consistency in return values
    x_tensor_prenorm_nchw = x_working.permute(0, 3, 1, 2).contiguous()
    print(f"Shape before normalization (NCHW): {x_tensor_prenorm_nchw.shape}")


    # 2. Calculate Mean and Std Dev directly on GPU tensor (using the log-transformed data)
    print("Calculating mean and std dev on GPU (post-log transform)...")
    # Calculate over N, H, W dimensions of the NHWC tensor, keep C dimension
    mean_gpu = torch.mean(x_working, dim=(0, 1, 2), keepdim=False)
    std_gpu = torch.std(x_working, dim=(0, 1, 2), keepdim=False)
    std_gpu = torch.clamp(std_gpu, min=1e-6)
    print(f"Calculated Mean (GPU): {mean_gpu.cpu().numpy()}")
    print(f"Calculated Std Dev (GPU): {std_gpu.cpu().numpy()}")

    # 3. Normalize the tensor (using the log-transformed data)
    print("Applying normalization...")
    # Reshape mean/std for broadcasting: (1, 1, 1, C) for NHWC tensor
    mean_reshaped = mean_gpu.view(1, 1, 1, -1)
    std_reshaped = std_gpu.view(1, 1, 1, -1)
    # Apply normalization to the log-transformed tensor
    x_tensor_normalized_nhwc = (x_working - mean_reshaped) / std_reshaped

    # 4. Permute the *normalized* tensor to (N, C, H, W) for the model
    print("Permuting normalized tensor to NCHW format...")
    x_tensor_processed_nchw = x_tensor_normalized_nhwc.permute(0, 3, 1, 2).contiguous()

    print(f"Preprocessing completed in {time.time() - start_time:.2f}s.")
    print(f"Final processed X tensor shape (normalized NCHW): {x_tensor_processed_nchw.shape}")

    # Clean up intermediate tensor if needed
    del x_working, x_tensor_normalized_nhwc
    torch.cuda.empty_cache()

    return x_tensor_processed_nchw, mean_gpu, std_gpu, x_tensor_prenorm_nchw

# Preprocess the data ONCE
ALL_X_PROC_DEVICE, GLOBAL_MEAN_GPU, GLOBAL_STD_GPU, ALL_X_PRENORM_DEVICE = preprocess_full_tensor(ALL_X_DEVICE, USE_LOG_TRANSFORM)

# We no longer need the original raw tensor on the GPU if memory is a concern
# del ALL_X_DEVICE
# torch.cuda.empty_cache()


Preprocessing full dataset tensor on GPU...
Applying log1p transform to energy channel...
Shape before normalization (NCHW): torch.Size([498000, 2, 32, 32])
Calculating mean and std dev on GPU (post-log transform)...
Calculated Mean (GPU): [ 0.00102406 -0.00026181]
Calculated Std Dev (GPU): [0.01807128 0.06738362]
Applying normalization...


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.80 GiB. GPU 

In [None]:
def plot_sample_from_tensor(x_tensor_nchw, y_tensor, index, particle_type_str, title_suffix=""):
    """Plots Energy and Time channels for a single sample from a tensor."""
    # Input x_tensor_nchw shape: (N, C, H, W)
    img_chw = x_tensor_nchw[index].cpu().numpy() # Move to CPU for plotting
    label_val = y_tensor[index].item()

    energy, time = img_chw[0, :, :], img_chw[1, :, :] # C=0 is Energy, C=1 is Time

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    fig.suptitle(f"{particle_type_str} Sample (Index: {index}, Label: {label_val}) {title_suffix}", fontsize=14)

    im_energy = axes[0].imshow(energy, cmap='viridis', aspect='auto')
    axes[0].set_title(f'Energy Channel {title_suffix}')
    fig.colorbar(im_energy, ax=axes[0])

    im_time = axes[1].imshow(time, cmap='plasma', aspect='auto')
    axes[1].set_title(f'Time Channel {title_suffix}')
    fig.colorbar(im_time, ax=axes[1])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

def plot_histograms_comparison(x_prenorm_nchw, x_postnorm_nchw, y_tensor, num_samples_hist=50000, title_prefix=""):
    """Plots histograms comparing pre- and post-normalization distributions."""
    print(f"\nPlotting {title_prefix} histograms using {num_samples_hist} random samples...")
    start_time = time.time()

    # Ensure tensors are on the correct device (should already be, but safe check)
    x_prenorm_nchw = x_prenorm_nchw.to(device)
    x_postnorm_nchw = x_postnorm_nchw.to(device)
    y_tensor = y_tensor.to(device)


    plot_indices = np.random.choice(x_prenorm_nchw.shape[0], num_samples_hist, replace=False)

    # Sample data directly from GPU tensors
    sampled_x_prenorm = x_prenorm_nchw[plot_indices]
    sampled_x_postnorm = x_postnorm_nchw[plot_indices]
    sampled_y = y_tensor[plot_indices]

    # Separate photons and electrons
    photon_mask = (sampled_y == PHOTON_LABEL).squeeze()
    electron_mask = (sampled_y == ELECTRON_LABEL).squeeze()

    # --- Extract Pre-Normalization Data ---
    photon_energies_pre = sampled_x_prenorm[photon_mask, 0, :, :].flatten().cpu().numpy()
    photon_times_pre = sampled_x_prenorm[photon_mask, 1, :, :].flatten().cpu().numpy()
    electron_energies_pre = sampled_x_prenorm[electron_mask, 0, :, :].flatten().cpu().numpy()
    electron_times_pre = sampled_x_prenorm[electron_mask, 1, :, :].flatten().cpu().numpy()

    # --- Extract Post-Normalization Data ---
    photon_energies_post = sampled_x_postnorm[photon_mask, 0, :, :].flatten().cpu().numpy()
    photon_times_post = sampled_x_postnorm[photon_mask, 1, :, :].flatten().cpu().numpy()
    electron_energies_post = sampled_x_postnorm[electron_mask, 0, :, :].flatten().cpu().numpy()
    electron_times_post = sampled_x_postnorm[electron_mask, 1, :, :].flatten().cpu().numpy()

    del sampled_x_prenorm, sampled_x_postnorm, sampled_y
    torch.cuda.empty_cache()

    # Filter non-finite/zero values for plotting
    def filter_data(arr):
        # Keep finite values, allow zeros for post-norm data
        return arr[np.isfinite(arr)]

    def filter_data_pre(arr):
         # Keep finite values, filter near-zero for pre-norm log scale
        return arr[np.isfinite(arr) & (np.abs(arr) > 1e-6)]


    photon_energies_pre_nz = filter_data_pre(photon_energies_pre)
    electron_energies_pre_nz = filter_data_pre(electron_energies_pre)
    photon_times_pre_nz = filter_data_pre(photon_times_pre)
    electron_times_pre_nz = filter_data_pre(electron_times_pre)

    photon_energies_post_f = filter_data(photon_energies_post)
    electron_energies_post_f = filter_data(electron_energies_post)
    photon_times_post_f = filter_data(photon_times_post)
    electron_times_post_f = filter_data(electron_times_post)

    print(f"Histogram data preparation took {time.time() - start_time:.2f} seconds.")

    # --- Plotting ---
    fig, axes = plt.subplots(2, 4, figsize=(24, 10)) # 2 rows, 4 columns
    fig.suptitle(f"{title_prefix} Channel Value Distributions Comparison", fontsize=16)

    # Column Titles
    axes[0, 0].set_title('Energy Pre-Norm (Linear)')
    axes[0, 1].set_title('Energy Pre-Norm (Log Scale)')
    axes[0, 2].set_title('Energy Post-Norm (Linear)')
    axes[0, 3].set_title('Energy Post-Norm (Symlog Scale)') # Changed to symlog
    axes[1, 0].set_title('Time Pre-Norm (Linear)')
    axes[1, 1].set_title('Time Pre-Norm (Log Scale)')
    axes[1, 2].set_title('Time Post-Norm (Linear)')
    axes[1, 3].set_title('Time Post-Norm (Symlog Scale)') # Changed to symlog

    # --- Pre-Normalization Plots ---
    # Energy Linear Pre
    axes[0, 0].hist(photon_energies_pre_nz, bins=100, alpha=0.7, label='Photons', density=True)
    axes[0, 0].hist(electron_energies_pre_nz, bins=100, alpha=0.7, label='Electrons', density=True)
    axes[0, 0].set_xlabel('Energy (Log-Transformed)')
    axes[0, 0].set_ylabel('Density')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Energy Log Pre
    min_e_pre = min(photon_energies_pre_nz.min(), electron_energies_pre_nz.min()) if len(photon_energies_pre_nz)>0 and len(electron_energies_pre_nz)>0 else 1e-5
    max_e_pre = max(photon_energies_pre_nz.max(), electron_energies_pre_nz.max()) if len(photon_energies_pre_nz)>0 and len(electron_energies_pre_nz)>0 else 1.0
    log_bins_e_pre = np.logspace(np.log10(max(1e-5, min_e_pre)), np.log10(max_e_pre), 100)
    axes[0, 1].hist(photon_energies_pre_nz, bins=log_bins_e_pre, alpha=0.7, label='Photons', density=True)
    axes[0, 1].hist(electron_energies_pre_nz, bins=log_bins_e_pre, alpha=0.7, label='Electrons', density=True)
    axes[0, 1].set_xlabel('Energy (Log-Transformed)')
    axes[0, 1].set_xscale('log')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Time Linear Pre
    axes[1, 0].hist(photon_times_pre_nz, bins=100, alpha=0.7, label='Photons', density=True)
    axes[1, 0].hist(electron_times_pre_nz, bins=100, alpha=0.7, label='Electrons', density=True)
    axes[1, 0].set_xlabel('Time')
    axes[1, 0].set_ylabel('Density')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Time Log Pre
    min_t_pre = min(photon_times_pre_nz.min(), electron_times_pre_nz.min()) if len(photon_times_pre_nz)>0 and len(electron_times_pre_nz)>0 else 1e-5
    max_t_pre = max(photon_times_pre_nz.max(), electron_times_pre_nz.max()) if len(photon_times_pre_nz)>0 and len(electron_times_pre_nz)>0 else 1.0
    log_bins_t_pre = np.logspace(np.log10(max(1e-5, min_t_pre)), np.log10(max_t_pre), 100)
    axes[1, 1].hist(photon_times_pre_nz, bins=log_bins_t_pre, alpha=0.7, label='Photons', density=True)
    axes[1, 1].hist(electron_times_pre_nz, bins=log_bins_t_pre, alpha=0.7, label='Electrons', density=True)
    axes[1, 1].set_xlabel('Time')
    axes[1, 1].set_xscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    # --- Post-Normalization Plots ---
    # Energy Linear Post
    axes[0, 2].hist(photon_energies_post_f, bins=100, range=(-5, 50), alpha=0.7, label='Photons', density=True) # Adjust range if needed
    axes[0, 2].hist(electron_energies_post_f, bins=100, range=(-5, 50), alpha=0.7, label='Electrons', density=True)
    axes[0, 2].set_xlabel('Processed Energy (Normalized)')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)

    # Energy Symlog Post (Better for data centered around zero)
    axes[0, 3].hist(photon_energies_post_f, bins=100, alpha=0.7, label='Photons', density=True)
    axes[0, 3].hist(electron_energies_post_f, bins=100, alpha=0.7, label='Electrons', density=True)
    axes[0, 3].set_xlabel('Processed Energy (Normalized)')
    axes[0, 3].set_xscale('symlog', linthresh=0.1) # Use symlog to see values near zero better
    axes[0, 3].legend()
    axes[0, 3].grid(True, alpha=0.3)

    # Time Linear Post
    axes[1, 2].hist(photon_times_post_f, bins=100, range=(-10, 10), alpha=0.7, label='Photons', density=True) # Adjust range
    axes[1, 2].hist(electron_times_post_f, bins=100, range=(-10, 10), alpha=0.7, label='Electrons', density=True)
    axes[1, 2].set_xlabel('Processed Time (Normalized)')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)

    # Time Symlog Post
    axes[1, 3].hist(photon_times_post_f, bins=100, alpha=0.7, label='Photons', density=True)
    axes[1, 3].hist(electron_times_post_f, bins=100, alpha=0.7, label='Electrons', density=True)
    axes[1, 3].set_xlabel('Processed Time (Normalized)')
    axes[1, 3].set_xscale('symlog', linthresh=0.1)
    axes[1, 3].legend()
    axes[1, 3].grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
    plt.show()


# --- Perform Visualizations ---
print("Plotting sample images from processed tensor...")
# Plot using original global indices
plot_sample_from_tensor(ALL_X_PROC_DEVICE, ALL_Y_DEVICE, 0, "Photon", title_suffix="(Processed)")
plot_sample_from_tensor(ALL_X_PROC_DEVICE, ALL_Y_DEVICE, num_photons, "Electron", title_suffix="(Processed)")

# Plot the comparison histograms
# Assumes ALL_X_PRENORM_DEVICE was returned by the modified preprocess_full_tensor
plot_histograms_comparison(
    ALL_X_PRENORM_DEVICE,
    ALL_X_PROC_DEVICE,
    ALL_Y_DEVICE,
    title_prefix="Pre- vs. Post-Normalization"
)

In [5]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False); self.bn1 = nn.BatchNorm2d(planes); self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False); self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample; self.stride = stride
    def forward(self, x):
        identity = x; out = self.conv1(x); out = self.bn1(out); out = self.relu(out); out = self.conv2(out); out = self.bn2(out)
        if self.downsample is not None: identity = self.downsample(x)
        out += identity; out = self.relu(out); return out

class ResNet15(nn.Module):
    def __init__(self, block=BasicBlock, layers=[3, 2, 2], num_classes=1, in_channels=2, dropout_prob=0.3):
        super(ResNet15, self).__init__()
        self.in_planes = 64
        # Input: (B, C, H, W) = (B, 2, 32, 32)
        self.conv1 = nn.Conv2d(in_channels, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) # Output: (B, 64, 32, 32)
        self.bn1 = nn.BatchNorm2d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        # No MaxPool

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1) # Output: (B, 64, 32, 32)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2) # Output: (B, 128, 16, 16)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # Output: (B, 256, 8, 8)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Output: (B, 256, 1, 1)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(256 * block.expansion, num_classes) # Output: (B, 1)

        # Weight Initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1); nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []; layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, num_blocks): layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Expects input x: (B, C, H, W)
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        # Return raw logits
        return x

# Instantiate the model
model = ResNet15().to(device)
print(f"Model Architecture:\n{model}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal Parameters: {total_params:,}"); print(f"Trainable Parameters: {trainable_params:,}")

Model Architecture:
ResNet15(
  (conv1): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), pad

In [6]:
# --- Data Splitting (Indices for pre-loaded tensors) ---
indices = np.arange(total_samples); np.random.shuffle(indices)
test_split_idx = int(np.floor(TEST_SPLIT * total_samples)); val_split_idx = int(np.floor((VAL_SPLIT + TEST_SPLIT) * total_samples))
test_indices = indices[:test_split_idx]; val_indices = indices[test_split_idx:val_split_idx]; train_indices = indices[val_split_idx:]
print(f"\nDataset Split Indices:"); print(f"  Training: {len(train_indices)} indices"); print(f"  Validation: {len(val_indices)} indices"); print(f"  Test: {len(test_indices)} indices")

# --- Create TensorDatasets for each split ---
# Index the pre-processed tensors using the split indices
train_dataset = TensorDataset(ALL_X_PROC_DEVICE[train_indices], ALL_Y_DEVICE[train_indices])
val_dataset = TensorDataset(ALL_X_PROC_DEVICE[val_indices], ALL_Y_DEVICE[val_indices])
test_dataset = TensorDataset(ALL_X_PROC_DEVICE[test_indices], ALL_Y_DEVICE[test_indices])

# --- Create Standard DataLoaders ---
persist_workers = NUM_WORKERS > 0

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, # Use standard shuffle
                          num_workers=NUM_WORKERS,
                          persistent_workers=persist_workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS,
                        persistent_workers=persist_workers)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS,
                         persistent_workers=persist_workers)

print(f"\nDataLoaders created: Batch Size={BATCH_SIZE}, Num Workers={NUM_WORKERS}, Persistent={persist_workers}")

# --- Loss, Optimizer, Scheduler, Scaler ---
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3, verbose=True)
scaler = GradScaler(enabled=USE_MIXED_PRECISION)

# --- Load Checkpoint Logic ---
# (Same as your analysis3.md cell 7)
start_epoch = 0; # ... (rest of checkpoint loading logic)
if os.path.exists(MODEL_SAVE_PATH): # ...
    print(f"\nLoading model weights from {MODEL_SAVE_PATH}..."); # ...
    try: # ...
        model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device)); print("Model weights loaded.")
        if os.path.exists(OPTIMIZER_SAVE_PATH): # ...
             print(f"Loading optimizer/scheduler state from {OPTIMIZER_SAVE_PATH}..."); # ...
             checkpoint = torch.load(OPTIMIZER_SAVE_PATH, map_location='cpu'); optimizer.load_state_dict(checkpoint['optimizer']); scheduler.load_state_dict(checkpoint['scheduler']); start_epoch = checkpoint.get('epoch', 0) + 1
             print(f"Optimizer and scheduler state loaded. Resuming from epoch {start_epoch}")
        else: print("Optimizer state file not found, initializing optimizer from scratch.")
    except Exception as e: print(f"WARNING: Could not load checkpoint: {e}. Training from scratch."); start_epoch = 0
else: print("\nNo checkpoint found. Starting training from scratch.")


Dataset Split Indices:
  Training: 348600 indices
  Validation: 49800 indices
  Test: 99600 indices

DataLoaders created: Batch Size=512, Num Workers=0, Persistent=False

Loading model weights from electron_photon_resnet15_v4_ram.pth...
Model weights loaded.
Loading optimizer/scheduler state from electron_photon_resnet15_v4_ram.pth.opt...
Optimizer and scheduler state loaded. Resuming from epoch 3




In [7]:
# --- GPU Augmentation Function ---
@torch.no_grad()
def augment_batch_gpu(batch_x, use_augment):
    """Applies augmentations to a batch on the GPU."""
    if not use_augment:
        return batch_x

    # Add Gaussian noise
    if random.random() < 0.5:
         noise = torch.randn_like(batch_x) * 0.1
         batch_x = batch_x + noise # Not in-place

    # Add other GPU augmentations here (e.g., kornia for shifts/rotations)
    # Example using torch.roll for tiny shifts:
    # if random.random() < 0.2:
    #    shift_h = random.randint(-1, 1)
    #    shift_w = random.randint(-1, 1)
    #    batch_x = torch.roll(batch_x, shifts=(shift_h, shift_w), dims=(2, 3))

    return batch_x

# --- Training and Validation Functions ---
def train_one_epoch(model, loader, criterion, optimizer, scaler, device, use_amp, use_augment):
    model.train(); total_loss = 0.0; total_correct = 0; total_samples = 0; loop_start_time = time.time()

    for i, (batch_x, batch_y) in enumerate(loader): # Data is already on device from TensorDataset
        batch_start_time = time.time()
        # Data batch_x, batch_y are already on the correct device

        # --- Apply Augmentation on GPU ---
        batch_x = augment_batch_gpu(batch_x, use_augment)
        # --------------------------------

        optimizer.zero_grad(set_to_none=True)
        with autocast(device_type=device.type, enabled=use_amp):
            outputs = model(batch_x) # Logits
            loss = criterion(outputs, batch_y)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()

        predicted = (torch.sigmoid(outputs) > 0.5).float(); correct = (predicted == batch_y).float().sum().item()
        total_correct += correct; batch_size = batch_y.size(0); total_samples += batch_size; total_loss += loss.item() * batch_size
        batch_end_time = time.time()
        if i % 100 == 0: print(f"  Batch {i:>4}/{len(loader)}, Loss: {loss.item():.4f}, Acc: {correct/batch_size:.4f}, Batch Time: {(batch_end_time - batch_start_time)*1000:.1f}ms")

    avg_loss = total_loss / total_samples; avg_acc = total_correct / total_samples; print(f"  Epoch Training Time: {time.time() - loop_start_time:.2f}s")
    return avg_loss, avg_acc

def validate_one_epoch(model, loader, criterion, device, use_amp):
    model.eval(); total_loss = 0.0; total_correct = 0; total_samples = 0; val_start_time = time.time()
    with torch.no_grad():
        for batch_x, batch_y in loader: # Data is already on device
            # No augmentation during validation
            with autocast(device_type=device.type, enabled=use_amp):
                outputs = model(batch_x); loss = criterion(outputs, batch_y)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total_correct += (predicted == batch_y).float().sum().item()
            total_samples += batch_y.size(0); total_loss += loss.item() * batch_y.size(0)
    avg_loss = total_loss / total_samples; avg_acc = total_correct / total_samples; print(f"  Validation Time: {time.time() - val_start_time:.2f}s")
    return avg_loss, avg_acc


# --- Training Loop ---
history={'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}; best_val_loss = float('inf')
if start_epoch > 0 and 'best' in scheduler.state_dict(): best_val_loss = scheduler.state_dict()['best'] if scheduler.state_dict()['best'] is not None else float('inf'); print(f"Resuming with best_val_loss = {best_val_loss:.4f}")

print("\n--- Starting Training (RAM Loaded Data) ---"); epochs_to_run = NUM_EPOCHS - start_epoch
if epochs_to_run <= 0: print("Training already completed.")
else: print(f"Running for {epochs_to_run} epochs (from epoch {start_epoch} to {NUM_EPOCHS-1})")

for epoch in range(start_epoch, NUM_EPOCHS):
    epoch_start_time = time.time(); print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}, Current LR: {optimizer.param_groups[0]['lr']:.6f}")
    # Training
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, USE_MIXED_PRECISION, USE_AUGMENTATION)
    history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
    # Validation
    val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device, USE_MIXED_PRECISION)
    history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
    epoch_duration = time.time() - epoch_start_time; print("-" * 60); print(f"Epoch {epoch+1} Summary:"); print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}"); print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"); print(f"  Duration: {epoch_duration:.2f}s"); print("-" * 60)
    scheduler.step(val_loss) # Scheduler Step
    if val_loss < best_val_loss: # Save best model
        print(f"Validation loss improved ({best_val_loss:.4f} -> {val_loss:.4f}). Saving model..."); best_val_loss = val_loss; torch.save(model.state_dict(), MODEL_SAVE_PATH); checkpoint = {'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'epoch': epoch}; torch.save(checkpoint, OPTIMIZER_SAVE_PATH)

print("\n--- Training Finished ---")

Resuming with best_val_loss = 0.5433

--- Starting Training (RAM Loaded Data) ---
Running for 27 epochs (from epoch 3 to 29)

Epoch 4/30, Current LR: 0.000500
  Batch    0/680, Loss: 0.5278, Acc: 0.7520, Batch Time: 597.0ms
  Batch  100/680, Loss: 0.5718, Acc: 0.7188, Batch Time: 148.2ms
  Batch  200/680, Loss: 0.5072, Acc: 0.7402, Batch Time: 147.4ms
  Batch  300/680, Loss: 0.5145, Acc: 0.7578, Batch Time: 147.4ms
  Batch  400/680, Loss: 0.5327, Acc: 0.7344, Batch Time: 148.3ms
  Batch  500/680, Loss: 0.5226, Acc: 0.7383, Batch Time: 147.8ms
  Batch  600/680, Loss: 0.5795, Acc: 0.7051, Batch Time: 148.5ms
  Epoch Training Time: 104.22s
  Validation Time: 5.01s
------------------------------------------------------------
Epoch 4 Summary:
  Train Loss: 0.5501, Train Acc: 0.7270
  Val Loss:   0.5500, Val Acc:   0.7246
  Duration: 109.23s
------------------------------------------------------------

Epoch 5/30, Current LR: 0.000500
  Batch    0/680, Loss: 0.5410, Acc: 0.7461, Batch Time: 

KeyboardInterrupt: 