# Model Interpretability: SHAP Analysis for MIT-BIH Deep Learning Models

SHAP (SHapley Additive exPlanations) analysis to understand which temporal features contribute to the CNN8 model's classification decisions.

**Analysis scope:**
- 200 test samples (40 per class) for explanation
- 100 background samples for reference distribution
- DeepExplainer for neural network interpretability

In [None]:
import numpy as np
import pandas as pd
import shap
import random
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model
from sklearn.metrics import classification_report, confusion_matrix

In [None]:
#import MIT data
df_mitbih_test = pd.read_csv('data/original/mitbih_test.csv', header = None)

X_train = pd.read_csv('data/processed/mitbih/X_train.csv')
y_train = pd.read_csv('data/processed/mitbih/y_train.csv')
y_train = y_train['187']

X_train_sm = pd.read_csv('data/processed/mitbih/X_train_sm.csv')
y_train_sm = pd.read_csv('data/processed/mitbih/y_train_sm.csv')
y_train_sm = y_train_sm['187']

X_val = pd.read_csv('data/processed/mitbih/X_val.csv')
y_val = pd.read_csv('data/processed/mitbih/y_val.csv')
y_val = y_val['187']

X_test = df_mitbih_test.drop(187, axis = 1)
y_test = df_mitbih_test[187]

#reshape the data for 1D CNN
X_train_sm_cnn = np.expand_dims(X_train_sm, axis=2)
X_val_cnn = np.expand_dims(X_val, axis=2)
X_test_cnn = np.expand_dims(X_test, axis=2)


In [None]:
#configuration
SAMPLES_PER_CLASS = 40  # 40 samples × 5 classes = 200 total samples
N_BACKGROUND = 100      # Background samples for SHAP explainer
RANDOM_SEED = 42        # reproducibility

np.random.seed(RANDOM_SEED)

#load best CNN
model = load_model('models/MIT_02_03_dl_models/CNN/cnn8_sm_BS512_best.keras')


In [None]:
# verify data shape 
print(f"\nTraining set (SMOTE):")
print(f"  X_train_sm_cnn shape: {X_train_sm_cnn.shape}")
print(f"  y_train_sm shape: {y_train_sm.shape}")
print(f"  Unique classes: {np.unique(y_train_sm)}")
print(f"  Class distribution: {np.bincount(y_train_sm)}")

print(f"\nValidation set:")
print(f"  X_val_cnn shape: {X_val_cnn.shape}")
print(f"  y_val shape: {y_val.shape}")

print(f"\nTest set:")
print(f"  X_test_cnn shape: {X_test_cnn.shape}")
print(f"  y_test shape: {y_test.shape}")

print(f"\nModel input shape: {model.input_shape}")
print(f"Model output shape: {model.output_shape}")

# Expected output:
# X_train_sm shape: (289885, 187, 1)
# y_train_sm shape: (289885,)
# Unique classes: [0 1 2 3 4]
# Class distribution: [57977 57977 57977 57977 57977]

# X_val_cnn shape: (17511, 187, 1)
# y_val shape: (17511,)

# X_test_cnn shape: (21892, 187, 1)
# y_test shape: (21892,)

# Model input shape: (None, 187, 1)
# Model output shape: (None, 5)

In [None]:
# Model performance check -> should be same results as results in Rendering2 for best model
y_pred = model.predict(X_test_cnn)
y_pred_classes = np.argmax(y_pred, axis=1)

# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred_classes, digits=4))

# Confusion matrix
cm = confusion_matrix(y_test, y_pred_classes)
print("\nConfusion Matrix:")
print(cm)

In [None]:
# Data preparation for SHAP: StratifiedSampling

# Select background data (random sample from training)
background_indices = np.random.choice(len(X_train_sm_cnn), N_BACKGROUND, replace=False)
background_data = X_train_sm_cnn[background_indices]

print(f"\nBackground data shape: {background_data.shape}")
print(f"Background data range: [{background_data.min():.3f}, {background_data.max():.3f}]")

# Select test samples -> ensure balanced class distribution
test_indices = []
for class_idx in range(5):
    class_samples = np.where(y_test == class_idx)[0]
    n_samples = min(SAMPLES_PER_CLASS, len(class_samples))
    selected = np.random.choice(class_samples, n_samples, replace=False)
    test_indices.extend(selected)

test_indices = np.array(test_indices)
X_explain = X_test_cnn[test_indices]
y_explain = y_test[test_indices]

print(f"\nTest samples to explain: {len(test_indices)}")
print(f"X_explain shape: {X_explain.shape}")
print(f"Class distribution: {np.bincount(y_explain.astype(int))}")


# expected output
#Background data shape: (100, 187, 1)
#Background data range: [0.000, 1.000]

#Test samples to explain: 200
#X_explain shape: (200, 187, 1)
#Class distribution: [40 40 40 40 40]

In [None]:
# Initialize SHAP Explainer -> DeepExplainer for DL model

explainer = shap.DeepExplainer(model, background_data)

print(f"\nExpected value (baseline) for each class:")
for i, ev in enumerate(explainer.expected_value):
    print(f"  Class {i}: {ev:.4f}")

baseline_sum = sum(explainer.expected_value)
print(f"\nBaseline sum: {baseline_sum:.4f}")

#expected output: baseline sum should be 1

In [None]:
# calculcation of SHAP values

shap_values_raw = explainer.shap_values(X_explain)

# Reshape SHAP values to correct format
print(f"\nRaw SHAP values shape: {shap_values_raw.shape}")

if isinstance(shap_values_raw, np.ndarray) and len(shap_values_raw.shape) == 4:
    shap_values = []
    for class_idx in range(5):
        shap_values.append(shap_values_raw[:, :, :, class_idx])
else:
    shap_values = shap_values_raw

print(f"\nFinal SHAP structure:")
print(f"  Type: {type(shap_values)}")
print(f"  Length (classes): {len(shap_values)}")
print(f"  Shape per class: {shap_values[0].shape}")

# Statistics
for i in range(5):
    print(f"\nClass {i} SHAP statistics:")
    print(f"  Min: {shap_values[i].min():.4f}")
    print(f"  Max: {shap_values[i].max():.4f}")
    print(f"  Mean: {shap_values[i].mean():.4f}")
    print(f"  Std: {shap_values[i].std():.4f}")


#expected output
#Raw SHAP values shape: (200, 187, 1, 5)
#Final SHAP structure:
#Type: <class 'list'>
#Length (classes): 5
#Shape per class: (200, 187, 1)

In [None]:
#Preparations for plotting

X_explain_2d = X_explain.reshape(len(X_explain), 187)
shap_values_2d = [sv.reshape(len(X_explain), 187) for sv in shap_values]

print(f"X_explain_2d shape: {X_explain_2d.shape}")
print(f"shap_values_2d[0] shape: {shap_values_2d[0].shape}")


# expected output
#X_explain_2d shape: (200, 187)
#shap_values_2d[0] shape: (200, 187)

In [None]:
# Plot: Top 20 most important features for every class

# Calculate importance for each class
class_importance = []
for class_idx in range(5):
    mean_abs_shap = np.mean(np.abs(shap_values_2d[class_idx]), axis=0)
    class_importance.append(mean_abs_shap)

# Create separate plot for each class
for class_idx in range(5):
    
    # Create individual figure for class
    fig, ax = plt.subplots(1, 1, figsize=(10, 8))
    
    # Get top 20 features for this class
    top_20 = np.argsort(class_importance[class_idx])[-20:][::-1]
    
    # Plot bars
    y_pos = np.arange(len(top_20))
    ax.barh(y_pos, class_importance[class_idx][top_20], 
            color=f'C{class_idx}', alpha=0.7, edgecolor='black', linewidth=0.5)
    
    # Labels
    ax.set_yticks(y_pos)
    ax.set_yticklabels([f'Feature {f}' for f in top_20], fontsize=10)
    ax.set_xlabel('Mean |SHAP Value|', fontsize=12, fontweight='bold')
    ax.set_title(f'Class {class_idx} - Top 20 Most Important Features', 
                fontsize=13, fontweight='bold')
    ax.invert_yaxis()
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add values on bars
    for i, (feat, val) in enumerate(zip(top_20, class_importance[class_idx][top_20])):
        ax.text(val, i, f' {val:.4f}', va='center', fontsize=8)
    
    plt.tight_layout()
    filename = f'reports/interpretability/SHAP_MIT/feature_importance_class_{class_idx}.png'
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print top 10 for this class
    print(f"\nClass {class_idx} - Top 10 features:")
    for rank, feat in enumerate(top_20[:10], 1):
        print(f"  {rank}. Feature {feat}: {class_importance[class_idx][feat]:.4f}")

In [None]:
# feature effects class 0

# Choose a class to analyze
class_to_analyze = 0
print(f"\nAnalyzing Class {class_to_analyze}")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_2d[class_to_analyze], X_explain_2d, show=False)
plt.title(f'SHAP Summary Plot - Class {class_to_analyze}')
plt.tight_layout()
plt.savefig(f'reports/interpretability/SHAP_MIT/shap_summary_class_{class_to_analyze}.png', dpi=150, bbox_inches='tight')
plt.show()



# feature effects class 1

# Choose a class to analyze
class_to_analyze = 1
print(f"\nAnalyzing Class {class_to_analyze}")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_2d[class_to_analyze], X_explain_2d, show=False)
plt.title(f'SHAP Summary Plot - Class {class_to_analyze}')
plt.tight_layout()
plt.savefig(f'reports/interpretability/SHAP_MIT/shap_summary_class_{class_to_analyze}.png', dpi=150, bbox_inches='tight')
plt.show()



# feature effects class 2

# Choose a class to analyze
class_to_analyze = 2
print(f"\nAnalyzing Class {class_to_analyze}")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_2d[class_to_analyze], X_explain_2d, show=False)
plt.title(f'SHAP Summary Plot - Class {class_to_analyze}')
plt.tight_layout()
plt.savefig(f'reports/interpretability/SHAP_MIT/shap_summary_class_{class_to_analyze}.png', dpi=150, bbox_inches='tight')
plt.show()



# feature effects class 3

# Choose a class to analyze
class_to_analyze = 3
print(f"\nAnalyzing Class {class_to_analyze}")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_2d[class_to_analyze], X_explain_2d, show=False)
plt.title(f'SHAP Summary Plot - Class {class_to_analyze}')
plt.tight_layout()
plt.savefig(f'reports/interpretability/SHAP_MIT/shap_summary_class_{class_to_analyze}.png', dpi=150, bbox_inches='tight')
plt.show()



# feature effects class 4

# Choose a class to analyze
class_to_analyze = 4
print(f"\nAnalyzing Class {class_to_analyze}")

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_2d[class_to_analyze], X_explain_2d, show=False)
plt.title(f'SHAP Summary Plot - Class {class_to_analyze}')
plt.tight_layout()
plt.savefig(f'reports/interpretability/SHAP_MIT/shap_summary_class_{class_to_analyze}.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
#sample ECG over SHAP values for 3 examples per class

# Find correctly classified samples
predicted_classes = np.argmax(model.predict(X_explain, verbose=0), axis=1)

for target_class in range(5):
    # Find samples of this class that were correctly predicted
    correct_idx = np.where((y_explain == target_class) & 
                          (predicted_classes == target_class))[0]
    
    if len(correct_idx) > 0:
        # Take up to 3 samples
        n_samples = min(3, len(correct_idx))
        
        for i in range(n_samples):
            sample_idx = correct_idx[i]
            
            print(f"Analyzing Sample {sample_idx} (True Class: {target_class}, Example {i+1}/3)")
            
            # Get prediction
            pred = model.predict(X_explain[sample_idx:sample_idx+1], verbose=0)[0]
            print(f"Predicted class: {np.argmax(pred)}")
            print(f"Confidence: {pred[target_class]:.4f}")
            
            # Plot ECG signal with SHAP values
            fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 8), sharex=True)
            
            # ECG signal
            ecg_signal = X_explain_2d[sample_idx]
            ax1.plot(ecg_signal, 'b-', linewidth=1)
            ax1.set_ylabel('Normalized ECG Signal', fontsize=12)
            ax1.set_title(f'Sample {sample_idx} - True Class: {target_class}, Predicted: {np.argmax(pred)}', 
                         fontsize=14, fontweight='bold')
            ax1.grid(True, alpha=0.3)
            
            # SHAP values for the predicted class
            shap_vals = shap_values_2d[target_class][sample_idx]
            colors = ['red' if x > 0 else 'blue' for x in shap_vals]
            ax2.bar(range(len(shap_vals)), shap_vals, color=colors, alpha=0.6, width=1.0)
            ax2.set_xlabel('Time Step (Feature Index)', fontsize=12)
            ax2.set_ylabel('SHAP Value', fontsize=12)
            ax2.set_title(f'Feature Importance - Class {target_class}', fontsize=12)
            ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.8)
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(f'reports/interpretability/SHAP_MIT/ecg_shap_sample_{sample_idx}_class_{target_class}_example_{i+1}.png', 
                       dpi=150, bbox_inches='tight')
            plt.show()
            
            # Show top important features
            top_n = 10
            top_features = np.argsort(np.abs(shap_vals))[-top_n:][::-1]
            print(f"\nTop {top_n} most important time steps:")
            for rank, feat_idx in enumerate(top_features, 1):
                print(f"  {rank}. Feature {feat_idx}: SHAP={shap_vals[feat_idx]:.4f}, "
                      f"Value={ecg_signal[feat_idx]:.4f}")

In [None]:
#misclassification analysis

#identify misclassified samples

# Get predictions on full test set
y_pred_probs = model.predict(X_test_cnn, verbose=0)
y_pred = np.argmax(y_pred_probs, axis=1)

# Find misclassified samples
misclassified_idx = np.where(y_test != y_pred)[0]
correct_idx = np.where(y_test == y_pred)[0]

print(f"\nTotal test samples: {len(y_test)}")
print(f"Correctly classified: {len(correct_idx)} ({100*len(correct_idx)/len(y_test):.2f}%)")
print(f"Misclassified: {len(misclassified_idx)} ({100*len(misclassified_idx)/len(y_test):.2f}%)")

# Confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
print("\nConfusion Matrix:")
print(cm)

# Find most common misclassification patterns
print("\nMost Common Misclassification Patterns:")
misclassification_patterns = []
for true_class in range(5):
    for pred_class in range(5):
        if true_class != pred_class:
            count = cm[true_class, pred_class]
            if count > 0:
                misclassification_patterns.append({
                    'true': true_class,
                    'pred': pred_class,
                    'count': count
                })

# Sort by frequency
misclassification_patterns = sorted(misclassification_patterns, 
                                   key=lambda x: x['count'], 
                                   reverse=True)

print("\nTop 10 confusion patterns:")
for i, pattern in enumerate(misclassification_patterns[:10], 1):
    print(f"  {i}. True Class {pattern['true']} → Predicted Class {pattern['pred']}: "
          f"{pattern['count']} cases")


In [None]:
# Sample Misclassified Cases for SHAP

# Select representative misclassified samples
samples_to_analyze = []
samples_info = []

for pattern in misclassification_patterns[:5]:  # Top 5 patterns
    true_class = pattern['true']
    pred_class = pattern['pred']
    
    # Find samples matching this pattern
    pattern_samples = misclassified_idx[
        (y_test[misclassified_idx] == true_class) & 
        (y_pred[misclassified_idx] == pred_class)
    ]
    
    # Select up to 3 samples from this pattern
    n_samples = min(3, len(pattern_samples))
    selected = np.random.choice(pattern_samples, n_samples, replace=False)
    
    samples_to_analyze.extend(selected)
    samples_info.extend([{
        'idx': idx,
        'true': true_class,
        'pred': pred_class,
        'prob': y_pred_probs[idx]
    } for idx in selected])

print(f"\nSelected {len(samples_to_analyze)} misclassified samples")
print(f"Patterns covered: {len(set([(s['true'], s['pred']) for s in samples_info]))}")

# Also select correctly classified samples for comparison
correct_samples_per_class = []
for class_idx in range(5):
    class_correct = correct_idx[y_test[correct_idx] == class_idx]
    if len(class_correct) > 0:
        selected = np.random.choice(class_correct, min(3, len(class_correct)), replace=False)
        correct_samples_per_class.extend(selected)

print(f"Selected {len(correct_samples_per_class)} correctly classified samples for comparison")

# Combine for SHAP analysis
all_samples = np.array(samples_to_analyze + correct_samples_per_class)
X_analyze = X_test_cnn[all_samples]
y_analyze = y_test[all_samples]

print(f"\nTotal samples for SHAP analysis: {len(all_samples)}")

In [None]:
# Calculate SHAP for Misclassified Cases

shap_values_misc_raw = explainer.shap_values(X_analyze)

# Reshape SHAP values
print(f"\nRaw SHAP shape: {shap_values_misc_raw.shape}")

if isinstance(shap_values_misc_raw, np.ndarray) and len(shap_values_misc_raw.shape) == 4:
    shap_values_misc = []
    for class_idx in range(5):
        shap_values_misc.append(shap_values_misc_raw[:, :, :, class_idx])
else:
    shap_values_misc = shap_values_misc_raw

# Reshape to 2D for analysis
X_analyze_2d = X_analyze.reshape(len(all_samples), 187)
shap_values_misc_2d = [sv.reshape(len(all_samples), 187) for sv in shap_values_misc]

print(f"Final SHAP structure: {len(shap_values_misc_2d)} classes")
print(f"Shape per class: {shap_values_misc_2d[0].shape}")

In [None]:
#Visualize Misclassified Cases

# Analyze each misclassified sample
for i, info in enumerate(samples_info[:10]):  # Top 10
    sample_idx = info['idx']
    true_class = info['true']
    pred_class = info['pred']
    probs = info['prob']
    
    # Find this sample in analyzed set
    analyze_idx = np.where(all_samples == sample_idx)[0][0]
    
    print(f"Misclassified Sample {i+1}: Index {sample_idx}")
    print(f"  True Class: {true_class}")
    print(f"  Predicted Class: {pred_class}")
    print(f"  Prediction confidence: {probs[pred_class]:.4f}")
    print(f"  True class probability: {probs[true_class]:.4f}")
    
    # Create plot
    fig = plt.figure(figsize=(16, 12))
    gs = fig.add_gridspec(4, 2, hspace=0.3, wspace=0.3)
    
    # 1. ECG Signal
    ax1 = fig.add_subplot(gs[0, :])
    ecg = X_analyze_2d[analyze_idx]
    ax1.plot(ecg, 'b-', linewidth=1.2)
    ax1.set_ylabel('ECG Signal', fontsize=11)
    ax1.set_title(f'Sample {sample_idx}: True={true_class}, Predicted={pred_class} '
                  f'(Confidence={probs[pred_class]:.3f})', 
                  fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(0, 186)
    
    # 2. SHAP for PREDICTED class
    ax2 = fig.add_subplot(gs[1, :])
    shap_pred = shap_values_misc_2d[pred_class][analyze_idx]
    colors_pred = ['red' if x > 0 else 'blue' for x in shap_pred]
    ax2.bar(range(187), shap_pred, color=colors_pred, alpha=0.6, width=1.0)
    ax2.set_ylabel('SHAP Value', fontsize=11)
    ax2.set_title(f'Why Model Predicted Class {pred_class}', fontsize=12, fontweight='bold')
    ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.8)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(0, 186)
    
    # 3. SHAP for TRUE class
    ax3 = fig.add_subplot(gs[2, :])
    shap_true = shap_values_misc_2d[true_class][analyze_idx]
    colors_true = ['red' if x > 0 else 'blue' for x in shap_true]
    ax3.bar(range(187), shap_true, color=colors_true, alpha=0.6, width=1.0)
    ax3.set_ylabel('SHAP Value', fontsize=11)
    ax3.set_title(f'Evidence for True Class {true_class} (Missed)', fontsize=12, fontweight='bold')
    ax3.axhline(y=0, color='black', linestyle='-', linewidth=0.8)
    ax3.grid(True, alpha=0.3)
    ax3.set_xlim(0, 186)
    ax3.set_xlabel('Time Step (Feature Index)', fontsize=11)
    
    # 4. Prediction probabilities
    ax4 = fig.add_subplot(gs[3, 0])
    colors_bar = ['green' if idx == true_class else ('red' if idx == pred_class else 'gray') 
                  for idx in range(5)]
    ax4.bar(range(5), probs, color=colors_bar, alpha=0.7)
    ax4.set_xlabel('Class', fontsize=11)
    ax4.set_ylabel('Probability', fontsize=11)
    ax4.set_title('Prediction Probabilities', fontsize=11, fontweight='bold')
    ax4.set_xticks(range(5))
    ax4.grid(True, alpha=0.3, axis='y')
    
    # 5. Key features comparison
    ax5 = fig.add_subplot(gs[3, 1])
    top_n = 10
    top_pred = np.argsort(np.abs(shap_pred))[-top_n:][::-1]
    top_true = np.argsort(np.abs(shap_true))[-top_n:][::-1]
    
    comparison_text = "Top Features:\n\n"
    comparison_text += f"Pred Class {pred_class}:\n"
    for rank, feat in enumerate(top_pred[:5], 1):
        comparison_text += f"  {rank}. F{feat}: {shap_pred[feat]:+.3f}\n"
    comparison_text += f"\nTrue Class {true_class}:\n"
    for rank, feat in enumerate(top_true[:5], 1):
        comparison_text += f"  {rank}. F{feat}: {shap_true[feat]:+.3f}\n"
    
    ax5.text(0.1, 0.5, comparison_text, fontsize=9, family='monospace',
             verticalalignment='center', transform=ax5.transAxes)
    ax5.axis('off')
    
    plt.savefig(f'reports/interpretability/SHAP_MIT/misclassified_sample_{i+1}_idx_{sample_idx}.png', 
                dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print top features
    print("\nTop 3 features:")
    print(f"  Predicted Class {pred_class}:")
    for feat in top_pred[:3]:
        print(f"    Feature {feat}: SHAP={shap_pred[feat]:+.4f}, Value={ecg[feat]:.4f}")
    print(f"  True Class {true_class}:")
    for feat in top_true[:3]:
        print(f"    Feature {feat}: SHAP={shap_true[feat]:+.4f}, Value={ecg[feat]:.4f}")