## Code for qualitative analysis - Aggressive treatments

Code is based on: https://github.com/clinicalml/trajectory-inspection/tree/main/notebooks

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt

import cvxpy as cp
import time

from utils import reward_direct_policy_evaluation, cost_direct_policy_evaluation

The defaults go here 

In [3]:
SEED = 0
LAMBDA_R = 1.0
LAMBDA_C = 1.0

COEFFS = (0.0, 1.0)

EPS = 10.0 

In [4]:
# ----------- Defaults MDP parameters -------------
nS, nA = 750, 25
DEATH_STATE = 750
SURVIVAL_STATE = 751

gamma = 0.99

# ----- User args 
SEED_LIST = 0
FREQUENCY_THRESHOLD = 10.0
COST_FOR_RARE_DECISION = 10.0



# -------- Folder Paths -------------
basepath = 'enter/path/here'

# Path variables
IMPORT_PATH = f'{basepath}/m_hat/{SEED}'
OUTPUT_PATH = f'{basepath}/output/{SEED}/freq_{FREQUENCY_THRESHOLD}_cost_{COST_FOR_RARE_DECISION}'

In [5]:
np.random.seed(SEED)

### Load the MDP stats

In [6]:
P_mat, R_mat = pickle.load(open(f"{IMPORT_PATH}/MDP_mat.p", "rb"))
counts_mat = pickle.load(open(f"{IMPORT_PATH}/MDP_counts.p", "rb"))

# load the defaults for this seed
C_mat = pickle.load(open(f"{OUTPUT_PATH}/C_mat.p", "rb"))
pi_baseline = pickle.load(open(f"{OUTPUT_PATH}/pi_baseline.p", "rb"))

In [7]:
def load_soln_policy(coeffs, eps):
    
    sol_name = f'cpi_{coeffs[0]}_{coeffs[1]}_{eps}.p'
    
    pi_solution = pickle.load(open(f'{OUTPUT_PATH}/{sol_name}', 'rb'))
    
    return pi_solution

In [8]:
pi_regular = pickle.load(open(f'{OUTPUT_PATH}/regular_PI_solution.p', 'rb'))

In [9]:
pi_sopt = load_soln_policy(COEFFS, EPS)

pi_sopt.shape

(752, 25)

Load the train corresponding to a seed

In [10]:
traj_tr = pickle.load(open(f'{IMPORT_PATH}/trajDr_tr.pkl', 'rb'))
print('Effective sample size of train set', len(traj_tr))

Effective sample size of train set 14667


In [11]:
orig_count_sa = counts_mat.sum(axis=-1)

In [12]:
orig_state_count = orig_count_sa.sum(-1)

## Qualitive Analysis starts here

In [13]:
def fix_p( p):
    if p.sum() != 1.0:
        p = p*(1./p.sum())
    return p

In [14]:
def select_action(pi_s, mode='sample'):
    """
    mode: sample/argmax
        -> sample: then samples from the probility distribution
        -> argmax: select the argmax corresponding to prob distribution
    """
    if mode=='sample':
        # clip the probs to round of errors 
        pi_s = np.clip(pi_s, a_min=0, a_max=1.0)
        pi_s = fix_p(pi_s)
        # sample a here 
        a_sampled = np.random.choice(np.arange(nA), p=pi_s)
        return a_sampled
    elif mode=='argmax':
        a_argmax = np.argmax(pi_s)
        return a_argmax
    else:
        raise NotImplementedError("No other mode implemented yet!")

In [15]:
def get_percentages(pi, mode='sample'):
    """
    takes a policy as input and returns the rl percentages associated with it
    """
    rl_percentages = []
    
#     for s in tqdm(range(nS)):
    for s in range(nS):
        # 
        a = select_action(pi[s], mode)
        # 
        approx_count = orig_count_sa[s, a]
        rl_percentages.append(approx_count/float(orig_state_count[s]))
        
    return rl_percentages

### Calculate the state occurrence proportion corresponding to top-K

In [16]:
def calc_state_occurrence_proportion(top_num, pi, mode='sample'):
    rl_percentages = get_percentages(pi, mode)
    sorted_idxs = np.argsort(np.array(rl_percentages))
    top_idxs = sorted_idxs[0:top_num]
    total_state_count = 0
    for idx in top_idxs:
        total_state_count += orig_state_count[idx]
        
    return total_state_count/float(np.sum(orig_state_count))

## Calculate RL and Common avg counts

In [17]:
def calc_rl_common_avg_count(pi, top_num, mode='sample'):
    # calculate rl percentages corresponding to pi
    rl_percentages = get_percentages(pi)
    
    # sort and find the coomon for top_num
    sorted_idxs = np.argsort(np.array(rl_percentages))
    top_idxs = sorted_idxs[0:top_num]
    total_rl_count = 0
    total_common_count = 0
    counts = np.empty(top_num)
    j = 0
    
    for idx in top_idxs:
        # select action based on a mode
        a = select_action(pi[idx], mode)
    
        counts[j] = orig_count_sa[idx, a]
        j += 1
        total_rl_count += orig_count_sa[idx, a]
        total_common_count += np.max(orig_count_sa[idx])
    
    print(f'RL-Alg avg count: {total_rl_count/float(top_num)}')
    print(f'Most common avg count: {total_common_count/float(top_num)}')
    
    return np.unique(counts,return_counts=True)

## Get RL and Common Percentage

In [18]:
def calc_rl_common_avg_perc(pi, top_num, mode='sample'):
    # calculate the percentage again 
    rl_percentages = get_percentages(pi)
    
    sorted_idxs = np.argsort(np.array(rl_percentages))
    #top_idxs = sorted_idxs[5:5+top_num]
    top_idxs = sorted_idxs[0:top_num]
    
    rl_percs = np.empty(top_num)
    common_percs = np.empty(top_num)
    j = 0
    for idx in top_idxs:
        # round and take care of floating point issues 
        a = select_action(pi[idx], mode)
    
        rl_percs[j] = orig_count_sa[idx, a]/float(orig_state_count[idx])
        common_percs[j] = np.max(orig_count_sa[idx])/float(orig_state_count[idx])
        j += 1
        
    print(f'Avg RL percentage: {np.average(rl_percs)}')
    print(f'Avg Common percentage: {np.average(common_percs)}')

## Top-K Quantify

In [19]:
TOP_K_LIST  = [750]

In [20]:
for top_num in tqdm(TOP_K_LIST):
    
    print(f'==================================')
    print(f'Analysing for top-{top_num} states now')
    print(f'==================================')
    
    for mode in ['sample', 'argmax']:
        
        print(f'----    Mode:{mode}    ----')
        
        # for Regular PI         
        print(f'For regular policy (unconstrained) iteration:')
        result = calc_state_occurrence_proportion(top_num, pi_regular, mode)
        print(f'Top {top_num} states with smallest proportions make up {result*100}% of transitions.')
        print(f'Count stats:')
        calc_rl_common_avg_count(pi_regular, top_num, mode)
        print(f'Percentage stats:')
        calc_rl_common_avg_perc(pi_regular, top_num, mode)

        print(f'---------')
        # for SOPT         
        print(f'For SOPT iteration:')
        result = calc_state_occurrence_proportion(top_num, pi_sopt, mode)
        print(f'Top {top_num} states with smallest proportions make up {result*100}% of transitions.')
        print(f'Count stats:')
        calc_rl_common_avg_count(pi_sopt, top_num, mode)
        print(f'Percentage stats:')
        calc_rl_common_avg_perc(pi_sopt, top_num, mode)


        print(f'---------')
        # for baseline         
        print(f'For Baseline iteration:')
        result = calc_state_occurrence_proportion(top_num, pi_baseline, mode)
        print(f'Top {top_num} states with smallest proportions make up {result*100}% of transitions.')
        print(f'Count stats:')
        calc_rl_common_avg_count(pi_baseline, top_num, mode)
        print(f'Percentage stats:')
        calc_rl_common_avg_perc(pi_baseline, top_num, mode)
    
    time.sleep(1)

  0%|          | 0/1 [00:00<?, ?it/s]

Analysing for top-750 states now
----    Mode:sample    ----
For regular policy (unconstrained) iteration:
Top 750 states with smallest proportions make up 100.0% of transitions.
Count stats:
RL-Alg avg count: 5.348
Most common avg count: 138.268
Percentage stats:
Avg RL percentage: 0.0300642940017241
Avg Common percentage: 0.5147115615063391
---------
For SOPT iteration:
Top 750 states with smallest proportions make up 100.0% of transitions.
Count stats:
RL-Alg avg count: 54.812
Most common avg count: 138.268
Percentage stats:
Avg RL percentage: 0.22320083497254017
Avg Common percentage: 0.5147115615063391
---------
For Baseline iteration:
Top 750 states with smallest proportions make up 100.0% of transitions.
Count stats:
RL-Alg avg count: 94.516
Most common avg count: 138.268
Percentage stats:
Avg RL percentage: 0.36112408192784856
Avg Common percentage: 0.5147115615063391
----    Mode:argmax    ----
For regular policy (unconstrained) iteration:
Top 750 states with smallest proporti

100%|██████████| 1/1 [00:02<00:00,  2.39s/it]


## Vasopresser claim

In [21]:
def calc_vaso_stats(pi, top_num=750, mode='sample'):
    rl_percentages = get_percentages(pi, mode)
    # sort
    sorted_idxs = np.argsort(np.array(rl_percentages))

    top_idxs = sorted_idxs[0:top_num]
    
    common_vasopressor_counts = np.zeros(5,dtype=np.int32)
    rl_vasopressor_counts = np.zeros(5,dtype=np.int32)
    total_rl_vasopressor_counts = np.zeros(5,dtype=np.int32)
    
    for idx in top_idxs:
        rl_action = select_action(pi[idx], mode)
        common_action = np.argmax(orig_count_sa[idx])

        # for common action
#         row, col = a // 5, a % 5
#         vaso, iv = col, row
        
        common_vasopressor_counts[int(common_action%5)] += 1
        if common_action%5 == 0:
            rl_vasopressor_counts[int(rl_action%5)] += 1
        total_rl_vasopressor_counts[int(rl_action%5)] += 1
            
    print(common_vasopressor_counts)
    print(rl_vasopressor_counts)
    print(total_rl_vasopressor_counts)
    
    msg = (f'For {common_vasopressor_counts[0]} of these {top_num} states, common practice involves zero vasopressors.'
           f' Yet, the RL policy recommends vasopressors in {common_vasopressor_counts[0]-rl_vasopressor_counts[0]} of those states'
           f' ({((common_vasopressor_counts[0]-rl_vasopressor_counts[0])/float(common_vasopressor_counts[0]))*100:0.2f} %)'
           f', with {sum(rl_vasopressor_counts[-2:])} of those recommendations being large doses, which we define as those in the upper 50th percentile of nonzero amounts.'
        )
    print(msg)

In [22]:
calc_vaso_stats(pi_regular, top_num=750, mode='argmax')

[722   1   1  14  12]
[160 137 130 139 156]
[175 142 134 142 157]
For 722 of these 750 states, common practice involves zero vasopressors. Yet, the RL policy recommends vasopressors in 562 of those states (77.84 %), with 295 of those recommendations being large doses, which we define as those in the upper 50th percentile of nonzero amounts.


In [23]:
calc_vaso_stats(pi_sopt,  top_num=750, mode='argmax')

[722   1   1  14  12]
[588  36  28  46  24]
[595  37  31  53  34]
For 722 of these 750 states, common practice involves zero vasopressors. Yet, the RL policy recommends vasopressors in 134 of those states (18.56 %), with 70 of those recommendations being large doses, which we define as those in the upper 50th percentile of nonzero amounts.
