In [50]:
from hcve_lib.evaluation_functions import c_index
from pandas import DataFrame
from functools import partial
from typing import Dict
from hcve_lib.custom_types import Result, Target
from run_training_curves import get_training_curve_file
import pickle
from hcve_lib.functional import t, pipe
from hcve_lib.visualisation import display_html
from hcve_lib.functional import lagged, subtract
from plotly import express as px

from functools import partial
from typing import Dict
from hcve_lib.custom_types import Result, Target
from run_training_curves import get_training_curve_file
import pickle
from hcve_lib.functional import t, pipe
from hcve_lib.visualisation import display_html
from hcve_lib.functional import lagged, subtract

def get_training_curves_data(method_name: str):
    with open(get_training_curve_file(method_name), 'rb') as f:
        return dict(pickle.load(
            f,
        ))


def get_training_curves_c_indexes(results: Result, X: DataFrame, y: Target):
    return {
        split: {n_features: c_index(result, X, y) for n_features, result in results.items()}
        for split, results in results.items()
    }


def plot_training_curves(c_index_per_n_features: Dict[str, float]):
    fig = px.line(DataFrame(c_index_per_n_features).sort_index().interpolate())

    fig.update_layout(
        showlegend=False,
        xaxis_title="n features",
        yaxis_title='c-statistic',
        yaxis_range=[0.5, 0.9],
    )

    fig.show()


def run_training_curves(method_name: str):
    return pipe(
        method_name,
        get_training_curves_data,
        partial(get_training_curves_c_indexes, X, y),
        plot_training_curves,
    )

In [50]:


def get_training_curves_data(method_name: str):
    with open(get_training_curve_file(method_name), 'rb') as f:
        return dict(pickle.load(
            f,
        ))


def get_training_curves_c_indexes(results: Result, X: DataFrame, y: Target):
    return {
        split: {n_features: c_index(t(result, lambda t: t.keys()), X, y) for n_features, result in results.items()}
        for split, results in results.items()
    }


def plot_training_curves(c_index_per_n_features: Dict[str, float]):
    fig = px.line(DataFrame(c_index_per_n_features).sort_index().interpolate())
    # fig.for_each_trace(lambda line: fig.add_annotation(
    #     x=line.x[0], y=line.y[0], text=line.name,
    #     font_color=line.line.color, ax=0, ay=0, xanchor="left", showarrow=False
    # ))
    fig.update_layout(
        showlegend=False,
        xaxis_title="n features",
        yaxis_title='c-statistic',
        yaxis_range=[0.5, 0.9],
    )

    fig.show()


def run_training_curves(method_name: str):
    return pipe(
        method_name,
        get_training_curves_data,
        partial(get_training_curves_c_indexes, X, y),
        plot_training_curves,
    )

In [1]:
from deps.common import get_data_cached

data, metadata, X, y = get_data_cached()

ModuleNotFoundError: No module named 'deps'

In [46]:
import plotly.io as pio

pio.templates.default = "simple_white"

In [124]:
from deps.constants import RANDOM_STATE
from ipywidgets import HBox
from plotly.subplots import make_subplots
from hcve_lib.visualisation import b, h3, h4
from hcve_lib.tracking import load_subrun_results, display_run_info
import plotly.graph_objs as go
import numpy
from hcve_lib.feature_importance import get_permutation_importance, plot_permutation_importance

PERMUTATION_IMPORTANCE_REPEATS = 5


def remove_xaxis(obj):
    obj.data[0].xaxis = None
    return obj


def feature_importance_iterations(results):
    for study_name, results_inner in results:
        h3(study_name)

        feature_order = None
        frames = []

        for fold_number, result in results_inner.items():
            def _get_permutation_importance(_result, train_importance):
                return get_permutation_importance(
                    _result,
                    X,
                    y,
                    train_importance=train_importance,
                    random_state=RANDOM_STATE,
                    n_repeats=PERMUTATION_IMPORTANCE_REPEATS,
                )

            if feature_order is not None:
                importance_train = _get_permutation_importance(result, train_importance=True)
                importance_train = importance_train.loc[feature_order]
            else:
                importance_train = _get_permutation_importance(result, train_importance=True)
                feature_order = importance_train.index

            importance_test = _get_permutation_importance(result, train_importance=False)
            importance_test = importance_test.loc[feature_order]

            frames.append(go.Frame(
                name=fold_number,
                data=[
                    remove_xaxis(plot_permutation_importance(importance_train, metadata)).data[0],
                    remove_xaxis(plot_permutation_importance(importance_test, metadata)).data[0],
                ],
                traces=[0, 1]
            ))

        fig = make_subplots(
            rows=1,
            cols=2,
            shared_yaxes=True,
            shared_xaxes=True,
            subplot_titles=('Train', 'Test'),
        )

        fig_train = frames[0]['data'][0]
        fig.append_trace(fig_train, row=1, col=1)

        fig_test = frames[0]['data'][1]
        fig.append_trace(fig_test, row=1, col=2)

        fig.update_yaxes(tickmode='linear')
        fig.update_layout(dict(margin=go.layout.Margin(
            l=20,
            r=20,
            b=20,
            t=20,
        )))
        fig.add_vline(x=0, line_width=2, opacity=0.3)

        steps = []
        for frame in frames:
            slider_step = {
                "args": [
                    [frame.name],
                    {"frame": {"duration": 300},
                     "mode": "immediate",
                     "transition": {"duration": 300}}
                ],
                "label": frame.name,
                "method": "animate"
            }
            steps.append(slider_step)

        sliders = [dict(
            active=0,
            steps=steps,
            yanchor="top",
            xanchor="left",
            currentvalue={
                "prefix": "Split:",
                "visible": True,
                "xanchor": "right"
            },
            transition={"duration": 300, "easing": "cubic-in-out"},
            pad={"b": 10, "t": 50},
            len=0.9,
            x=0.1,
            y=0,
        )]

        fig.update_layout(
            sliders=sliders,
            updatemenus=[
                {
                    "buttons": [
                        {
                            "args": [None, {"frame": {"duration": 500, "redraw": True},
                                            "fromcurrent": True, "transition": {"duration": 300,
                                                                                "easing": "quadratic-in-out"}}],
                            "label": "Play",
                            "method": "animate"
                        },
                        {
                            "args": [[None], {"frame": {"duration": 0, "redraw": True},
                                              "mode": "immediate",
                                              "transition": {"duration": 0}}],
                            "label": "Pause",
                            "method": "animate"
                        }
                    ],
                    "direction": "left",
                    "pad": {"r": 10, "t": 87},
                    "showactive": False,
                    "type": "buttons",
                    "x": 0.1,
                    "xanchor": "right",
                    "y": 0,
                    "yanchor": "top"
                }
            ]

        )

        fig.frames = frames
        fig.show()



0,1
Name,gb
Experiment,optimized_10_fold_per_study


In [125]:
GB_10_FOLD = '867bf5c253924241a23bb14c84e27a73'
display_run_info(GB_10_FOLD)
feature_importance_iterations(load_subrun_results(GB_10_FOLD).items())

0,1
Name,gb
Experiment,optimized_10_fold_per_study


In [126]:
COXNET_10_FOLD = '27253d3ffacc4a2a9d38e51c7628e730'
display_run_info(COXNET_10_FOLD)
feature_importance_iterations(load_subrun_results(COXNET_10_FOLD).items())

0,1
Name,coxnet
Experiment,optimized_10_fold_per_study


KeyError: "['LDL'] not in index"