# Adaptive Quantile Regression: Visualization

This notebook loads the best-trained `QuantileRegressor` model and visualizes its performance on the test set. It performs the following steps:

1.  **Load Configuration**: Reads the `config.yaml` file to ensure settings are consistent with the training run.
2.  **Load Model**: Loads the best model checkpoint saved by PyTorch Lightning.
3.  **Load Data**: Re-initializes the `PolynomialDataModule` to get access to the test dataset and the true polynomial coefficients.
4.  **Generate Predictions**: For a set of specified quantiles (e.g., 0.05, 0.50, 0.95), it uses the model to predict the corresponding `y` values for the test `x` values.
5.  **Plot Results**: Creates a comprehensive plot showing:
    * The raw test data points (as a scatter plot).
    * The true underlying polynomial function (as a dashed line).
    * The predicted quantile curves (as solid lines).

In [None]:
import torch
import pytorch_lightning as pl

import numpy as np
import yaml
import matplotlib.pyplot as plt
import glob
import os

# Import our custom modules
from src.model import QuantileRegressor
from src.data import PolynomialDataModule

### 1. Load Configuration and Find Best Checkpoint

In [None]:
# Load the configuration file
config_path = 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

# Find the best model checkpoint automatically
checkpoint_dir = config['training']['checkpoint_dir']
list_of_files = glob.glob(os.path.join(checkpoint_dir, '*.ckpt'))
if not list_of_files:
    raise FileNotFoundError(f"No checkpoint files found in '{checkpoint_dir}'. Please run train.py first.")

latest_file = max(list_of_files, key=os.path.getctime)
checkpoint_path = latest_file

print(f"Loaded configuration from: {config_path}")
print(f"Found best model checkpoint at: {checkpoint_path}")

### 2. Load Model and DataModule

In [None]:
# Load the trained model from the checkpoint
model = QuantileRegressor.load_from_checkpoint(checkpoint_path)
model.eval() # Set the model to evaluation mode (disables dropout, etc.)

# Instantiate the DataModule with the same parameters used for training
# This ensures we get the same data distribution and true function
pl.seed_everything(config['seed']) # Use the same seed for reproducibility
data_module = PolynomialDataModule(
    degree=config['data']['degree'],
    noise_scale=config['data']['noise_scale'],
    n_samples=config['data']['n_samples'],
    batch_size=config['training']['batch_size']
)

# The .setup() method generates the data and performs the train/val/test splits
data_module.setup()

# Get the test data tensors.
# NOTE: random_split creates a 'Subset' object. We must access the
# underlying .dataset and use the subset's .indices to get the correct data.
full_dataset = data_module.test_dataset.dataset
test_indices = data_module.test_dataset.indices
test_x = full_dataset.tensors[0][test_indices]
test_y = full_dataset.tensors[1][test_indices]

### 3. Generate Predictions and Plot Results

In [None]:
# Set up the plot
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 8))

# 1. Plot the raw test data points
ax.scatter(test_x, test_y, alpha=0.15, label='Test Data Points', color='gray', s=20)

# Define the quantiles we want to visualize
quantiles_to_plot = [0.05, 0.25, 0.50, 0.75, 0.95]
colors = plt.cm.viridis(np.linspace(0.1, 0.9, len(quantiles_to_plot)))

# For a clean line plot, we need to sort the x values
sorted_indices = torch.argsort(test_x.squeeze())
sorted_x = test_x[sorted_indices]

# Move data to the same device as the model (e.g., 'cuda' or 'mps')
device = model.device
sorted_x = sorted_x.to(device)

# 2. Generate and plot the predicted quantile lines
with torch.no_grad(): # Deactivate autograd for faster inference
    for i, q in enumerate(quantiles_to_plot):
        # Create a tensor of the current quantile, ON THE CORRECT DEVICE
        tau = torch.full((sorted_x.shape[0], 1), q, device=device)
        
        # Get model predictions
        predictions = model(sorted_x, tau)
        
        # Plot the line (move predictions to CPU for numpy/matplotlib)
        label = f'Predicted q={q:.2f}'
        if q == 0.50:
            label += ' (Median)'
        ax.plot(sorted_x.cpu().numpy(), predictions.cpu().numpy(), label=label, color=colors[i], linewidth=2.5)

# 3. Plot the true underlying function (without noise)
x_range = config['data']['x_range']
true_x_line = np.linspace(x_range[0], x_range[1], 400)
true_y_line = np.polyval(data_module.coeffs, true_x_line)
ax.plot(true_x_line, true_y_line, color='red', linestyle='--', linewidth=3, label='True Function (Ground Truth)')

# Final plot styling
ax.set_title('Adaptive Quantile Regression Results', fontsize=18, fontweight='bold')
ax.set_xlabel('Input Feature (x)', fontsize=14)
ax.set_ylabel('Output (y)', fontsize=14)
ax.legend(fontsize=12)
ax.set_ylim(test_y.min() - 5, test_y.max() + 5) # Adjust y-limits for better visibility
plt.tight_layout()
plt.show()