# Model Evaluation on the Test Set

This notebook evaluates the performance of the trained CPT Foundation Model on the held-out test set. The evaluation process consists of the following steps:

1.  **Load Configuration and Model**: Load the same configuration file used for training and reinstantiate the model architecture. Load the trained weights from the saved `.pth` file.
2.  **Load Test Data**: Use the `CPTDataModule` to get the `DataLoader` for the test split. This ensures we use the exact same data preprocessing and splits.
3.  **Perform Inference**:
    *   Iterate through the test set.
    *   For each CPT profile, apply the same masking strategy used during training.
    *   Feed the corrupted (masked) data to the model to get the reconstructions.
4.  **Calculate Loss**: Compute the Mean Squared Error (MSE) between the model's predictions and the true values **only for the masked tokens**. This tells us how well the model can "fill in the blanks."
5.  **Visualize Results**: Plot a few examples from the test set, showing:
    *   The original, complete CPT data.
    *   The corrupted data with masked portions that were fed to the model.
    *   The model's reconstructed output.

This provides both a quantitative (MSE) and qualitative (visualization) assessment of the model's pre-training performance.

In [1]:
import os
import yaml
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Change working directory to project root to handle relative paths correctly
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')

# Make sure the script can find the src modules
import sys
sys.path.append(os.path.abspath('src'))

from data_utils import CPTDataModule
from model import CPTFoundationModel

In [2]:
# Set a random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
print(f"Using random seed: {SEED}")

Using random seed: 42


### 1. Load Configuration and Set Up Device
We'll load the `PG_dataset.yaml` to ensure all our parameters (model dimensions, paths, etc.) are consistent with the training setup.

In [3]:
CONFIG_PATH = 'configs/PG_dataset.yaml'

# Load the YAML configuration file
try:
    with open(CONFIG_PATH, 'r') as f:
        config = yaml.safe_load(f)
    print("Configuration file loaded successfully.")
except FileNotFoundError:
    print(f"Error: Configuration file not found at '{CONFIG_PATH}'")
    config = None
except Exception as e:
    print(f"Error loading configuration file: {e}")
    config = None

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Configuration file loaded successfully.
Using device: cuda


### 2. Load the Trained Model
Instantiate the model with the parameters from the config file and then load the saved weights from the training process. It's crucial to set the model to evaluation mode using `.eval()` to disable layers like Dropout.

In [4]:
if config:
    model_params = config['model_params']
    paths = config['data_paths']
    
    # Initialize the model
    model = CPTFoundationModel(
        num_features=model_params['num_features'],
        model_dim=model_params['model_dim'],
        num_heads=model_params['num_heads'],
        num_layers=model_params['num_layers']
    ).to(device)

    # Load the saved model checkpoint
    model_path = paths['model_save_path']
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval() # Set the model to evaluation mode
        print(f"Model loaded successfully from '{model_path}'")
        print(f"Trained for {checkpoint.get('epoch', 'N/A')} epochs with a final loss of {checkpoint.get('loss', 'N/A'):.6f}")
    else:
        print(f"Error: Model file not found at '{model_path}'")
        model = None

Model loaded successfully from 'models/foundation_model_PG.pth'
Trained for 199 epochs with a final loss of 0.039218


In [5]:
# Load the scaler used during preprocessing
import joblib

scaler = None
if config:
    scaler_path = config['data_paths'].get('scaler_path') # Use .get for safety
    if scaler_path and os.path.exists(scaler_path):
        scaler = joblib.load(scaler_path)
        print(f"Scaler loaded from '{scaler_path}'")
    else:
        print("Warning: Scaler file not found. Visualization will show scaled values.")

Scaler loaded from 'data/processed/PG/scaler.joblib'


### 3. Load the Test Dataset
We use our `CPTDataModule` to handle the data setup. It will automatically find the processed data and load the correct test set based on the `test_ids.txt` file.

In [6]:
if config:
    print("Setting up data module for the test set...")
    data_module = CPTDataModule(config)
    data_module.setup() # This will set up train, val, and test datasets
    
    # Get the DataLoader for the test set
    test_loader = data_module.get_dataloader(stage='test', shuffle=False)
    print("Test data loaded successfully.")

Setting up data module for the test set...
Found existing processed data in 'data/processed/PG'. Delete to reprocess.
Processing 1071 files with max_len=1028 and overlap=256...
Processing 1071 files with max_len=1028 and overlap=256...


Loading and Chunking Data: 100%|██████████| 1071/1071 [00:22<00:00, 46.92it/s]



Processing 133 files with max_len=1028 and overlap=256...


Loading and Chunking Data: 100%|██████████| 133/133 [00:02<00:00, 46.41it/s]



Processing 135 files with max_len=1028 and overlap=256...


Loading and Chunking Data: 100%|██████████| 135/135 [00:02<00:00, 46.17it/s]

Train dataset size: 2813
Validation dataset size: 353
Test dataset size: 362
Test data loaded successfully.





### 4. Run Evaluation and Calculate Loss
Now we'll loop through the test set. In `torch.no_grad()` mode, we perform the forward pass to get the model's reconstructions and calculate the MSE loss on the masked values. We'll also store some examples for visualization later.

In [7]:
# Store results for visualization
visualization_results = []

if config and model:
    total_mse = 0
    total_masked_tokens = 0
    loss_fn = torch.nn.MSELoss(reduction='sum')
    
    # Get mask ratio from config
    mask_ratio = config.get('training_params', {}).get('mask_ratio', 0.15)

    with torch.no_grad():
        pbar = tqdm(test_loader, desc="Evaluating on Test Set")
        for i, (batch, attention_mask) in enumerate(pbar):
            batch = batch.to(device)
            attention_mask = attention_mask.to(device)
            
            # --- Create a corrupted version of the input batch ---
            corrupted_batch = batch.clone()
            prob_mask = torch.rand(batch.shape[:2], device=device)
            masking_condition = (prob_mask < mask_ratio) & (attention_mask == 1)
            
            num_masked = masking_condition.sum().item()
            if num_masked == 0:
                continue # Skip batches where no tokens are masked
            
            corrupted_batch[masking_condition] = 0.0

            # --- Forward Pass (Updated Logic) ---
            # 1. Get contextual embeddings from the encoder
            contextual_embeddings = model(corrupted_batch, attention_mask)
            
            # 2. Select only the embeddings for masked tokens
            masked_embeddings = contextual_embeddings[masking_condition]
            
            # 3. Get predictions ONLY for the masked tokens
            masked_predictions = model.output_projection(masked_embeddings)

            # --- Calculate Loss on Masked Tokens ---
            loss = loss_fn(masked_predictions, batch[masking_condition])
            total_mse += loss.item()
            total_masked_tokens += num_masked
            
            # --- Store a few examples for visualization (Updated Logic) ---
            if i < 5: # Store first 5 batches for potential visualization
                # For visualization, we need to reconstruct the full sequence
                # Start with the original batch and fill in the model's predictions
                reconstructed_batch = batch.clone()
                reconstructed_batch[masking_condition] = masked_predictions
                
                visualization_results.append({
                    'original': batch.cpu().numpy(),
                    'masked': corrupted_batch.cpu().numpy(),
                    'predicted': reconstructed_batch.cpu().numpy(), # Use the reconstructed full sequence
                    'mask': masking_condition.cpu().numpy(),
                    'attention_mask': attention_mask.cpu().numpy()
                })

    # Calculate the final average MSE across all masked tokens
    average_mse = total_mse / total_masked_tokens if total_masked_tokens > 0 else 0
    print(f"\nEvaluation Complete.")
    print(f"Average MSE on masked tokens in the test set: {average_mse:.6f}")

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)
Evaluating on Test Set: 100%|██████████| 46/46 [00:02<00:00, 17.48it/s]


Evaluation Complete.
Average MSE on masked tokens in the test set: 0.012343





### 5. Visualize Reconstruction Results
A quantitative metric like MSE is useful, but a qualitative visualization can provide deeper insight into the model's behavior.

The following function plots a single CPT profile, comparing the ground truth, the masked input, and the model's reconstruction. We will focus on the first two numerical features, which are typically `qc` and `fs`. The masked regions in the reconstruction plot are highlighted in red to show exactly where the model was tasked with predicting.

In [8]:
def visualize_reconstruction(original, masked, predicted, actual_mask, attention, scaler, feature_names=["qc", "fs"]):
    """
    Plots a comparison of original, masked, and predicted CPT profiles for key features.
    
    Args:
        original (np.array): The original data (single CPT).
        masked (np.array): The data after masking (single CPT).
        predicted (np.array): The model's output reconstruction (single CPT).
        actual_mask (np.array): Boolean mask showing which tokens were masked.
        attention (np.array): Attention mask showing real vs. padding data.
        scaler (StandardScaler): The fitted scaler object for inverse transformation.
    """
    # Find the actual length of the sequence before padding
    seq_len = int(attention.sum())
    
    # Trim all data to the actual sequence length
    original = original[:seq_len]
    predicted = predicted[:seq_len]
    actual_mask = actual_mask[:seq_len]

    # --- Inverse transform the data to its original scale for visualization ---
    if scaler:
        # The scaler was likely fit on only the numerical features.
        # We assume the number of features in the scaler matches what we want to plot.
        num_numerical_features = scaler.n_features_in_
        
        original_unscaled = scaler.inverse_transform(original[:, :num_numerical_features])
        predicted_unscaled = scaler.inverse_transform(predicted[:, :num_numerical_features])
    else:
        # If no scaler, plot the raw (scaled) data
        original_unscaled = original
        predicted_unscaled = predicted

    # Use the feature_map to get the correct feature names and indices
    num_features = len(feature_names)
    
    fig = make_subplots(rows=1, cols=num_features, subplot_titles=[f'Feature: {name} (Unscaled)' for name in feature_names])
    
    depth = np.arange(seq_len)

    for name in feature_names:
        i = feature_names.index(name)  # Get the correct column index for the feature
        # Plot Original Data
        fig.add_trace(go.Scatter(x=original_unscaled[:, i], y=depth, mode='lines', name='Original', line=dict(color='blue')), row=1, col=i+1)
        
        # Plot Model's Reconstruction
        fig.add_trace(go.Scatter(x=predicted_unscaled[:, i], y=depth, mode='lines', name='Reconstructed', line=dict(color='green')), row=1, col=i+1)
        
        # Highlight the ORIGINAL values at the masked locations
        masked_indices = np.where(actual_mask)[0]
        fig.add_trace(go.Scatter(
            x=original_unscaled[masked_indices, i], 
            y=depth[masked_indices], 
            mode='markers', 
            name='Ground Truth (Masked)', 
            marker=dict(color='red', size=6, symbol='x')
        ), row=1, col=i+1)

    fig.update_yaxes(autorange="reversed", title_text="Depth Index")
    fig.update_layout(
        title_text='Model Reconstruction vs. Original Data (Unscaled)',
        height=600,
        width=900
    )
    fig.show()

# Let's visualize the first CPT profile from the first stored batch
if visualization_results:
    first_batch = visualization_results[0]
    # Get the first CPT from the batch
    cpt_index = 0 
    
    original_cpt = first_batch['original'][cpt_index]
    masked_cpt = first_batch['masked'][cpt_index]
    predicted_cpt = first_batch['predicted'][cpt_index]
    mask_cpt = first_batch['mask'][cpt_index]
    attention_cpt = first_batch['attention_mask'][cpt_index]
    
    # Pass the loaded scaler and feature_map to the visualization function
    visualize_reconstruction(original_cpt, masked_cpt, predicted_cpt, mask_cpt, attention_cpt, scaler, feature_names=["qc", "fs"])
else:
    print("No results available for visualization.")

You can re-run the cell above and change `cpt_index` to see other examples from the first batch, or change `visualization_results[0]` to `visualization_results[1]` to inspect results from a different batch.