In [5]:
import torch
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
from src.models.models import NeuralNetwork
from src.data.data_loader import load_data
from src.config.config import MODEL_PATHS, DEVICE
import joblib

In [6]:
class ModelWrapper:
    """Wrapper class to make PyTorch model compatible with SHAP."""
    def __init__(self, model):
        self.model = model
        
    def __call__(self, X):
        if not isinstance(X, torch.Tensor):
            X = torch.FloatTensor(X).to(DEVICE)
        self.model.eval()
        with torch.no_grad():
            return self.model(X).cpu().numpy()

In [7]:
def analyze_model(model, X, feature_names=None, sample_size=10):
    """
    Analyze model using SHAP values.
    
    Args:
        model: Neural network model
        X: Input features
        feature_names: List of feature names
        sample_size: Number of samples to use for SHAP analysis
    """
    # Create model wrapper for SHAP
    model_wrapper = ModelWrapper(model)
    
    # Select random samples for background
    if len(X) > sample_size:
        background_inds = np.random.choice(len(X), sample_size, replace=False)
        background_data = X[background_inds]
    else:
        background_data = X
    
    # Create SHAP explainer
    print("\nCreating SHAP explainer...")
    explainer = shap.DeepExplainer(model_wrapper, background_data)
    
    # Calculate SHAP values
    print("Calculating SHAP values...")
    shap_values = explainer.shap_values(X[:sample_size])
    
    # Create plots
    plt.figure(figsize=(12, 6))
    
    # Summary plot
    print("Creating summary plot...")
    plt.subplot(1, 2, 1)
    shap.summary_plot(shap_values, X[:sample_size], 
                     feature_names=feature_names,
                     plot_type="bar",
                     show=False)
    plt.title("Feature Importance (SHAP Values)")
    
    # Detailed SHAP values plot
    print("Creating detailed SHAP plot...")
    plt.subplot(1, 2, 2)
    shap.summary_plot(shap_values, X[:sample_size],
                     feature_names=feature_names,
                     show=False)
    plt.title("Feature Impact on Predictions")
    
    plt.tight_layout()
    plt.show()
    
    # Create dependence plots for top features
    print("\nCreating dependence plots for top features...")
    mean_abs_shap = np.abs(shap_values).mean(0)
    top_features = np.argsort(-mean_abs_shap)[:5]  # Top 5 features
    
    plt.figure(figsize=(15, 10))
    for i, feature_idx in enumerate(top_features, 1):
        plt.subplot(2, 3, i)
        feature_name = feature_names[feature_idx] if feature_names is not None else f"Feature {feature_idx}"
        shap.dependence_plot(feature_idx, shap_values, X[:sample_size],
                           feature_names=feature_names,
                           show=False)
        plt.title(f"Dependence Plot: {feature_name}")
    
    plt.tight_layout()
    plt.show()
    
    return shap_values, mean_abs_shap

In [8]:


def preprocess_data(X):
    """
    Preprocess the input data using the saved scaler.
    
    Args:
        X (pd.DataFrame): Input features
        
    Returns:
        np.ndarray: Preprocessed features
    """
    # Load scaler
    print("Loading scaler...")
    scaler = joblib.load(MODEL_PATHS['SCALER'])
    
    # Scale the input data
    print("Normalizing input data...")
    X_scaled = scaler.transform(X)
    
    return X_scaled


In [9]:
import torch
from src.models.models import NeuralNetwork
from src.config.config import DEVICE

def load_neural_network(model_path, input_dim=None):
    """
    Load a saved model with its hyperparameters for prediction.
    
    Args:
        model_path (str): Path to the saved model file
        input_dim (int, optional): Input dimension. If None, will be inferred from hyperparameters
        
    Returns:
        model: Loaded neural network model
        dict: Model hyperparameters
    """
    try:
        # Load the saved model info
        model_info = torch.load(model_path)
        
        if not isinstance(model_info, dict) or 'state_dict' not in model_info:
            raise ValueError("Model file does not contain the expected format")
        
        # Get hyperparameters
        hyperparameters = model_info.get('hyperparameters', {})
        
        # Create model with the same architecture
        if input_dim is None:
            # Try to infer input_dim from the state dict
            first_layer_weight = model_info['state_dict']['layer1.weight']
            input_dim = first_layer_weight.size(1)
        
        # Initialize model
        model = NeuralNetwork(input_dim).to(DEVICE)
        
        # Update model architecture based on hyperparameters
        if hyperparameters:
            # Update layer sizes if they were optimized
            if 'hidden_layer_0' in hyperparameters:
                model.layer1 = torch.nn.Linear(input_dim, hyperparameters['hidden_layer_0']).to(DEVICE)
                model.layer2 = torch.nn.Linear(hyperparameters['hidden_layer_0'], 
                                             hyperparameters['hidden_layer_1']).to(DEVICE)
                model.layer3 = torch.nn.Linear(hyperparameters['hidden_layer_1'], 
                                             hyperparameters['hidden_layer_2']).to(DEVICE)
                model.layer4 = torch.nn.Linear(hyperparameters['hidden_layer_2'], 1).to(DEVICE)
            
            # Update dropout if it was optimized
            if 'dropout_rate' in hyperparameters:
                model.dropout = torch.nn.Dropout(hyperparameters['dropout_rate'])
        
        # Load the state dict
        model.load_state_dict(model_info['state_dict'])
        
        # Set model to evaluation mode
        model.eval()
        
        return model, hyperparameters
        
    except Exception as e:
        raise Exception(f"Error loading model from {model_path}: {str(e)}")


In [10]:
X = pd.read_csv("x_test.csv")

X = np.array(X)

In [None]:
X, y = load_data()
feature_names = [f"Wavelength_{i}" for i in range(X.shape[1])]

# Analyze regular neural network
print("\nAnalyzing Regular Neural Network...")
regular_model = load_neural_network(MODEL_PATHS['NEURAL_NET'], X.shape[1])

regular_shap_values, regular_importance = analyze_model(
    regular_model, X, feature_names
)