In [233]:
from hcve_lib.utils import notebook_init

notebook_init()

%autoreload 2

from deps.common import get_data_cached
from plotly import express as px
from pandas import DataFrame
from hcve_lib.evaluation_functions import c_index


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [103]:
data, metadata, X, y = get_data_cached()

In [298]:
from toolz import valmap
from tensorflow.python.ops.sparse_ops import map_values
from matplotlib.figure import Figure
from functools import partial
from typing import Dict, Iterable
from hcve_lib.custom_types import Result, Target
from run_training_curves import get_training_curve_file
import pickle
from hcve_lib.functional import t, pipe, flatten
from hcve_lib.visualisation import display_html
from hcve_lib.functional import lagged, subtract


def get_training_curves_data(method_name: str):
    with open(get_training_curve_file(method_name), 'rb') as f:
        return dict(pickle.load(
            f,
        ))

def get_training_curves_c_indexes(results: Result, y: Target):
    return {
        split: {n_features: c_index(result, y=y) for n_features, result in results.items()}
        for split, results in results.items()
    }

def interpolate_metrics(metrics: DataFrame) -> DataFrame:
    return DataFrame(metrics).sort_index().interpolate()

def run_training_curves(method_names: Iterable[str]) -> Figure:
    return pipe(
        method_names,
        partial(valmap, get_training_curves_data),
        partial(valmap, partial(get_training_curves_c_indexes,y=y)),
        partial(valmap, DataFrame),
        partial(valmap, interpolate_metrics),
        partial(plot_training_curves, show_annotations=False),
    )

def run_average_training_curves(method_names: Iterable[str]) -> Figure:
    return pipe(
        method_names,
        partial(valmap, get_training_curves_data),
        partial(valmap, partial(get_training_curves_c_indexes,y=y)),
        partial(valmap, interpolate_metrics),
        partial(valmap, partial(DataFrame.mean, axis=1)),
        partial(plot_training_curves, show_annotations=False),
    )

def plot_training_curves(
    metrics_per_features: Dict[str, DataFrame],
    show_annotations: bool = True,
    size: int = 250,
):
    metrics_per_features_ = interpolate_metrics(DataFrame(metrics_per_features))

    fig = px.line(
        metrics_per_features_,
        width=size,
        height=size,
        color_discrete_sequence=[Set1[1], Set1[4]],
    )

    if show_annotations:
        fig.for_each_trace(lambda line: fig.add_annotation(
            x=line.x[-1], y=line.y[-1], text=line.name,
            font_color=line.line.color, ay=10, xanchor="left", showarrow=False
        ))

    max_n_features = pipe(
        metrics_per_features_.index,
        max,
    )

    fig.update_layout(
        showlegend=False,
        xaxis_title="n features",
        yaxis_title='c-statistic',
        xaxis_tickvals = list(range(1, max_n_features , 5)),
        template='simple_white',
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return fig

In [154]:
 pipe(
     'gb',
     get_training_curves_data,
     partial(get_training_curves_c_indexes, y=y),
     DataFrame,
 )

Unnamed: 0,ASCOT,FLEMENGHO,HEALTHABC,HVC,PREDICTOR,PROSPER
33,0.716796,0.819762,0.6644,0.720895,0.699522,0.669449
28,0.735778,0.818734,,,,
23,0.739241,0.839926,,,,0.667844
18,0.735445,0.859215,,,,0.678791
13,0.732423,0.848979,,,,0.67802
8,0.723771,0.846459,0.66468,0.745537,0.697512,0.674704
7,0.709684,0.833393,0.660649,0.693236,0.708136,0.674769
6,0.716754,0.848002,0.661461,0.684184,0.704801,0.668533
5,0.716249,0.837045,0.653569,0.64873,0.69607,0.659372
4,0.702595,0.825266,0.649535,0.663188,0.688236,0.662958


In [152]:
run_training_curves('gb')

In [127]:
run_training_curves('coxnet')

In [300]:

from matplotlib.figure import Figure
from _plotly_utils.colors.colorbrewer import Set1

def update_style_training_curve(fig: Figure):
    fig.update_layout({
        'font': {
            'size': 20,
        },
        'yaxis_range':[0.6,0.75],
        'xaxis': {
            'showgrid': True,
        },
        'yaxis': {
            'showgrid': True,
        }
    })

    fig.update_traces({
        'line': {
            'width': 3,
            'shape': 'spline',
        },
    })

    return fig


fig = run_average_training_curves({'CoxNet': 'coxnet', 'Gradient Boosting': 'gb'})
update_style_training_curve(fig)
fig.write_image('./data/training_curves.svg')
fig.show()