In [1]:
import copy
from collections.abc import Iterable
import functools
import itertools
import operator
from matplotlib import pyplot as plt
import matplotlib as mpl

mpl.use('pgf')
plt.rcParams.update({
    "font.family": "serif",  # use serif/main font for text elements
    "text.usetex": True,     # use inline math for ticks
    "pgf.rcfonts": False,    # don't setup fonts from rc parameters
    "text.latex.preamble":  [r"""\usepackage{amssymb}""", r'\usepackage{amsmath}'],
    })
# mpl.verbose.level = 'debug-annoying'


import pandas as pd
from pandas.api.types import is_numeric_dtype
import numpy as np
import numpy_ext as npe
import math
import random
from pprint import pprint
from scipy.optimize import curve_fit
from scipy.stats import poisson
from scipy.sparse import hstack, vstack, csr_matrix
import scipy

from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.impute import KNNImputer
from sklearn.preprocessing import Normalizer, StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn import metrics
import joblib

import seaborn as sns
import utils
import sys

from config import demographics, vital_sign_vars, lab_vars, treatment_vars, vent_vars, guideline_vars, ffill_windows_clinical, SAMPLE_TIME_H
from config import fio2_bins, peep_bins, tv_bins

In [2]:
# TEXTWIDTH=390.0 # AI in Medicine
TEXTWIDTH=341.43289 # Dissertation
inches_per_pt = 1 / 72.27
MAX_FIGWIDTH = TEXTWIDTH * inches_per_pt

greedy_policy_file = 'models2/mcp_greedy_policy_{}_{}_{}_{}.bin'
sm_policy_file = 'models2/mcp_softmax_policy_{}_{}_{}_{}.bin'
behavior_policy_train_file = 'models2/clinicians_policy_train_{}{}.bin'
behavior_policy_test_file = 'models2/clinicians_policy_test_{}{}.bin'
behavior_policy_file = 'models2/clinicians_policy_train_test_{}{}.bin'

test_set_file = 'data/test_unshaped_traj_{}.csv'
train_set_file = 'data/train_unshaped_traj_{}.csv'

In [3]:
sm_safe_policy = joblib.load(sm_policy_file.format(0, 'none', 0.0, 0.0))
sm_unsafe_policy = joblib.load(sm_policy_file.format(0, 'none', 0.0, 1.0))
behavior_policy = joblib.load(behavior_policy_file.format(0, ''))
greedy_safe = joblib.load(greedy_policy_file.format(0, 'none', 0.0, 0.0))
greedy_unsafe = joblib.load(greedy_policy_file.format(0, 'none', 0.0, 1.0))

test_set = pd.read_csv(test_set_file.format(0), low_memory=False)

In [4]:
actions_greedy_safe = test_set.state.apply(lambda x: greedy_safe[x])
actions_greedy_unsafe = test_set.state.apply(lambda x: greedy_unsafe[x])
observed_action_counts = test_set.action_discrete.value_counts()
actions_observed = [observed_action_counts[i] if i in observed_action_counts else 0 for i in range(343)]

In [5]:
actions_greedy_safe = pd.DataFrame({'sum': np.nansum(np.array(actions_greedy_safe.to_list()), axis=0),})
actions_greedy_unsafe = pd.DataFrame({'sum': np.nansum(np.array(actions_greedy_unsafe.to_list()), axis=0),})
actions_behavior = pd.DataFrame({'sum': actions_observed,})

In [6]:
actions_greedy_safe.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_greedy_safe.index))
actions_greedy_unsafe.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_greedy_unsafe.index))
actions_behavior.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_behavior.index))

  actions_greedy_safe.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_greedy_safe.index))
  actions_greedy_unsafe.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_greedy_unsafe.index))
  actions_behavior.loc[:, ['tv_bin', 'fio2_bin', 'peep_bin']] = list(map(utils.to_discrete_action_bins, actions_behavior.index))


In [7]:
fmt=".1f"
greedy_fmt = "0.0f"

def paired_heatmap(to_plot, title="", vs=(None, None)):
    fig, axs = plt.subplots(figsize=(MAX_FIGWIDTH, 2.5), ncols=3, gridspec_kw=dict(width_ratios=[20,20,1],hspace=0,wspace=.05))
    hm1 = to_plot.pivot_table(values='sum', aggfunc='sum', columns=['peep_bin',], index='fio2_bin')
    hm2 = to_plot.pivot_table(values='sum', aggfunc='sum', columns=['tv_bin',], index='fio2_bin')
    if vs == (None, None):
        vmin = min(hm1.min().min(), hm2.min().min())
        vmax = max(hm1.max().max(), hm2.max().max())
    else:
        vmin, vmax = vs
    g1 = sns.heatmap(hm1,
                annot=True,
                fmt=greedy_fmt,
                annot_kws={'fontsize': 7},
                vmin=vmin,
                vmax=vmax,
                cbar=False,
               ax=axs[0])
    g1.set_xlabel('PEEP')
    g1.set_ylabel('FiO$_2$')
    g2 = sns.heatmap(hm2,
                annot=True,
                fmt=greedy_fmt,
                vmin=vmin,
                vmax=vmax,
                annot_kws={'fontsize': 7},
               ax=axs[1],
               cbar=False)
    g2.set_xlabel('Vt$_{set}$')
    g2.set(yticklabels=[])
    g2.set_ylabel('')
    fig.colorbar(axs[1].collections[0], cax=axs[2])
    plt.show()
    plt.suptitle(title, y= 0.95)
    plt.subplots_adjust(bottom=0.18)
#     fig.tight_layout()
    return fig, (vmin, vmax)

fig, vs = paired_heatmap(actions_behavior, 'Observed')
plt.savefig('/tmp/actions_observed.pdf')
plt.savefig('/tmp/actions_observed.png', dpi=1200)


fig, _ = paired_heatmap(actions_greedy_unsafe, 'QL$_D$ unconstrained', vs=vs)
plt.savefig('/tmp/actions_unsafe.pdf')
plt.savefig('/tmp/actions_unsafe.png', dpi=1200)

fig, _ = paired_heatmap(actions_greedy_safe, 'QL$_D$ compliant/Q-function', vs=vs)
plt.savefig('/tmp/actions_safe.pdf')
plt.savefig('/tmp/actions_safe.png', dpi=1200)


    

# fig, axs = plt.subplots(figsize=(MAX_FIGWIDTH, 2.5), ncols=3, gridspec_kw=dict(width_ratios=[20,20,1],wspace=.05))
# hm1 = actions_greedy_safe.pivot_table(values='sum', aggfunc='sum', columns=['peep_bin',], index='fio2_bin')
# hm2 = actions_greedy_safe.pivot_table(values='sum', aggfunc='sum', columns=['tv_bin',], index='fio2_bin')
# vmin = min(hm1.min().min(), hm2.min().min())
# vmax = max(hm1.max().max(), hm2.max().max())
# sns.heatmap(hm1,
#             annot=True,
#             fmt=greedy_fmt,
#             annot_kws={'fontsize': 7},
#             vmin=vmin,
#             vmax=vmax,
#             cbar=False,
#            ax=axs[0])
# sns.heatmap(hm2,
#             annot=True,
#             fmt=greedy_fmt,
#             annot_kws={'fontsize': 7},
#            ax=axs[1],
#            cbar=False)
# axs[1].set_ylabel('')
# fig.colorbar(axs[1].collections[0], cax=axs[2])
# fig.tight_layout()
# plt.show()

# sns.heatmap(actions_greedy_unsafe.pivot_table(values='sum', aggfunc='sum', columns=['fio2_bin',], index='peep_bin'), annot=True, fmt=greedy_fmt)
# plt.show()
# sns.heatmap(actions_behavior.pivot_table(values='sum', aggfunc='sum', columns=['fio2_bin',], index='peep_bin'), annot=True, fmt=greedy_fmt)
# plt.show()

  plt.show()
  plt.show()
  plt.show()


In [8]:
hm1.max(axis=)

SyntaxError: invalid syntax (3103198376.py, line 1)

In [None]:
actions_greedy_safe

In [None]:
sm_unsafe_policy