In [3]:
import utils
import os
import pandas as pd

# Annual Charge Analysis

In [5]:
DS = 'm24p8'
chg = pd.Series(utils.pkl_load(f'data/{DS}/annual_charges_label_dict.pkl'))
chg

0_0        0
0_1        0
0_2        0
0_3        0
0_4        0
          ..
3140_54    0
3140_55    0
3140_56    0
3140_57    0
3140_58    0
Length: 159967, dtype: int64

In [None]:
DS = 'm24p8'
chg = pd.Series(utils.pkl_load(f'data/{DS}/annual_charges_label_dict.pkl'))

print(f"Mean: {chg.mean()}. Mean is heavily skewed by disproportionate expenses.\nMedian:{chg.quantile(0.5)}")
print("Distribution of chg between the 10th and 95th %ile:")
sns.distplot(chg[chg.between(chg.quantile(0.1), chg.quantile(0.95))])
plt.show()
print("Distribution of chg between the $1 and $20000:")
sns.distplot(chg[chg.between(1, 20000)])
plt.show()
print("Distribution of chg between the $1 and $5000:")
sns.distplot(chg[chg.between(1, 5000)])
plt.show()


In [None]:
print(f" $ value at 80%ile {chg.quantile(0.8)}")
print([f"{x}%ile: {chg.quantile(x/10)}" for x in range(0,10)])
print([f"{x/10}%ile: {chg.quantile(x/100)}" for x in range(80,100,2)])


# Spike analysis
- Spikes are underrepresented in the current m36p12 data. This analysis is not relevant

In [None]:
du.IBDSequenceGenerator(24,8).write_to_disk(self, 'm24p8', tgt_types='spikes', labels_only=True, seq_only=False)


# Plot spikes/cumulative charges

In [None]:
arr = [1,4,9,10,16,24,27,28,51]
# arr = chk[chk==3].sample(9).index

fig, axes = plt.subplots(3,3, figsize = (12,12))
arr = np.array(arr).reshape(3,-1)
for r in range(3):
    for c in range(3):
        df=lib[lib.AUTO_ID==arr[r][c]]
        xax = range(len(df))
        y = df['charges'].cumsum()
        titl = len(df)
        axes[r,c].plot(xax, y)
        axes[r,c].set_title(f"AID: {df.AUTO_ID.values[0]}, Months: {titl}")
fig.tight_layout()
plt.show()

# Train the model

In [None]:
python train.py -b128 -Tannual_charges -Aabnormal_labs,diagnostics,surgery -a0.9 -t0.6 -dm24p8_ds -m2y8m_4 -e300

In [None]:
nohup unbuffer python train.py -b128 -Tannual_charges -Aabnormal_labs,diagnostics,surgery -a0.2 -t0.2 -dm36p12_ds -m3y1y -e300 > models/3y1y.log &

In [None]:
nohup unbuffer python train.py -m 3yAA -dm36p12_AnnualAvg -b128 -e400 > models/3yAA.log &
nohup unbuffer python train.py -m 3yAA_Xian -dm36p12_XIAN_AnnualAvg -b192 -e350 > models/3yAA_XIAN.log &
nohup unbuffer python train.py -m 2yAA -d m24p8_AnnualAvg -b128 -e400 > models/2yAA.log &
nohup unbuffer python train.py -m 2yAA_XIAN -dm24p8_XIAN_AnnualAvg -b192 -e350 > models/2yAA_XIAN.log &

# Evaluate

## Get sequences and their predictions

In [168]:
import utils 
import numpy as np
import pandas as pd


_seq/{MODEL}/{i}.csv')

229200
tn     12939
tp       941
fp       393
fn        76
Name: tag, dtype: int64


In [None]:
# Get TP/TN/FN/FP sequences
dfs = [data(ix) for ix in test_results[test_results.tag=='fn'].index]
pd.concat(dfs, axis=0).to_csv(f'models/patient_seq/{MODEL}/FalseNegatives.csv')

## Classification Report

In [165]:
from sklearn.metrics import classification_report, roc_curve, auc
from pprint import pprint
from scipy.special import softmax
import torch

for MODEL in ['2y8m_3', '2y8m_4']:
    ds = utils.pkl_load('datasets/m24p8_ds/test.dataset')
    test_ids = pd.Series(ds.list_IDs)
    results = pd.DataFrame(np.array(utils.pkl_load(f'models/{MODEL}/test_eval.pkl')))
    results.columns = ['y', 'yhat', 'logit0', 'logit1', 'logit2']
    results['seq_id'] = test_ids

    report = classification_report(results.y, results.yhat, labels=[0,1,2], target_names=['lt_7k', 'lt_53k', 'gt_54k'], output_dict=True)

    y_dum = pd.get_dummies(results.y).values
    y_hats = softmax(results[['logit0', 'logit1', 'logit2']].values, axis=0)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(y_dum[:, i], y_hats[:, i])
        roc_auc[i] = round(auc(fpr[i], tpr[i]), 4)

    print(f"Model: {MODEL}")
    print(f"Accuracy:{report['accuracy']}\nMultilclass ROC:{roc_auc}\n")
    del report['accuracy']
    print(pd.DataFrame.from_dict(report))
    print("\n=====================================================================================\n")


Model: 2y8m_3
Accuracy:0.9173595049071701
Multilclass ROC:{0: 0.8892, 1: 0.7774, 2: 0.9364}

                  lt_7k       lt_53k       gt_54k     macro avg  weighted avg
precision      0.981789     0.566239     0.665488      0.737839      0.933643
recall         0.931668     0.765159     0.879439      0.858755      0.917360
f1-score       0.956072     0.650839     0.757649      0.788187      0.922975
support    13888.000000  1039.000000  1070.000000  15997.000000  15997.000000


Model: 2y8m_4
Accuracy:0.9173595049071701
Multilclass ROC:{0: 0.8892, 1: 0.7774, 2: 0.9364}

                  lt_7k       lt_53k       gt_54k     macro avg  weighted avg
precision      0.981789     0.566239     0.665488      0.737839      0.933643
recall         0.931668     0.765159     0.879439      0.858755      0.917360
f1-score       0.956072     0.650839     0.757649      0.788187      0.922975
support    13888.000000  1039.000000  1070.000000  15997.000000  15997.000000


