In [None]:
import os
import pickle

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import shap
from sklearn.preprocessing import StandardScaler


def plot_shaps(shap_values, data, dataset: str, analysis: str, plot_type: str = 'dot'):
    sns.set("paper", font_scale=0.1)
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(1, 1, figsize=(10, 5), dpi=150)
    plt.sca(ax)

    shap.summary_plot(
        shap_values, data, show=False, max_display=20, plot_size=None, feature_names=columns, plot_type=plot_type
    )
    if analysis == 'central':
        plt.title('Centralized')
    elif analysis == 'federated-analysis':
        plt.title('Federated')
    elif analysis == 'smpc-analysis':
        plt.title('SMPC')
    else:
        raise Exception('Invalid analysis')

    plt.title('dataset')

    val = None
    if dataset == 'brca':
        val = 40
    elif dataset == 'gbsg2':
        val = 75
    elif dataset == 'whas500':
        val = 40
    elif dataset == 'microbiome':
        val = 4

    if plot_type == 'dot':
        ax.set_xlim(-val, val)
    elif plot_type == 'bar':
        ax.set_xlim(0, val / 2)
    else:
        raise Exception('Invalid dataset')
    fig.savefig(f'{dataset}/{analysis}_{plot_type}.png', bbox_inches='tight')

In [None]:
for dataset in ['brca',
                'gbsg2',
                'whas500',
                'microbiome'
                ]:
    print(dataset)
    os.makedirs(f'{dataset}', exist_ok=True)
    shaps = []
    for analysis in ['central', 'federated-analysis', 'smpc-analysis']:
        print(analysis)
        if analysis == 'central':
            data = pd.read_csv(f'../central/{dataset}/data.csv').drop(columns=['event', 'tte'])
            columns = data.columns.tolist()
            data = pd.DataFrame(StandardScaler().fit_transform(data))

            with open(f'../central/{dataset}/model.pkl', 'rb') as f:
                model = pickle.load(f)
        elif analysis == 'federated-analysis':
            train = pd.read_csv(
                f'../{analysis}/{dataset}/1_clients/fc_normalization/client_1/data/split_1/train_norm.csv')
            test = pd.read_csv(
                f'../{analysis}/{dataset}/1_clients/fc_normalization/client_1/data/split_1/test_norm.csv')
            data = pd.concat([train, test]).drop(columns=['event', 'tte'])
            columns = data.columns.tolist()
            data = data.to_numpy()
            with open(f'../{analysis}/{dataset}/3_clients/fc_normalization/fc_survival_svm/client_1/model.pickle',
                      'rb') as f:
                model = pickle.load(f)
        elif analysis == 'smpc-analysis':
            train_client_1 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_1/data/split_1/train_norm.csv')
            test_client_1 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_1/data/split_1/test_norm.csv')
            train_client_2 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_2/data/split_1/train_norm.csv')
            test_client_2 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_2/data/split_1/test_norm.csv')
            train_client_3 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_3/data/split_1/train_norm.csv')
            test_client_3 = pd.read_csv(
                f'../{analysis}/{dataset}/5_clients/basic_normalization/client_3/data/split_1/test_norm.csv')
            data = pd.concat(
                [train_client_1, test_client_1, train_client_2, test_client_2, train_client_3, test_client_3]).drop(
                columns=['event', 'tte'])
            columns = data.columns.tolist()
            data = data.to_numpy()
            with open(f'../{analysis}/{dataset}/5_clients/basic_normalization/fc_survival_svm/client_1/model.pickle',
                      'rb') as f:
                model = pickle.load(f)
        else:
            raise Exception('Invalid analysis')

        explainer = shap.Explainer(model.predict, data, feature_names=columns, seed=np.random.seed(231211),
                                   algorithm='auto')
        shap_values = explainer(data, max_evals=5000)

        plot_shaps(shap_values, data, dataset, analysis, plot_type='dot')
        plot_shaps(shap_values, data, dataset, analysis, plot_type='bar')

        shap_series = pd.DataFrame(shap_values.values, columns=columns).abs().mean(axis=0).sort_index()
        shap_series.to_csv(f'{dataset}/{analysis}_shap.csv')
        shaps.append(shap_series)

    for corr_type in ['pearson']:
        corr_central_vs_fed = shaps[0].corr(shaps[1], method=corr_type)
        corr_central_vs_smpc = shaps[0].corr(shaps[2], method=corr_type)
        corr_fed_vs_smpc = shaps[1].corr(shaps[2], method=corr_type)

        corr_df = pd.DataFrame(np.nan, index=['central', 'federated', 'federated + secure aggregation'],
                               columns=['central', 'federated', 'federated + secure aggregation'])
        corr_df.loc['central', 'central'] = 1
        corr_df.loc['central', 'federated'] = corr_central_vs_fed
        corr_df.loc['central', 'federated + secure aggregation'] = corr_central_vs_smpc
        corr_df.loc['federated', 'central'] = corr_central_vs_fed
        corr_df.loc['federated', 'federated'] = 1
        corr_df.loc['federated', 'federated + secure aggregation'] = corr_fed_vs_smpc
        corr_df.loc['federated + secure aggregation', 'central'] = corr_central_vs_smpc
        corr_df.loc['federated + secure aggregation', 'federated'] = corr_fed_vs_smpc
        corr_df.loc['federated + secure aggregation', 'federated + secure aggregation'] = 1
        corr_df.to_csv(f'results/{dataset}/shap_{corr_type}.csv')

In [None]:
dfs = []
for dataset in ['brca', 'gbsg2', 'microbiome', 'whas500']:
    df = pd.read_csv(f'results/{dataset}/shap_pearson.csv')
    df['dataset'] = dataset
    df = df.rename(columns={'Unnamed: 0': 'type'})
    dfs.append(df)
    
df = pd.concat(dfs)
df.to_csv('xai.csv')