# Sample codes to precompute `shap_summary_dfs` and `shap_sample_dfs`

The following shows samples codes to precompute `shap_summary_dfs` and `shap_sample_dfs` for the cases:
- Regression
- Binary classification
- Multiclass classification

Finally, save `shap_summary_dfs` and `shap_sample_dfs` as pickle.

In [None]:
from toolkit_spark import get_explainer, compute_shap
from constants import FEATURES

In [None]:
# Helpers
def get_corrs(df, features):
    """Compute correlation between prediction and features for each class (if any)."""
    df1 = (
        df
        .withColumn("features", F.array(*features))
        .withColumn("prediction", udf_predict_score("features"))
        .drop("features")
    )
    df1.cache()
    df1.count()
    
    sample_pred = df1.select("prediction").take(1)[0][0]
    if isinstance(sample_pred, list):
        num_classes = len(sample_pred)
        all_corrs = [
            [
                df1
                .select(c, F.col("prediction").getItem(i).alias("pred"))
                .corr(c, "pred")
                for c in features
            ] for i in range(num_classes)
        ]
    else:
        all_corrs = [[df1.corr(c, "prediction") for c in features]]
    return all_corrs


def get_mashap(x_df):
    """Compute mean(|SHAP|) for each feature and for each class (if any)."""
    shap_df = x_df.rdd.mapPartitions(get_shap_partition).toDF()
    shap_df.cache()
    num_rows = shap_df.count()
    row = shap_df.take(1)[0]
    num_features = len(row[0]) - 1
    
    all_mashap = [
        [
            shap_df
            .select(F.abs(F.col(f"_{i + 1}").getItem(j))).groupBy().mean().take(1)[0][0]
            for j in range(1, len(FEATURES) + 1)
        ]
        for i in range(1, len(row))
    ]
    return all_mashap, shap_df


def get_shap_summary_dfs(all_corrs, all_mashap, features):
    shap_summary_dfs = [
        pd.DataFrame({
            "feature": features,
            "mas_value": mashap,
            "corrcoef": corrs,
        })
        for mashap, corrs in zip(all_mashap, all_corrs)
    ]
    return shap_summary_dfs


def sample_shap_df(shap_df, features, num_rows=3000, seed=42):
    """Sample shap_df."""
    n = shap_df.count()
    if num_rows >= n:
        _sample_df = shap_df.collect()
    else:
        _sample_df = shap_df.sample(False, num_rows / n, seed=seed).collect()
    
    feats = [list(_sample_df[i][0]) for i in range(len(_sample_df))]
    sample_dfs = [pd.DataFrame(feats, columns=features)]  # dataframe of features
    
    for j in range(1, len(_sample_df[0])):
        tmp_arr = [list(_sample_df[i][j]) for i in range(len(_sample_df))]
        sample_dfs.append(np.array(tmp_arr))  # arrays of shap
    return sample_dfs


## Regression example

In [None]:
@F.udf(returnType=FloatType())
def udf_predict_score(row):
    #################################################
    # Amend accordingly
    model = model_bc.value
    score = model.predict([row]).item()
    #################################################
    return score


def get_shap_partition(rows):
    """
    Compute Shapley values in the shape of (prediction size x num_rows x len(features)).
    Also outputs the raw features, but no base_value
    """
    #################################################
    # Amend accordingly
    model = model_bc.value
    expl = get_explainer(...)
    #################################################
    
    rows = np.squeeze(np.asarray(list(rows)))
    shap_values, base_value = compute_shap(expl, rows)
    
    results = []
    for i in range(rows.shape[0]):
        result = [rows[i].tolist()]  # features
        for j in range(len(base_value)):  # for each, corresponding shap
            result.append(shap_values[j][i].tolist())
        results.append(Row(*result))
    return results


# Compute correlation between prediction and features    
all_corrs = get_corrs(valid, FEATURES)

# Compute mean(|shap_values|)
x_df = valid.select(FEATURES).repartition(2)
all_mashap, shap_df = get_mashap(x_df)

shap_summary_dfs = get_shap_summary_dfs(all_corrs, all_mashap, FEATURES)


## Binary classification example

In [None]:
@F.udf(returnType=FloatType())
def udf_predict_score(row):
    #################################################
    # Amend accordingly
    model = model_bc.value
    score = model.predict_proba([row])[0, 1].item()
    #################################################
    return score


def get_shap_partition(rows):
    """
    Compute Shapley values in the shape of (prediction size x num_rows x len(features)).
    Also outputs the raw features, but no base_value
    """
    #################################################
    # Amend accordingly
    model = model_bc.value
    expl = get_explainer(...)
    #################################################
    
    rows = np.squeeze(np.asarray(list(rows)))
    shap_values, base_value = compute_shap(expl, rows)
    
    results = []
    for i in range(rows.shape[0]):
        result = [rows[i].tolist()]  # features
        for j in range(len(base_value)):  # for each, corresponding shap
            result.append(shap_values[j][i].tolist())
        results.append(Row(*result))
    return results


# Compute correlation between prediction and features    
all_corrs = get_corrs(valid, FEATURES)

# Compute mean(|shap_values|)
x_df = valid.select(FEATURES).repartition(2)
all_mashap, shap_df = get_mashap(x_df)

shap_summary_dfs = get_shap_summary_dfs(all_corrs, all_mashap, FEATURES)


## Multiclass classification example

In [8]:
@F.udf(returnType=ArrayType(FloatType()))
def udf_predict_score(row):
    #################################################
    # Amend accordingly
    model = model_bc.value
    scores = model.predict_proba([row])[0].tolist()
    #################################################
    return scores


def get_shap_partition(rows):
    """
    Compute Shapley values in the shape of (prediction size x num_rows x len(features)).
    Also outputs the raw features, but no base_value
    """
    #################################################
    # Amend accordingly
    model = model_bc.value
    expl = get_explainer(...)
    #################################################
    
    rows = np.squeeze(np.asarray(list(rows)))
    shap_values, base_value = compute_shap(expl, rows)
    
    results = []
    for i in range(rows.shape[0]):
        result = [rows[i].tolist()]  # features
        for j in range(len(base_value)):  # for each, corresponding shap
            result.append(shap_values[j][i].tolist())
        results.append(Row(*result))
    return results


# Compute correlation between prediction and features    
all_corrs = get_corrs(valid, FEATURES)

# Compute mean(|shap_values|)
x_df = valid.select(FEATURES).repartition(2)
all_mashap, shap_df = get_mashap(x_df)

shap_summary_dfs = get_shap_summary_dfs(all_corrs, all_mashap, FEATURES)


## Save `shap_summary_dfs` and `shap_sample_dfs` as pickle

In [None]:
with open("shap_summary_dfs.pkl", "wb") as f:
    pickle.dump(shap_summary_dfs, f)

In [None]:
shap_sample_dfs = sample_shap_df(shap_df, FEATURES)

with open("shap_sample_dfs.pkl", "wb") as f:
    pickle.dump(shap_sample_dfs, f)