In [56]:
from hcve_lib.utils import notebook_init, get_first_entry

notebook_init()

%autoreload 2

from deps.common import get_data_cached
from plotly import express as px
from pandas import DataFrame
from hcve_lib.evaluation_functions import c_index, merge_predictions


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
data, metadata, X, y = get_data_cached()

[Memory]19.2s, 0.3min   : Loading get_data...
____________________________________________get_data cache loaded - 0.0s, 0.0min


In [101]:
import pandas
from hcve_lib.evaluation_functions import map_inverse_weight
from hcve_lib.metrics import WeightedCIndex
from hcve_lib.utils import transpose_dict
from toolz import valmap
from tensorflow.python.ops.sparse_ops import map_values
from matplotlib.figure import Figure
from functools import partial
from typing import Dict, Iterable
from hcve_lib.custom_types import Result, Target, Prediction
from run_training_curves import get_training_curve_file
import pickle
from hcve_lib.functional import t, pipe, flatten
from hcve_lib.visualisation import display_html, setup_plotly_style
from hcve_lib.functional import lagged, subtract


def get_training_curves_data(method_name: str):
    with open(get_training_curve_file(method_name), 'rb') as f:
        return dict(pickle.load(
            f,
        ))

def get_training_curves_c_indexes(results: Result, y: Target):
    return {
        split: {n_features: WeightedCIndex(weight=inverse_weight).get_values(result, y=y)[0] for n_features, result in results.items()}
        for split, results in results.items()
    }


inverse_weight = map_inverse_weight(data['STUDY'])

def get_training_curves_c_indexes2(curve, y: Target):
    return {n_features: WeightedCIndex(weight=inverse_weight).get_values(prediction, y=y)[0] for n_features, prediction in curve.items()}

def interpolate_metrics(metrics: DataFrame) -> DataFrame:
    return DataFrame(metrics).sort_index().interpolate()

def run_training_curves(method_names: Iterable[str]) -> Figure:
    return pipe(
        method_names,
        partial(valmap, get_training_curves_data),
        partial(valmap, partial(get_training_curves_c_indexes,y=y)),
        partial(valmap, DataFrame),
        partial(valmap, interpolate_metrics),
        partial(plot_training_curves, show_annotations=False),
    )

def run_average_training_curves(method_names: Iterable[str]) -> Figure:
    return pipe(
        method_names,
        partial(valmap, get_training_curves_data),
        partial(valmap, partial(get_training_curves_c_indexes,y=y)),
        partial(valmap, interpolate_metrics),
        partial(valmap, partial(DataFrame.mean, axis=1)),
        partial(plot_training_curves, show_annotations=False),
    )

def merge_curve(cr):
    merged_result = {}
    for n in n_features:
        merged_for_n = []
        for test_name, by_n in cr.items():
            for n_proposed in by_n:
                if n_proposed <= n:
                    merged_for_n.append(by_n[n_proposed])
                    break

        merged_result[n] = merge_predictions({i: prediction for i, prediction in enumerate(merged_for_n)})

    return merged_result
def run_merged_training_curves(method_names: Iterable[str]) -> Figure:
    return pipe(
        method_names,
        partial(valmap, get_training_curves_data),
        partial(valmap, merge_curve),
        partial(valmap, partial(get_training_curves_c_indexes2, y=y)),
        partial(plot_training_curves, show_annotations=False),
    )


def merge_curve(cr):
    merged_result = {}
    for n in n_features:
        merged_for_n = []
        for test_name, by_n in cr.items():
            for n_proposed in by_n:
                if n_proposed <= n:
                    merged_for_n.append(by_n[n_proposed])
                    break

        merged_result[n] = merge_predictions({i: prediction for i, prediction in enumerate(merged_for_n)})

    return merged_result

def plot_training_curves(
    metrics_per_features: Dict[str, DataFrame],
    show_annotations: bool = True,
):
    metrics_per_features_ = interpolate_metrics(DataFrame(metrics_per_features))

    fig = px.line(
        metrics_per_features_,
        color_discrete_sequence=[Set1[1], Set1[4]],
    )

    setup_plotly_style(fig)

    if show_annotations:
        fig.for_each_trace(lambda line: fig.add_annotation(
            x=line.x[-1], y=line.y[-1], text=line.name,
            font_color=line.line.color, ay=10, xanchor="left", showarrow=False
        ))

    max_n_features = pipe(
        metrics_per_features_.index,
        max,
    )

    fig.update_layout(
        showlegend=False,
        xaxis_title="n features",
        yaxis_title='c-statistic',
        xaxis_tickvals = list(range(1, max_n_features , 5)),
        template='simple_white',
        margin=dict(l=0, r=0, t=0, b=0),
    )

    return fig


In [46]:
cr = get_training_curves_data('coxnet')

In [61]:
n_features = list(get_first_entry(cr).keys())

In [74]:
merged_result[33][0].keys()

dict_keys(['split', 'X_columns', 'y_column', 'y_score', 'y_proba', 'model', 'random_state', 'method'])

In [102]:

from matplotlib.figure import Figure
from _plotly_utils.colors.colorbrewer import Set1

def update_style_training_curve(fig: Figure):
    fig.update_layout({
        'yaxis_range':[0.6,0.75],
        'xaxis': {
            'showgrid': True,
        },
        'yaxis': {
            'showgrid': True,
        }
    })

    fig.update_traces({
        'line': {
            'width': 3,
            'shape': 'spline',
        },
    })

    return fig


fig = run_average_training_curves({'CoxNet': 'coxnet', 'Gradient Boosting': 'gb'})
fig.write_image('./data/training_curves.svg')
fig.show()
