# SHAP plots for Telecom Churn

This notebook creates three modern, well-labeled SHAP visualizations: a summary (beeswarm-like), a feature importance bar chart, and an interactive dependence plot.

The cells include a small fallback: if you don't have a trained `model` and feature matrix `X`, the notebook will train a quick RandomForest on `data/churndata.csv` to demonstrate the plots.

Requirements (install once):
```
pip install shap scikit-learn pandas matplotlib seaborn plotly
```
: 
,
: {
: 

: [
,
,
,
,
,
,
,
,
,
,
,
,
,
whitegrid", context="notebook")
plt.rcParams.update({
    'figure.figsize': (10, 6),
    'axes.titlesize': 14,
    'axes.labelsize': 12
})

In [None]:
# Helper: load dataset or create demo model if needed
def prepare_demo_model(data_path='data/churndata.csv', target_col=None, sample_frac=0.5, random_state=42):
    # Load CSV if available
    if os.path.exists(data_path):
        df = pd.read_csv(data_path)
    else:
        # Synthetic fallback if dataset not present
        from sklearn.datasets import make_classification
        Xs, ys = make_classification(n_samples=2000, n_features=12, n_informative=6, random_state=random_state)
        df = pd.DataFrame(Xs, columns=[f'f{i}' for i in range(Xs.shape[1])])
        df['target'] = ys
    # Guess target column if not provided
    if target_col is None:
        if 'churn' in df.columns.str.lower():
            possible = [c for c in df.columns if c.lower().endswith('churn') or 'churn' in c.lower()]
            target_col = possible[0] if possible else df.columns[-1]
        else:
            target_col = df.columns[-1]
    # Basic preprocessing: drop NAs, encode categoricals simply
    df = df.dropna().copy()
    X = df.drop(columns=[target_col])
    y = df[target_col]
    for col in X.select_dtypes(include=['object', 'category']).columns:
        X[col] = LabelEncoder().fit_transform(X[col].astype(str))
    # Subsample for speed if large
    X, _, y, _ = train_test_split(X, y, train_size=sample_frac, stratify=y, random_state=random_state)
    # Train a quick tree model
    model = RandomForestClassifier(n_estimators=100, random_state=random_state, n_jobs=-1)
    model.fit(X, y)
    return model, X, y

# Create or reuse model/X from global scope if present
try:
    model  # noqa: F821
    X  # noqa: F821
    print("Using existing `model` and `X` from the environment.")
except NameError:
    print("No existing model/X found — preparing a demo model. This may take ~30s.")
    model, X, y = prepare_demo_model()

In [None]:
# Helper: compute shap values with robust explainer selection
def compute_shap_values(model, X, explainer_type_preference=['Tree','Kernel','Linear']):
    
    # Try TreeExplainer first for tree-based models
    try:
        if hasattr(shap, 'TreeExplainer'):
            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(X)
            return explainer, shap_values
    except Exception as e:
        print("TreeExplainer failed: %s" % e)
    # Fallback to KernelExplainer (slower)
    try:
        if hasattr(shap, 'KernelExplainer'):
            # use a small background sample for KernelExplainer
            background = shap.sample(X, nsamples=min(50, X.shape[0]))
            explainer = shap.KernelExplainer(model.predict_proba if hasattr(model, 'predict_proba') else model.predict, background)
            shap_values = explainer.shap_values(X, nsamples=100)
            return explainer, shap_values
    except Exception as e:
        print("KernelExplainer failed: %s" % e)
    # Last resort: approximate with permutation importance to show bar chart
    from sklearn.inspection import permutation_importance
    r = permutation_importance(model, X, getattr(model, 'predict', lambda x: model.predict_proba(x)[:,1]), n_repeats=10, random_state=0, n_jobs=-1)
    perm_importance = pd.Series(r.importances_mean, index=X.columns)
    return None, perm_importance

explainer, shap_values = compute_shap_values(model, X)
print("Computed SHAP values (explainer: %s)" % (type(explainer).__name__ if explainer is not None else 'None - permutation importance'))

## Plot 1 — SHAP summary (beeswarm-style)

A compact, color-coded summary showing feature impact, ordered by importance. For large datasets we sample for the plot to keep it readable.

In [None]:
# SHAP summary plot (matplotlib + seaborn style)
def plot_shap_summary(explainer, shap_values, X, max_display=20, sample_size=1000, title='SHAP summary plot'):
    # If shap_values is a permutation importance Series, draw bar plot instead
    if isinstance(shap_values, pd.Series):
        vals = shap_values.sort_values(ascending=False).head(max_display)
        plt.figure(figsize=(10,6))
        sns.barplot(x=vals.values, y=vals.index, palette='viridis')
        plt.title(title + ' (permutation importance)')
        plt.xlabel('Importance (mean decrease)')
        plt.tight_layout()
        return
    # shap_values may be a list (for classification with classes) or array
    if isinstance(shap_values, list):
        # choose the class with the largest mean absolute shap value
        mags = [np.abs(s).mean() for s in shap_values]
        idx = int(np.argmax(mags))
        vals = shap_values[idx]
    else:
        vals = shap_values
    # Sample rows for plotting clarity
    if X.shape[0] > sample_size:
        sample_idx = np.random.choice(X.index, size=sample_size, replace=False)
        Xs = X.loc[sample_idx]
        vals_sample = vals[sample_idx] if hasattr(vals, '__getitem__') else vals[sample_idx,:]
    else:
        Xs = X
        vals_sample = vals
    # Use SHAP's plotting if available for a clean beeswarm-style output but style it
    plt.figure(figsize=(10,7))
    try:
        shap.summary_plot(vals_sample, Xs, show=False, max_display=max_display)
        plt.title(title)
    except Exception as e:
        print("shap.summary_plot failed, falling back to bar chart: %s" % e)
        # fallback: mean(|SHAP|) bar plot
        means = np.abs(vals).mean(axis=0)
        order = np.argsort(means)[::-1][:max_display]
        feats = X.columns[order]
        sns.barplot(x=means[order], y=feats, palette='viridis')
        plt.title(title + ' (mean |SHAP|)')
        plt.xlabel('mean |SHAP value|')
    plt.tight_layout()

plot_shap_summary(explainer, shap_values, X, max_display=15)

## Plot 2 — SHAP feature importance (horizontal bar)

A clear horizontal bar chart using average absolute SHAP values to show global importance. For permutation fallback we use the same chart.

In [None]:
def plot_shap_importance(explainer, shap_values, X, top_n=20, title='SHAP feature importance'):
    if isinstance(shap_values, pd.Series):
        vals = shap_values.sort_values(ascending=False).head(top_n)
        plt.figure(figsize=(10,6))
        sns.barplot(x=vals.values, y=vals.index, palette='plasma')
        plt.title(title + ' (permutation)')
        plt.xlabel('Importance')
        plt.tight_layout()
        return
    if isinstance(shap_values, list):
        # aggregate across classes by mean absolute value
        arr = np.mean([np.abs(s) for s in shap_values], axis=0)
    else:
        arr = np.abs(shap_values)
        if arr.ndim == 2:
            arr = np.mean(arr, axis=0)
    means = pd.Series(arr, index=X.columns).sort_values(ascending=False).head(top_n)
    plt.figure(figsize=(10, max(4, 0.35*len(means))))
    sns.barplot(x=means.values, y=means.index, palette='plasma')
    plt.title(title)
    plt.xlabel('mean |SHAP value|')
    plt.tight_layout()

plot_shap_importance(explainer, shap_values, X, top_n=15)

## Plot 3 — SHAP dependence plot (interactive with Plotly)

Shows how a single feature affects predictions — we include color-coding by another feature and an interactive scatter for exploration.

In [None]:
def plot_shap_dependence_interactive(explainer, shap_values, X, feature=None, color_feature=None, title='SHAP dependence (interactive)'):
    # If we only have permutation importance, fall back to static message
    if isinstance(shap_values, pd.Series):
        print('No shap values available for dependence plot (permutation importance only).')
        return
    # Determine shap matrix and choose column if multi-class
    if isinstance(shap_values, list):
        mags = [np.abs(s).mean() for s in shap_values]
        idx = int(np.argmax(mags))
        vals = shap_values[idx]
    else:
        vals = shap_values
    # default feature: most important by mean abs
    means = np.abs(vals).mean(axis=0)
    if feature is None:
        feature = X.columns[int(np.argmax(means))]
    if color_feature is None:
        # choose the second-most important feature to color by (if available)
        order = np.argsort(means)[::-1]
        color_idx = order[1] if len(order) > 1 else order[0]
        color_feature = X.columns[int(color_idx)]
    # build a DataFrame for plotting
    shap_df = pd.DataFrame(vals, columns=X.columns, index=X.index)
    plot_df = pd.DataFrame({
        'feature_value': X[feature],
        'shap_value': shap_df[feature],
        'color_by': X[color_feature]
    })
    # Create interactive scatter colored by color_feature
    fig = px.scatter(plot_df, x='feature_value', y='shap_value', color='color_by',
                     color_continuous_scale='Turbo',
                     title=f'{title}: {feature} (colored by {color_feature})',
                     labels={'feature_value': f'{feature} value', 'shap_value': 'SHAP value', 'color_by': color_feature})
    fig.update_traces(marker=dict(size=6, opacity=0.75))
    fig.update_layout(height=500)
    return fig

fig = plot_shap_dependence_interactive(explainer, shap_values, X)
# In a notebook this will display as interactive; in some viewers call fig.show()
if fig is not None:
    fig.show()

## Notes and next steps

- To improve visuals further, you can export the Plotly figure to HTML with `fig.write_html('dependence.html')`.
- For very large datasets, compute SHAP on a sample or use model-approximation explainers.
- If you have a trained model elsewhere in your repo (e.g., in `src/model_training.py`), import and pass it into the notebook environment to reuse the heavy work.