In [None]:
import pandas as pd
import numpy as np
import os 
import copy
import seaborn as sns
import joblib

from collections import Counter

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt


from crepes import WrapClassifier
from nonconformist.icp import IcpClassifier
from nonconformist.nc import ClassifierNc, MarginErrFunc
from nonconformist.icp import IcpClassifier
from nonconformist.base import ClassifierAdapter

import shap

# Reading data files

In [None]:
# Reading input
train_df = pd.read_csv("data_splits/train.csv")
valid_df = pd.read_csv("data_splits/valid.csv")
test_df = pd.read_csv("data_splits/test.csv")
calibration_df = pd.read_csv("data_splits/calibration.csv")

In [None]:
# Selecting required features
features =['edss value/score','age_at_visit' ,'sex_label',
                'no_treatment', 'first_line_DMT', 'second_line_DMT', 'other_drugs', 'relapse_treatment_drugs','stem_cell_treatment',
                'eq5d_score','age_at_eq5d',
               'sdmt_score','age_at_sdmt',
                'mono_on_sum','monofocal_sum','multi_focal_sum','afferent_non_on_sum','steroid_treatment_sum','is_last_relapse_steroid_treated','is_last_relapse_completely_remitted','age_at_relapse','revised_debut_age', 'age_at_debut_relapse',
                't2_lesion_catagory', 'brain_barrier_lesion_catagory', 'spinal_barrier_lesion_catagory','age_at_mri']

y_label = ["y_label"]

In [None]:
# Concatnating train and valid df
train_valid_df = pd.concat([train_df,valid_df])

# Extracting required features from the data splits
sub_train_df = copy.deepcopy(train_valid_df[features+y_label])
X_train = sub_train_df[features].values
y_train = sub_train_df[y_label].values
        
sub_calibration_df = copy.deepcopy(calibration_df[features+y_label])
X_cal = sub_calibration_df[features].values
y_cal = sub_calibration_df[y_label].values

sub_test_df = copy.deepcopy(test_df[features+y_label])
X_test = sub_test_df[features].values
y_test = sub_test_df[y_label].values

In [None]:
x_train_df = sub_train_df[features]
x_calibration_df = sub_calibration_df[features]
x_test_df = sub_test_df[features]

# Training and saving the model

In [None]:
# rf_parameters
n_estimators=150 #100
min_samples_leaf=5
criterion="gini"
class_weight="balanced"
max_depth=None

def lambda_fuction(x):
    return x[1]
    
clf = RandomForestClassifier(n_estimators=n_estimators,
                            min_samples_leaf=min_samples_leaf,
                            criterion=criterion,
                            class_weight=class_weight,
                           max_depth=max_depth,
                            n_jobs=-1)



# Training the conformal model
icp = IcpClassifier(ClassifierNc(ClassifierAdapter(clf),
                                 MarginErrFunc()), condition=lambda_fuction)
icp.fit(X_train, y_train)
icp.calibrate(X_cal, y_cal.ravel())
rf_model = icp.nc_function.model.model


# Removing cal data from the model
icp.cal_x=[]
icp.cal_y=[]

# Saving the model
os.system("mkdir models")
joblib.dump(rf_model, "models/rf.joblib")
joblib.dump(icp, "models/icp.joblib")

# Loading the model

In [None]:
rf_model = joblib.load("models/rf.joblib")
icp = joblib.load("models/icp.joblib")

# Getting test metrics

In [None]:
from sklearn.metrics import roc_auc_score
roc_auc_score(y_test, rf_model.predict_proba(X_test)[:,1])

In [None]:
from sklearn.metrics import RocCurveDisplay
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(rf_model, X_test, y_test, ax=ax, alpha=0.8)
#rfc_disp.plot(ax=ax, alpha=0.8)
plt.show()

In [None]:
y_test_preds = np.argmax(rf_model.predict_proba(X_test),axis=1)
cls_rprt = classification_report(y_test, y_test_preds)
print (cls_rprt)

# Conformal analysis

## Downloading plotting scripts

In [None]:
! git clone https://github.com/pharmbio/plot_utils.git
! cd plot_utils && git reset --hard 491d1f9

## CP plots

In [None]:
import sys
sys.path.append("plot_utils/python/src/")
from pharmbio.cp import metrics
from pharmbio.cp import plotting
np.set_printoptions(suppress=True)
plt.rcParams.update({'font.size': 20})

### Confusion metrix

In [None]:
test_pval = icp.predict(X_test)

metrics.frac_errors(y_test, test_pval, sign_vals=[.15,.25])
print ("Observed fuzziness = ",metrics.obs_fuzziness(y_test, test_pval))
CM = metrics.confusion_matrix(y_test, test_pval, sign=0.08, labels=[0,1])
print (CM)

### Calibration plot

In [None]:
line_args = {'alpha': 0.6, 'marker':"*", 'linestyle': '-.', 'linewidth':2.5}
the_fig = plotting.plot_calibration_clf(y_test, 
                                          test_pval,
                                          sign_vals=np.arange(0.0,1,0.001),
                                          chart_padding=0.025,
                                          labels=["RRMS", "SPMS"], 
                                          **line_args)
# Get the axes to make customizations on
axes = the_fig.axes[0]
# Set a custom title
axes.set_title('Calibration plot', fontsize=22)
# Add a new (custom) legend
axes.legend(shadow=True,title='Prediction type')
the_fig.savefig("calibration_plot.png",dpi=300)

### P0-P1 plot

In [None]:
kwargs = {'alpha': 0.75} #, 'linewidth': 10} #, 'fontsize': 30}
font_args = None #{'fontsize':30}
marks = ['o','x']
s = [100, 200]
p0p1 = plotting.plot_pvalues(y_test,test_pval,
                             title='P0-P1 plot',
                             sizes=s, 
                             markers=marks, 
                             labels=['RR', 'SP'],
                             fontargs=font_args,
                             **kwargs)

### Label distribution plot

In [None]:
plt.rcParams.update({'font.size': 20})
my_fig = plt.figure(figsize=(15,4))
ax = my_fig.add_axes([0,0,1,1])
custom_args = {'alpha': 1}
fig = plotting.plot_label_distribution(y_true=y_test,p_values=test_pval,  tight_layout=True, ax=ax,display_incorrect=True,title="Label distribution plot", **custom_args)
fig.savefig("label_distribution_plot.png",dpi=300,bbox_inches='tight')

# SHAP
## Calculating shap values

In [None]:
import shap
feature_names_for_shap = [entry.replace("_"," ") for entry in list(x_test_df.columns)]
feature_names_for_shap[feature_names_for_shap.index('revised debut age')] = "debut age"
x_test_df.columns = feature_names_for_shap

In [None]:
explainer = shap.Explainer(rf_model, x_test_df)
shap_values = explainer(x_test_df,check_additivity=False)

## Barplot (relative importance)

In [None]:
plt.clf()
shap.plots.bar(shap_values[:,:,1],show=False)
plt.savefig("shap_barplot.png",dpi=300, bbox_inches='tight')

## Beeswarm plot (global importance)

In [None]:
plt.clf()
shap.plots.beeswarm(shap_values[:,:,1],show=False)
plt.savefig("shap_beeswarm_plot.png",dpi=300, bbox_inches='tight')

## Violin plot

In [None]:
shap.plots.violin(shap_values[:,:,1])