# Model Feature Importance Analysis
This notebook analyzes a saved PyTorch `.pth` model and computes feature importance using permutation importance, SHAP, and Integrated Gradients. It is designed to run inside the current workspace and save visualizations and results locally.


In [None]:
# Load Dependencies and Configure Environment
import os
import json
import math
import random
import sys
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Tuple, Dict, Optional
np.random.seed(42)
random.seed(42)
try:
    import torch
    from torch import nn
    torch.manual_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
except Exception as e:
    print('PyTorch unavailable:', e)
    torch = None
    nn = None
    device = 'cpu'
try:
    from sklearn.metrics import accuracy_score, roc_auc_score, mean_squared_error, mean_absolute_error
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
except Exception as e:
    print('scikit-learn unavailable:', e)
try:
    import shap
except Exception as e:
    print('SHAP unavailable:', e)
    shap = None
try:
    from captum.attr import IntegratedGradients
except Exception as e:
    print('Captum unavailable:', e)
    IntegratedGradients = None
from datetime import datetime
fig_dir = pathlib.Path('notebooks/figures')
fig_dir.mkdir(parents=True, exist_ok=True)
print(f"Using device: {device}")
print("Figures will be saved to:", fig_dir)
print("Workspace root:", pathlib.Path('.').resolve())
print("Variance-based importance idea: I_j = E[(ŷ(X) - ŷ(X^(j perm)))^2]")

In [None]:
# Load Dataset and Define Feature Columns
data_path = pathlib.Path('data/features.csv')
label_column = os.environ.get('LABEL_COLUMN', 'label')
feature_columns: Optional[List[str]] = None
if not data_path.exists():
    print(f"Dataset not found at {data_path}. Please place your features CSV there.")
else:
    df = pd.read_csv(data_path)
    print('Dataset shape:', df.shape)
    if label_column in df.columns:
        y = df[label_column].values
        X = df.drop(columns=[label_column])
    else:
        print(f"Label column '{label_column}' not found. Treating as unlabeled features.")
        X = df
        y = None
    feature_columns = list(X.columns)
    print('Feature columns:', len(feature_columns))
    # Train/Val/Test split if labels available
    if y is not None:
        X_train, X_tmp, y_train, y_tmp = train_test_split(X.values, y, test_size=0.4, random_state=42)
        X_val, X_test, y_val, y_test = train_test_split(X_tmp, y_tmp, test_size=0.5, random_state=42)
    else:
        X_train, X_val, X_test = X.values, None, X.values
        y_train, y_val, y_test = None, None, None
    # Scale features
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val) if X_val is not None else None
    X_test_scaled = scaler.transform(X_test) if X_test is not None else None

In [None]:
# Load .pth Model and Prepare for Inference
model_path = pathlib.Path(os.environ.get('MODEL_PATH', 'models/model.pth'))
class SimpleTabularNN(nn.Module):
    def __init__(self, in_dim: int, hidden: int = 64, out_dim: int = 1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x):
        return self.net(x)
def load_model(in_dim: int, out_dim: int = 1):
    if torch is None:
        print('PyTorch not available; skipping model load.')
        return None
    model = SimpleTabularNN(in_dim, hidden=64, out_dim=out_dim)
    if model_path.exists():
        sd = torch.load(model_path, map_location=device)
        # Handle either state_dict or entire model
        if isinstance(sd, dict) and 'state_dict' in sd:
            model.load_state_dict(sd['state_dict'])
        elif isinstance(sd, dict):
            try:
                model.load_state_dict(sd)
            except Exception as e:
                print('State dict load failed:', e)
        else:
            print('Unknown checkpoint format; expected state_dict dict.')
    else:
        print(f"Model file not found: {model_path}")
    model.to(device)
    model.eval()
    return model
def infer(model, X_np: np.ndarray):
    if model is None or torch is None:
        return None
    with torch.no_grad():
        x = torch.tensor(X_np, dtype=torch.float32, device=device)
        out = model(x).detach().cpu().numpy()
        return out.squeeze()

In [None]:
# Baseline Evaluation Metrics
task_type = os.environ.get('TASK_TYPE', 'regression')  # 'classification' or 'regression'
metrics = {}
if 'X_test_scaled' in globals() and X_test_scaled is not None:
    model = load_model(in_dim=X_test_scaled.shape[1], out_dim=1)
    y_pred = infer(model, X_test_scaled)
    if y_pred is not None:
        if task_type == 'classification' and y_test is not None:
            # Assume binary classification with logistic output; apply sigmoid
            y_prob = 1 / (1 + np.exp(-y_pred))
            y_hat = (y_prob >= 0.5).astype(int)
            metrics['accuracy'] = accuracy_score(y_test, y_hat)
            try:
                metrics['roc_auc'] = roc_auc_score(y_test, y_prob)
            except Exception:
                pass
        elif task_type == 'regression' and y_test is not None:
            metrics['mse'] = mean_squared_error(y_test, y_pred)
            metrics['mae'] = mean_absolute_error(y_test, y_pred)
        print('Baseline metrics:', json.dumps(metrics, indent=2))
    else:
        print('Inference failed (no predictions).')
else:
    print('Test set unavailable; skipping baseline metrics.')

In [None]:
# Permutation Feature Importance
def metric_func(y_true, y_pred):
    if y_true is None or y_pred is None:
        return np.nan
    if task_type == 'classification':
        y_prob = 1 / (1 + np.exp(-y_pred))
        return roc_auc_score(y_true, y_prob)
    else:
        return -mean_squared_error(y_true, y_pred)
perm_results = None
if 'X_test_scaled' in globals() and X_test_scaled is not None and feature_columns is not None:
    K = 5  # repeats
    base_pred = y_pred
    base_metric = metric_func(y_test, base_pred) if y_test is not None else np.nan
    deltas = []
    for j, col in enumerate(feature_columns):
        scores = []
        for k in range(K):
            Xp = X_test_scaled.copy()
            np.random.shuffle(Xp[:, j])
            yp = infer(model, Xp)
            m = metric_func(y_test, yp) if y_test is not None else np.nan
            scores.append(base_metric - m)
        deltas.append({'feature': col, 'importance': float(np.nanmean(scores))})
    perm_results = pd.DataFrame(deltas).sort_values('importance', ascending=False)
    print('Permutation importance top 10:')
    print(perm_results.head(10))
else:
    print('Permutation importance skipped (missing test set or feature names).')

In [None]:
# SHAP Feature Importance (Kernel/Deep Explainer)
shap_values = None
shap_summary = None
if shap is not None and model is not None and X_test_scaled is not None:
    try:
        sample_idx = np.random.choice(np.arange(X_test_scaled.shape[0]), size=min(256, X_test_scaled.shape[0]), replace=False)
        X_bg = X_train_scaled[np.random.choice(np.arange(X_train_scaled.shape[0]), size=min(128, X_train_scaled.shape[0]), replace=False)]
        X_sample = X_test_scaled[sample_idx]
        def f_predict(Xnp):
            return infer(model, Xnp)
        try:
            explainer = shap.DeepExplainer(model, torch.tensor(X_bg, dtype=torch.float32, device=device))
            sv = explainer.shap_values(torch.tensor(X_sample, dtype=torch.float32, device=device))
            shap_values = sv if isinstance(sv, np.ndarray) else sv[0]
        except Exception as _deep_e:
            explainer = shap.KernelExplainer(f_predict, X_bg)
            shap_values = explainer.shap_values(X_sample, nsamples=200)
        shap_values = np.array(shap_values)
        shap_summary = pd.DataFrame({
            'feature': feature_columns,
            'mean_abs_shap': np.mean(np.abs(shap_values), axis=0)[:len(feature_columns)]
        }).sort_values('mean_abs_shap', ascending=False)
        print('SHAP top 10:')
        print(shap_summary.head(10))
    except Exception as e:
        print('SHAP failed:', e)
else:
    print('SHAP not run (missing SHAP, model, or data).')

In [None]:
# Integrated Gradients with Captum
ig_summary = None
if IntegratedGradients is not None and model is not None and X_test_scaled is not None:
    try:
        ig = IntegratedGradients(model)
        batch = torch.tensor(X_test_scaled[:256], dtype=torch.float32, device=device)
        baseline = torch.zeros_like(batch)
        attributions, _ = ig.attribute(inputs=batch, baselines=baseline, target=None, return_convergence_delta=True)
        attributions = attributions.detach().cpu().numpy()
        ig_scores = np.mean(np.abs(attributions), axis=0)
        ig_summary = pd.DataFrame({
            'feature': feature_columns,
            'mean_abs_ig': ig_scores[:len(feature_columns)]
        }).sort_values('mean_abs_ig', ascending=False)
        print('Integrated Gradients top 10:')
        print(ig_summary.head(10))
    except Exception as e:
        print('Captum IntegratedGradients failed:', e)
else:
    print('Captum IG not run (missing Captum, model, or data).')

In [None]:
# Aggregate, Rank, and Visualize Importances
all_tables = []
if perm_results is not None:
    all_tables.append(perm_results.rename(columns={'importance': 'perm_importance'}))
if shap_summary is not None:
    all_tables.append(shap_summary.rename(columns={'mean_abs_shap': 'shap_importance'}))
if ig_summary is not None:
    all_tables.append(ig_summary.rename(columns={'mean_abs_ig': 'ig_importance'}))
combined = None
if all_tables:
    combined = all_tables[0]
    for t in all_tables[1:]:
        combined = combined.merge(t, on='feature', how='outer')
    # Normalize each column to [0,1]
    for col in ['perm_importance','shap_importance','ig_importance']:
        if col in combined:
            m = combined[col].max()
            if pd.notnull(m) and m > 0:
                combined[col] = combined[col] / m
    combined['avg_importance'] = combined[[c for c in ['perm_importance','shap_importance','ig_importance'] if c in combined]].mean(axis=1)
    combined = combined.sort_values('avg_importance', ascending=False)
    print('Combined importance (top 20):')
    print(combined.head(20))
    # Visualize
    top_n = combined.head(25)
    plt.figure(figsize=(10,6))
    sns.barplot(data=top_n, x='avg_importance', y='feature', orient='h')
    plt.title('Feature Importance (Aggregated)')
    plt.tight_layout()
    out_bar = fig_dir / 'importance_aggregated_bar.png'
    plt.savefig(out_bar)
    print('Saved:', out_bar)
    # Correlation heatmap among methods
    cols = [c for c in ['perm_importance','shap_importance','ig_importance'] if c in combined]
    if len(cols) >= 2:
        plt.figure(figsize=(6,5))
        corr = combined[cols].corr()
        sns.heatmap(corr, annot=True, vmin=-1, vmax=1, cmap='coolwarm')
        plt.title('Importance Methods Correlation')
        plt.tight_layout()
        out_heat = fig_dir / 'importance_methods_corr.png'
        plt.savefig(out_heat)
        print('Saved:', out_heat)
else:
    print('No importance tables to aggregate.')

In [None]:
# Save Artifacts and Results to Disk
results_dir = pathlib.Path('notebooks/results')
results_dir.mkdir(parents=True, exist_ok=True)
config = {
    'timestamp': datetime.now().isoformat(),
    'device': str(device),
    'model_path': str(model_path),
    'data_path': str(data_path),
    'task_type': task_type,
    'label_column': label_column,
    'feature_count': len(feature_columns) if feature_columns else 0,
}
(results_dir / 'config.json').write_text(json.dumps(config, indent=2))
if perm_results is not None:
    perm_results.to_csv(results_dir / 'importance_permutation.csv', index=False)
if shap_summary is not None:
    shap_summary.to_csv(results_dir / 'importance_shap.csv', index=False)
if ig_summary is not None:
    ig_summary.to_csv(results_dir / 'importance_integrated_gradients.csv', index=False)
if 'combined' in globals() and combined is not None:
    combined.to_csv(results_dir / 'importance_combined.csv', index=False)
print('Saved artifacts to:', results_dir)