# Refactored notebook for modelling

## imports

In [None]:
import sys
import os

sys.path.append("..")

import numpy as np
import pandas as pd
import pickle
import warnings
import re
import plotly.express as px
import plotly.graph_objects as go

from NHS_PROMs.settings import config
from NHS_PROMs.load_data import load_proms, structure_name
from NHS_PROMs.preprocess import filter_in_range, filter_in_labels, method_delta
from NHS_PROMs.utils import (
    most_recent_file,
    downcast,
    map_labels,
    fillna_categories,
    pd_fit_resample,
    infer_categories_fit,
    KindSelector,
    get_feature_names,
    remove_categories,
)
from NHS_PROMs.data_dictionary import meta_dict, methods

import shap
shap.initjs()

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.compose import make_column_selector

# from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.compose import (
    ColumnTransformer,
    make_column_transformer,
    make_column_selector,
)
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
    RandomForestClassifier,
    AdaBoostClassifier,
    GradientBoostingRegressor,
    BaggingClassifier,
)
from sklearn.metrics import classification_report, balanced_accuracy_score
from sklearn.inspection import permutation_importance
from sklearn import set_config
from sklearn.utils.validation import check_is_fitted

set_config(display="diagram")

from imblearn.ensemble import BalancedBaggingClassifier, BalancedRandomForestClassifier, EasyEnsembleClassifier
from imblearn.over_sampling import SMOTENC
from imblearn.pipeline import Pipeline, make_pipeline
from imblearn.under_sampling import RandomUnderSampler 

# use adjusted fillna which can cope with non-existing categories for CategoricalDtype
pd.core.frame.DataFrame.fillna = fillna_categories
# added a remove categories
pd.core.frame.Series.remove_categories = remove_categories
# enable autodetect of categories from CategoricalDtype by using "infer" for SMOTENC
SMOTENC.fit_resample = pd_fit_resample(SMOTENC.fit_resample)
# enable inference of categories for encoders from CategoricalDtype
OneHotEncoder.fit = infer_categories_fit(OneHotEncoder.fit)
OrdinalEncoder.fit = infer_categories_fit(OrdinalEncoder.fit)

## load data
General approach is not DRY for the sake of availability of having knee and hip df's always at hand, but also keep it readable (script-wise).

In [None]:
from NHS_PROMs.model import pl, param_grid

class PROMsModel():
    def __init__(self, kind="hip"):
        self.kind = kind
        self.outputs = config["outputs"][kind]
        
    def load_data(self, mode="train"):
        df = (
            load_proms(part=self.kind)
            .apply(downcast)
            .rename(structure_name, axis=1)
        )
        
        self.load_meta(df.columns)
        
        df = self.preprocess(df)
        
        if mode=="train":
            df = df.query("t0_year != 'April 2019 - April 2020'").drop(columns="t0_year")
        elif mode=="test":
            df = df.query("t0_year == 'April 2019 - April 2020'").drop(columns="t0_year")
        else: 
            raise ValueError(f"No valid mode: '{mode}'")
            
        return df
    
    def load_meta(self, columns):
        # get meta data 
        full_meta = {t + k: v for k, v in meta_dict.items() for t in ["t0_", "t1_"]}
        self.meta = {k: v for k, v in full_meta.items() if k in columns}
    
    def preprocess(self, df):
        # remove certain columns
        endings = config["preprocessing"]["remove_columns_ending_with"]
        cols2drop = [c for c in df.columns if c.endswith(endings)]
        
        df = (
            df.apply(lambda s: filter_in_range(s, **self.meta[s.name])) # filter in range numeric features
            .apply(lambda s: filter_in_labels(s, **self.meta[s.name])) # filter in labels categorical features + ordinal if ordered
            .apply(lambda s: map_labels(s, **self.meta[s.name])) # map the labels as values for readability
            .query("t0_revision_flag == 'no revision'") # drop revision cases
            .drop(columns=cols2drop) # drop not needed columns
        )

        # remove low info values from columns (almost redundant) values
        for col, value in config["preprocessing"]["remove_low_info_categories"].items():
            df[col] = df[col].remove_categories(value)
        
        # remove NaNs/missing/unknown from numerical and ordinal features
        df = (
            df.apply(pd.Series.remove_categories, args=(["missing", "not known"],))
            .dropna(subset= KindSelector(kind="numerical")(df) + KindSelector(kind="ordinal")(df))
            .reset_index(drop=True) # make index unique (prevent blow ups when joining)
        )
        
        return df
        
    def split_XY(self, df):
        
        # define inputs and outputs 
        X = df.filter(regex="t0").copy()
        Y = df[self.outputs].copy()
        
        # get cut from settings
        for col in Y.columns:
            if pd.api.types.is_numeric_dtype(Y[col]):
                Y[col] = pd.cut(
                    Y[col],
                    include_lowest=True,
                    **self.outputs[col],
                )
        
        return X, Y

    def train_models(self):
        X, Y = (
            self.load_data(mode="train")
            .pipe(self.split_XY)
        )
        self.models = dict()
        for col, y in Y.iteritems():
            self.models[col] = self.train_model(X, y)
        return self
        
    def train_model(self, X, y):
        GS = GridSearchCV(
            estimator=pl,
            param_grid=param_grid,
            scoring=config["score"]
        )
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            GS.fit(X, y)
        return GS
    
    def save_models(self, filename=None):
        if filename is None:
            hashable = frozenset(self.models.items())
            sha = hex(hash(hashable))[-5:]
            path = os.path.join("..", config["models"]["path"])
            filename = f"{self.kind}_{sha}.mdl"
        pickle.dump(self.models, open(os.path.join(path, filename), 'wb'))
        
    def load_models(self, filename=None):
        path = os.path.join("..", config["models"]["path"])
        if filename is None:
            filename = most_recent_file(path, ext=".mdl", prefix=self.kind)
            if filename is None:
                raise ValueError("No correct models found!")
        else:
            if not re.search(fr"^{self.kind}_", filename):
                raise Warning(f"File '{filename} does not seem to be having models for {self.kind}")
        self.models = pickle.load(open(os.path.join(path, filename), 'rb'))
        return self
        
    def predict(self, X):
        y_hat = dict()
        for name, model in self.models.items():
            check_is_fitted(model)
            y_hat[name] = model.predict(X)
        return y_hat
            
    def predict_proba(self, X):
        y_hat = dict()
        for name, model in self.models.items():
            check_is_fitted(model)
            y_hat[name] = model.predict_proba(X)
        return y_hat
    
    def labels_encoded(self):
        y_labels = dict()
        for name, model in self.models.items():
            check_is_fitted(model)
            y_labels[name] = model.classes_
        return y_labels
    
    def classification_reports(self):
        data = self.load_data(mode="test")
        X, Y = self.split_XY(data)
        for name, model in self.models.items():
            check_is_fitted(model)
            y_hat = model.predict(X)
            print(f"\nClassification report for {name}:\n")
            print(classification_report(Y[name], y_hat))
            
    def get_explainer(self, name):
        if hasattr(self, "explainers") is False:
            self.explainers = dict()
            
        if self.explainers.get(name) is None:
            model = self.models[name]
            check_is_fitted(model)
            self.explainers[name] = shap.TreeExplainer(
                model.best_estimator_.named_steps["model"],
#                 feature_perturbation='interventional',
#                 model_output="probability",
#                 data=self.load_data("train"),
            )
        return self.explainers[name]
        
                
    def force_plot(self, X, name):

        if X.shape[0] != 1:
            raise ValueError("First dimension should be 1. Expected a single case for force plot!")
        
        model = self.models[name]
        check_is_fitted(model)
        explainer = self.get_explainer(name)
        
        # rescaleing base values for multiclass https://evgenypogorelov.com/multiclass-xgb-shap.html
        def logodds_to_proba(logodds):
            return np.exp(logodds)/np.exp(logodds).sum()
        
        predict_proba = model.predict_proba(X)
        i_max = np.argmax(predict_proba)
        end_value = predict_proba[0, i_max]
        X_pre = model.best_estimator_[:-1].transform(X)
        shap_values = explainer.shap_values(X_pre)[i_max]
        base_value = logodds_to_proba(explainer.expected_value)[i_max]
        feature_names = [re.sub("(t0_|gender_|_yes|_no|)", "", n).replace("_", " ") for n in get_feature_names(model)]
        out_names = f"{name} = {self.labels_encoded()[name][i_max]}"
        
        # rescaling according to https://github.com/slundberg/shap/issues/29
        shap_values = shap_values / shap_values.sum() * (end_value - base_value)
        
        fp = shap.force_plot(
            base_value=base_value, 
            shap_values=shap_values,
#             features=X_pre,
            feature_names=feature_names,
            out_names=out_names,
#             link="logit",
        )
        return fp
    
    def force_plots(self, X=None):
        if X is None:
            df_data = PM.load_data("test").sample()
            X, Y = PM.split_XY(df_data)
        
        for name in self.outputs:
            display(self.force_plot(X, name))

In [None]:
PM = PROMsModel(kind="hip").train_models()

In [None]:
PM.force_plots()