# py_visualize: Interactive Model Visualization

This notebook demonstrates the **py_visualize** package, which provides interactive Plotly-based visualizations for model evaluation, diagnostics, and hyperparameter tuning.

## Functions Covered

1. **plot_forecast()** - Time series forecast visualization
2. **plot_residuals()** - Diagnostic plots for model validation
3. **plot_model_comparison()** - Compare multiple models
4. **plot_tune_results()** - Hyperparameter tuning visualization

In [None]:
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# Import py-tidymodels packages
from py_parsnip import linear_reg, rand_forest, prophet_reg
from py_recipes import recipe
from py_workflows import workflow
from py_rsample import initial_time_split, time_series_cv
from py_tune import tune, tune_grid, grid_regular
from py_visualize import plot_forecast, plot_residuals, plot_model_comparison, plot_tune_results

print("✓ All packages imported successfully")

## Setup: Create Sample Time Series Data

We'll create a time series with trend, seasonality, and noise.

In [None]:
# Create time series data
np.random.seed(42)
dates = pd.date_range('2020-01-01', periods=500, freq='D')
time_index = np.arange(len(dates))

# Trend + seasonality + noise
trend = time_index * 0.5
seasonality = 10 * np.sin(2 * np.pi * time_index / 30)
noise = np.random.randn(len(dates)) * 5

y = trend + seasonality + noise + 100

data = pd.DataFrame({
    'date': dates,
    'value': y
})

print(f"Created time series with {len(data)} observations")
print(f"Date range: {data['date'].min()} to {data['date'].max()}")
data.head()

## 1. plot_forecast() - Time Series Forecasting

Create interactive forecast plots with prediction intervals.

In [None]:
# Split data
split = initial_time_split(data, prop=0.8)
train_data = split.training()
test_data = split.testing()

print(f"Training: {len(train_data)} observations")
print(f"Testing: {len(test_data)} observations")

In [None]:
# Create and fit a linear regression model with lags
rec = (
    recipe(value ~ date, data=train_data)
    .step_date('date', features=['month', 'week', 'doy'])
    .step_lag('value', lags=[1, 7, 30])
    .step_normalize(['value_lag_1', 'value_lag_7', 'value_lag_30'])
)

wf = (
    workflow()
    .add_recipe(rec)
    .add_model(linear_reg())
)

# Fit model
fit = wf.fit(train_data)

# Generate predictions
predictions = fit.predict(test_data)

print("✓ Model fitted and predictions generated")

In [None]:
# Plot forecast
fig = plot_forecast(
    fit,
    prediction_intervals=True,
    title="Linear Regression Forecast with Lags",
    height=500
)

fig.show()

print("\n📊 The plot shows:")
print("  • Blue line: Training data (actual values)")
print("  • Red line: Test data (actual values)")
print("  • Green line: Model predictions")
print("  • Shaded region: 95% prediction intervals")

## 2. plot_residuals() - Diagnostic Plots

Check model assumptions with comprehensive diagnostic plots.

In [None]:
# Plot all diagnostics (2x2 grid)
fig = plot_residuals(
    fit,
    plot_type="all",
    title="Model Diagnostics: Linear Regression",
    height=700,
    width=900
)

fig.show()

print("\n📊 Diagnostic plots:")
print("  • Top-left: Residuals vs Fitted (check for patterns)")
print("  • Top-right: Q-Q plot (check normality)")
print("  • Bottom-left: Residuals vs Time (check for autocorrelation)")
print("  • Bottom-right: Histogram (check distribution)")

### Individual Diagnostic Plots

You can also create individual diagnostic plots:

In [None]:
# Residuals vs Fitted only
fig_fitted = plot_residuals(
    fit,
    plot_type="fitted",
    title="Residuals vs Fitted Values"
)

fig_fitted.show()

In [None]:
# Q-Q plot only
fig_qq = plot_residuals(
    fit,
    plot_type="qq",
    title="Normal Q-Q Plot"
)

fig_qq.show()

## 3. plot_model_comparison() - Compare Multiple Models

Compare performance metrics across different models.

In [None]:
# Fit multiple models for comparison

# Model 1: Linear Regression with lags
wf_linear = (
    workflow()
    .add_recipe(rec)
    .add_model(linear_reg())
)
fit_linear = wf_linear.fit(train_data)
pred_linear = fit_linear.predict(test_data)

# Model 2: Random Forest
wf_rf = (
    workflow()
    .add_recipe(rec)
    .add_model(rand_forest(trees=50, mode='regression'))
)
fit_rf = wf_rf.fit(train_data)
pred_rf = fit_rf.predict(test_data)

# Model 3: Ridge Regression (linear_reg with penalty)
wf_ridge = (
    workflow()
    .add_recipe(rec)
    .add_model(linear_reg(penalty=0.1, mixture=0.0))  # Ridge
)
fit_ridge = wf_ridge.fit(train_data)
pred_ridge = fit_ridge.predict(test_data)

print("✓ Three models fitted successfully")

In [None]:
# Extract stats DataFrames
_, _, stats_linear = fit_linear.extract_outputs()
_, _, stats_rf = fit_rf.extract_outputs()
_, _, stats_ridge = fit_ridge.extract_outputs()

# Create bar chart comparison
fig = plot_model_comparison(
    stats_list=[stats_linear, stats_rf, stats_ridge],
    model_names=["Linear Regression", "Random Forest", "Ridge Regression"],
    metrics=["rmse", "mae", "r_squared"],
    split="test",
    plot_type="bar",
    title="Model Performance Comparison",
    height=500
)

fig.show()

print("\n📊 Lower is better for RMSE and MAE")
print("📊 Higher is better for R²")

### Heatmap Comparison

Useful when comparing many models across many metrics:

In [None]:
# Heatmap view
fig_heatmap = plot_model_comparison(
    stats_list=[stats_linear, stats_rf, stats_ridge],
    model_names=["Linear Regression", "Random Forest", "Ridge Regression"],
    plot_type="heatmap",
    title="Model Performance Heatmap",
    height=400
)

fig_heatmap.show()

### Radar Chart

Compare models across multiple metrics simultaneously:

In [None]:
# Radar chart (metrics are normalized)
fig_radar = plot_model_comparison(
    stats_list=[stats_linear, stats_rf, stats_ridge],
    model_names=["Linear Regression", "Random Forest", "Ridge Regression"],
    plot_type="radar",
    title="Model Performance Radar Chart",
    height=500,
    width=600
)

fig_radar.show()

print("\n📊 Metrics are normalized to 0-1 scale")
print("📊 Larger area = better overall performance")

## 4. plot_tune_results() - Hyperparameter Tuning Visualization

Visualize how hyperparameters affect model performance.

In [None]:
# Create cross-validation splits
cv_splits = time_series_cv(
    train_data,
    initial=200,
    assess=50,
    skip=25,
    cumulative=False
)

print(f"Created {cv_splits.n_splits} CV splits")

### Single Parameter Tuning (Line Plot)

In [None]:
# Tune penalty parameter for linear regression
wf_tune_1d = (
    workflow()
    .add_recipe(rec)
    .add_model(linear_reg(penalty=tune(), mixture=1.0))  # Lasso
)

# Create parameter grid
grid_1d = grid_regular(
    penalty={'range': (0.001, 1.0), 'trans': 'log'},
    levels=8
)

# Tune
results_1d = tune_grid(
    wf_tune_1d,
    resamples=cv_splits,
    grid=grid_1d
)

print(f"✓ Tuned {len(grid_1d)} penalty values")

In [None]:
# Plot tuning results (line plot for single parameter)
fig = plot_tune_results(
    results_1d,
    metric="rmse",
    plot_type="line",
    show_best=3,
    title="Penalty Parameter Tuning (Lasso Regression)"
)

fig.show()

print("\n📊 Line plot shows how RMSE changes with penalty")
print("📊 Top 3 best configurations are highlighted")

### Two Parameters (Heatmap)

In [None]:
# Tune penalty and mixture for elastic net
wf_tune_2d = (
    workflow()
    .add_recipe(rec)
    .add_model(linear_reg(penalty=tune(), mixture=tune()))
)

# Create 2D grid
grid_2d = grid_regular(
    penalty={'range': (0.001, 0.5), 'trans': 'log'},
    mixture={'range': (0.0, 1.0)},
    levels=5
)

# Tune
results_2d = tune_grid(
    wf_tune_2d,
    resamples=cv_splits,
    grid=grid_2d
)

print(f"✓ Tuned {len(grid_2d)} penalty × mixture combinations")

In [None]:
# Plot as heatmap
fig_heatmap = plot_tune_results(
    results_2d,
    metric="rmse",
    plot_type="heatmap",
    show_best=5,
    title="Elastic Net Tuning: Penalty vs Mixture"
)

fig_heatmap.show()

print("\n📊 Heatmap shows RMSE across parameter combinations")
print("📊 Darker colors = lower RMSE (better)")
print("📊 Best configurations are marked with ⭐")

### Scatter Plot Matrix

In [None]:
# Scatter plot view (useful for visualizing correlations)
fig_scatter = plot_tune_results(
    results_2d,
    metric="rmse",
    plot_type="scatter",
    title="Scatter Plot: Penalty vs Mixture"
)

fig_scatter.show()

print("\n📊 Color indicates RMSE value")
print("📊 Useful for seeing parameter interaction effects")

### Three+ Parameters (Parallel Coordinates)

When tuning 3 or more parameters, parallel coordinates plot is most effective:

In [None]:
# Tune random forest with multiple parameters
wf_tune_multi = (
    workflow()
    .add_recipe(rec)
    .add_model(
        rand_forest(
            trees=tune(),
            min_n=tune(),
            mode='regression'
        )
    )
)

# Create multi-parameter grid
grid_multi = grid_regular(
    trees={'range': (50, 200)},
    min_n={'range': (2, 20)},
    levels=4
)

# Tune (this may take a minute)
print("Tuning random forest... (this may take a moment)")
results_multi = tune_grid(
    wf_tune_multi,
    resamples=cv_splits,
    grid=grid_multi
)

print(f"✓ Tuned {len(grid_multi)} parameter combinations")

In [None]:
# Plot as parallel coordinates
fig_parallel = plot_tune_results(
    results_multi,
    metric="rmse",
    plot_type="parallel",
    title="Random Forest Tuning: Trees and Min_n",
    height=500
)

fig_parallel.show()

print("\n📊 Each line represents one parameter configuration")
print("📊 Color indicates RMSE (darker = better)")
print("📊 Trace lines from left to right to see parameter combinations")

### Auto Plot Type Selection

Use `plot_type="auto"` to automatically select the best visualization:

In [None]:
# Auto-selects best plot type based on number of parameters
fig_auto = plot_tune_results(
    results_2d,
    metric="rmse",
    plot_type="auto",  # Automatically chooses heatmap for 2 parameters
    title="Auto-Selected Plot Type"
)

fig_auto.show()

print("\n📊 Auto-selected plot type based on:")
print("  • 1 parameter → Line plot")
print("  • 2 parameters → Heatmap")
print("  • 3+ parameters → Parallel coordinates")

## Summary

The **py_visualize** package provides four powerful visualization functions:

1. **plot_forecast()** - Interactive time series forecast plots
   - Train/test/forecast visualization
   - Prediction intervals
   - Support for grouped/nested models

2. **plot_residuals()** - Model diagnostic plots
   - Residuals vs Fitted
   - Q-Q plot
   - Residuals vs Time
   - Histogram

3. **plot_model_comparison()** - Multi-model comparison
   - Bar charts
   - Heatmaps
   - Radar charts

4. **plot_tune_results()** - Hyperparameter visualization
   - Line plots (1D)
   - Heatmaps (2D)
   - Parallel coordinates (3+D)
   - Scatter plot matrix
   - Auto plot type selection

All plots are interactive (Plotly) and customizable!