# LoS Evaluation

Evaluate the length of stay prediction model using MAE, median error, and limits of agreement.

In [None]:
import os
import pickle
from matplotlib import pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
from scipy import stats

import sys
sys.path.append('..')

from ltss import los_model, risk_model, vectorise
from ltss.utils import reshape_vector, vector_to_dict
from training.loader import DataHandler

## Generate results using LOS and CDF Models

In [None]:
LOS_MODEL = los_model.init_model(model_file='../training/mod_ep_110')
CDF_MODEL = risk_model.init_model(model_file='../training/trained_shuffle_100_v2.pkl')
if os.path.exists('checkpoint.pkl'):
    print('Loading Checkpoint')
    with open('checkpoint.pkl', 'rb') as f:
        d = pickle.load(f)
    print(d.keys(), 'containing', len(d['dataset']), 'entries')
    dataset = d['dataset']
    true_los = d['true_los']
    results = d['results']
else:
    # Get a DataHandler with the same shuffle and seed settings as used in training
    loader = DataHandler('../NHSX Polygeist data 1617 to 2021 v2.csv', shuffle=True, fixed_seed=100, reshape=False)
    print(f'Loaded {loader}')
    # Use the validation cut of the data only
    data, true_los = loader.get_validation()
    true_los = true_los.cpu().detach().numpy()
    # Flush streams to sync before using tqdm again
    sys.stdout.flush()
    sys.stderr.flush()
    # Get classification for validation data through both models at once
    results = []
    for i, vector in enumerate(tqdm(data)):
        # Convert to dict representation for the risk model, and predict LTSS
        vector_dict = vector_to_dict(vector)
        forecast = los_model.get_prediction(LOS_MODEL, vector_dict)
        risk_predictions = risk_model.get_prediction(CDF_MODEL, vector_dict, 
                                                     ai_day_prediction=forecast.get('PREDICTED_LOS'))
        # Store results
        results.append(dict(forecast, **risk_predictions))
    dataset = data.cpu().detach().numpy()
    # Save checkpoint for fast reuse
    with open('checkpoint.pkl', 'wb') as f:
        pickle.dump(dict(dataset=dataset, true_los=true_los, results=results), f)

In [None]:
predicted_los = np.asarray([r['PREDICTED_LOS'] for r in results])

print(f'Generated {len(predicted_los)} LoS predictions from {len(true_los)} ground truth entries')

## Evaluate Error Rate

In [None]:
print('Mean absolute error:', np.mean(np.abs(true_los - predicted_los)))
print('Median error:', np.median(np.abs(true_los - predicted_los)))

In [None]:
short_stays = true_los < 15
print('Mean absolute error (short):', np.mean(np.abs(true_los[short_stays] - predicted_los[short_stays])))
print('Median error (short):', np.median(np.abs(true_los[short_stays] - predicted_los[short_stays])))
print('Mean absolute error (long):', np.mean(np.abs(true_los[~short_stays] - predicted_los[~short_stays])))
print('Median error (long):', np.median(np.abs(true_los[~short_stays] - predicted_los[~short_stays])))

## Evaluate Limits of Agreement

In [None]:
%matplotlib notebook

def evaluate(true_los, predicted_los, title, alpha=0.2, cats=None):
    irr_x = true_los
    irr_y = true_los - predicted_los
    mean = np.mean(irr_y)
    stdev = np.std(irr_y)
    limits = [mean + 1.96 * stdev, mean - 1.96 * stdev]
    for confidence in [0.5, 0.75, 0.9, 0.95]:
        sds = stats.norm.ppf(1 - (1 - confidence) / 2)
        print(f'{title} {int(100 * confidence)}% limits of agreement: ±{sds*stdev:0.2f} days')
    plt.figure()
    plt.title(title)
    plt.scatter(irr_x, irr_y, alpha=alpha, c=cats)
    plt.hlines(mean, np.min(irr_x), np.max(irr_x), linestyles='solid', color='#B22330', label='$\mu$ ({:0.2f} days)'.format(mean))
    plt.hlines(limits, np.min(irr_x), np.max(irr_x), linestyles='dashed', color='#03716E', label='95% LoA ($\sigma$ {:0.2f} days gives {:0.2f})'.format(stdev, 1.96*stdev))
    plt.xlabel('True Length of Stay')
    plt.ylabel('Agreement')
    plt.legend(loc='upper left')
    plt.show()

In [None]:
evaluate(true_los, predicted_los, 'All Validation Data', alpha=0.05)