In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
np.seterr(divide = 'ignore') 
from scipy.linalg import block_diag
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import cf.counterfactual as cf
import networkx as nx
import copy
import pickle

In [None]:
class Action(object):
    NUM_ACTIONS_TOTAL = 8
    ANTIBIOTIC_STRING = "antibiotic"
    VENT_STRING = "ventilation"
    VASO_STRING = "vasopressors"
    ACTION_VEC_SIZE = 3

    def __init__(self, selected_actions = None, action_idx = None):
        # This method sets up the action object. 
        # Actions can be specified in two ways: by providing a list of selected actions (as strings) or by an action index.
        
        assert (selected_actions is not None and action_idx is None) \
            or (selected_actions is None and action_idx is not None), \
            "must specify either set of action strings or action index"
            
        if selected_actions is not None:
            # For each of the three treatments (ANTIBIOTIC_STRING, VENT_STRING, VASO_STRING), the code checks if its corresponding string is present in the selected_actions. 
            # If it is, the relevant attribute (e.g., self.antibiotic) is set to 1, indicating that treatment is selected. Otherwise, it's set to 0, indicating the treatment is not selected.
            if Action.ANTIBIOTIC_STRING in selected_actions:
                self.antibiotic = 1
            else:
                self.antibiotic = 0
            if Action.VENT_STRING in selected_actions:
                self.ventilation = 1
            else:
                self.ventilation = 0
            if Action.VASO_STRING in selected_actions:
                self.vasopressors = 1
            else:
                self.vasopressors = 0
                
        else:
            # This block decomposes the action_idx (from 0 to 7) into the three binary treatment values (0 or 1). 
            # This process assumes a specific order and numbering scheme for the action index and treatments.
            mod_idx = action_idx
            term_base = Action.NUM_ACTIONS_TOTAL/2
            self.antibiotic = np.floor(mod_idx/term_base).astype(int)
            mod_idx %= term_base
            term_base /= 2
            self.ventilation = np.floor(mod_idx/term_base).astype(int)
            mod_idx %= term_base
            term_base /= 2
            self.vasopressors = np.floor(mod_idx/term_base).astype(int)
            
            '''
            There are three treatments (A, E, V) and thus 2^3 = 8 possible action combinations. 
            The binary representation of action_idx from 0 to 7 can be thought of as the action combinations:

                000 -> No treatments
                001 -> V
                010 -> E
                011 -> E, V
                100 -> A
                101 -> A, V
                110 -> A, E
                111 -> A, E, V
                
            The code block breaks down action_idx to understand which treatments are being used and initializes the three attributes (self.antibiotic, self.ventilation, self.vasopressors) accordingly.
            '''
            
    # Equality and Inequality (__eq__ and __ne__ methods): These are to check the equality or inequality of two Action objects.

    def __eq__(self, other):
        return isinstance(other, self.__class__) and \
            self.antibiotic == other.antibiotic and \
            self.ventilation == other.ventilation and \
            self.vasopressors == other.vasopressors

    def __ne__(self, other):
        return not self.__eq__(other)

    # Get Action Index (get_action_idx method): This method converts the selected actions into an integer index.
    
    def get_action_idx(self):
        assert self.antibiotic in (0, 1)
        assert self.ventilation in (0, 1)
        assert self.vasopressors in (0, 1)
        return 4*self.antibiotic + 2*self.ventilation + self.vasopressors
    '''
    The weighted sum effectively encodes the three binary values into a single integer (form 0 to 7; NUM_ACTIONS_TOTAL = 8 in total). 
    The weights (4, 2, and 1) were chosen to uniquely identify each combination of the three treatments.
    
    For example:

        If only antibiotic is used: action_idx = 4*1 + 2*0 + 0*1 = 4.
        If only ventilation is used: action_idx = 4*0 + 2*1 + 0*1 = 2.
        If antibiotic and ventilation are used: action_idx = 4*1 + 2*1 + 0*1 = 6.
        If all three are used: action_idx = 4*1 + 2*1 + 1*1 = 7.
    '''

    # Hash (__hash__ method): Provides a unique hash for the action object. This is important if you want to use Action objects as keys in a dictionary.

    def __hash__(self):
        return self.get_action_idx()
    
    # Get Selected Actions (get_selected_actions method): Returns a set of selected actions for the object.

    def get_selected_actions(self):
        selected_actions = set()
        if self.antibiotic == 1:
            selected_actions.add(Action.ANTIBIOTIC_STRING)
        if self.ventilation == 1:
            selected_actions.add(Action.VENT_STRING)
        if self.vasopressors == 1:
            selected_actions.add(Action.VASO_STRING)
        return selected_actions
    
    # Abbreviated String (get_abbrev_string method): Returns a short string representation of the actions (A for antibiotic, E for ventilation, V for vasopressors).

    def get_abbrev_string(self):
        '''
        AEV: antibiotics, ventilation, vasopressors
        '''
        output_str = ''
        if self.antibiotic == 1:
            output_str += 'A'
        if self.ventilation == 1:
            output_str += 'E'
        if self.vasopressors == 1:
            output_str += 'V'
        return output_str

    # Action Vector (get_action_vec method): Returns a numpy array representation of the action, with a shape of (3,1).
    
    def get_action_vec(self):
        return np.array([[self.antibiotic], [self.ventilation], [self.vasopressors]])

In [None]:
class State(object):

    NUM_OBS_STATES = 720
    NUM_HID_STATES = 2  # Binary value of diabetes
    NUM_PROJ_OBS_STATES = int(720 / 5)  # Marginalizing over glucose
    NUM_FULL_STATES = int(NUM_OBS_STATES * NUM_HID_STATES)

    def __init__(self, state_idx = None, idx_type = 'obs', diabetic_idx = None, state_categs = None):
    # __init__: Constructor method to initialize the state either by its index or by passing specific categories for each state variable.

        assert state_idx is not None or state_categs is not None
        assert ((diabetic_idx is not None and diabetic_idx in [0, 1]) or
                (state_idx is not None and idx_type == 'full'))

        assert idx_type in ['obs', 'full', 'proj_obs']

        if state_idx is not None:
            self.set_state_by_idx(
                    state_idx, idx_type=idx_type, diabetic_idx=diabetic_idx)
        elif state_categs is not None:
            assert len(state_categs) == 7, "must specify 7 state variables"
            self.hr_state = state_categs[0]
            self.sysbp_state = state_categs[1]
            self.percoxyg_state = state_categs[2]
            self.glucose_state = state_categs[3]
            self.antibiotic_state = state_categs[4]
            self.vaso_state = state_categs[5]
            self.vent_state = state_categs[6]
            self.diabetic_idx = diabetic_idx

    def check_absorbing_state(self):
        # check_absorbing_state: Checks if the state is "absorbing" which means it has a certain 
        # number of abnormal conditions or it is a normal state with no ongoing treatment.
        num_abnormal = self.get_num_abnormal()
        if num_abnormal >= 3:
            return True
        elif num_abnormal == 0 and not self.on_treatment():
            return True
        return False
    
    def state_rewards(self):
        # check_absorbing_state: Checks if the state is "absorbing" which means it has a certain 
        # number of abnormal conditions or it is a normal state with no ongoing treatment.
        num_abnormal = self.get_num_abnormal()
        if num_abnormal >= 3:
            return (-1000)
        elif num_abnormal == 2:
            return (-50)
        elif num_abnormal == 1:
            return (+50)
        elif num_abnormal == 0 and self.on_treatment():
            return (+70)
        elif num_abnormal == 0 and not self.on_treatment():
            return (+1000)

    def set_state_by_idx(self, state_idx, idx_type, diabetic_idx=None):
        
        # set_state_by_idx: interprets the state index into its respective categorical variables. 
        # Depending on the index type (observable, full, or projected observable), the function decodes the index and sets the member variables. 
        # This method employs a form of "bit" arithmetic, even though not all states are binary.
        """set_state_by_idx

        The state index is determined by using "bit" arithmetic, with the
        complication that not every state is binary

        :param state_idx: Given index
        :param idx_type: Index type, either observed (720), projected (144) or
        full (1440)
        :param diabetic_idx: If full state index not given, this is required
        """
        
        # Determine Base for Arithmetic: Depending on the idx_type, the method calculates the term_base. 
        # This base will be used for extracting individual state information from the given index. 
        # The choice of this base reflects the number of categories available for the primary state variables.
        
        if idx_type == 'obs':
            term_base = State.NUM_OBS_STATES/3 # Starts with heart rate
        elif idx_type == 'proj_obs':
            term_base = State.NUM_PROJ_OBS_STATES/3
        elif idx_type == 'full':
            term_base = State.NUM_FULL_STATES/2 # Starts with diab
        

        # Start with the given state index
        mod_idx = state_idx

        if idx_type == 'full':
            # If the idx_type is 'full', the function first extracts the diabetes status (diabetic_idx) 
            # and then adjusts the base for the next state variable (heart rate).
            
            self.diabetic_idx = np.floor(mod_idx/term_base).astype(int)
            mod_idx %= term_base
            term_base /= 3 # This is for heart rate, the next item
        else:
            assert diabetic_idx is not None
            self.diabetic_idx = diabetic_idx

        self.hr_state = np.floor(mod_idx/term_base).astype(int)

        mod_idx %= term_base
        term_base /= 3
        self.sysbp_state = np.floor(mod_idx/term_base).astype(int)

        mod_idx %= term_base
        term_base /= 2
        self.percoxyg_state = np.floor(mod_idx/term_base).astype(int)

        if idx_type == 'proj_obs':
            self.glucose_state = 2
        else:
            mod_idx %= term_base
            term_base /= 5
            self.glucose_state = np.floor(mod_idx/term_base).astype(int)

        mod_idx %= term_base
        term_base /= 2
        self.antibiotic_state = np.floor(mod_idx/term_base).astype(int)

        mod_idx %= term_base
        term_base /= 2
        self.vaso_state = np.floor(mod_idx/term_base).astype(int)

        mod_idx %= term_base
        term_base /= 2
        self.vent_state = np.floor(mod_idx/term_base).astype(int)


    def get_state_idx(self, idx_type='obs'):
        # Opposite of set_state_by_idx. It takes the categorical variables of the state and returns its index. 
        # It constructs the index using the "bit" arithmetic approach.
        '''
        returns integer index of state: significance order as in categorical array
        '''
        
        if idx_type == 'obs':
            categ_num = np.array([3,3,2,5,2,2,2])
            state_categs = [
                    self.hr_state,
                    self.sysbp_state,
                    self.percoxyg_state,
                    self.glucose_state,
                    self.antibiotic_state,
                    self.vaso_state,
                    self.vent_state]
        elif idx_type == 'proj_obs':
            categ_num = np.array([3,3,2,2,2,2])
            state_categs = [
                    self.hr_state,
                    self.sysbp_state,
                    self.percoxyg_state,
                    self.antibiotic_state,
                    self.vaso_state,
                    self.vent_state]
        elif idx_type == 'full':
            categ_num = np.array([2,3,3,2,5,2,2,2])
            state_categs = [
                    self.diabetic_idx,
                    self.hr_state,
                    self.sysbp_state,
                    self.percoxyg_state,
                    self.glucose_state,
                    self.antibiotic_state,
                    self.vaso_state,
                    self.vent_state]

        sum_idx = 0
        prev_base = 1
        for i in range(len(state_categs)):
            idx = len(state_categs) - 1 - i
            sum_idx += prev_base*state_categs[idx]
            prev_base *= categ_num[idx]
        return sum_idx
    
    # __eq__, __ne__, __hash__: Overridden methods to check for equality, inequality, and to generate a hash value respectively.

    def __eq__(self, other):
        '''
        override equals: two states equal if all internal states same
        '''
        return isinstance(other, self.__class__) and \
            self.hr_state == other.hr_state and \
            self.sysbp_state == other.sysbp_state and \
            self.percoxyg_state == other.percoxyg_state and \
            self.glucose_state == other.glucose_state and \
            self.antibiotic_state == other.antibiotic_state and \
            self.vaso_state == other.vaso_state and \
            self.vent_state == other.vent_state

    def __ne__(self, other):
        return not self.__eq__(other)

    def __hash__(self):
        return self.get_state_idx()

    def get_num_abnormal(self):
        # get_num_abnormal: Counts and returns the number of abnormal conditions present in the current state.
        '''
        returns number of abnormal conditions
        '''
        num_abnormal = 0
        if self.hr_state != 1:
            num_abnormal += 1
        if self.sysbp_state != 1:
            num_abnormal += 1
        if self.percoxyg_state != 1:
            num_abnormal += 1
        if self.glucose_state != 2:
            num_abnormal += 1
        return num_abnormal

    # on_treatment, on_antibiotics, on_vasopressors, on_ventilation: These methods check if certain treatments are active.
    
    def on_treatment(self):
        '''
        returns True iff any of 3 treatments active
        '''
        if self.antibiotic_state == 0 and \
            self.vaso_state == 0 and self.vent_state == 0:
            return False
        return True

    def on_antibiotics(self):
        '''
        returns True iff antibiotics active
        '''
        return self.antibiotic_state == 1

    def on_vasopressors(self):
        '''
        returns True iff vasopressors active
        '''
        return self.vaso_state == 1

    def on_ventilation(self):
        '''
        returns True iff ventilation active
        '''
        return self.vent_state == 1

    def copy_state(self):
        return State(state_categs = [
            self.hr_state,
            self.sysbp_state,
            self.percoxyg_state,
            self.glucose_state,
            self.antibiotic_state,
            self.vaso_state,
            self.vent_state],
            diabetic_idx=self.diabetic_idx)

    def get_state_vector(self):
        # get_state_vector: Returns the state as a vector (numpy array).
        return np.array([self.hr_state,
            self.sysbp_state,
            self.percoxyg_state,
            self.glucose_state,
            self.antibiotic_state,
            self.vaso_state,
            self.vent_state]).astype(int)

In [None]:
class MDP(object):

    def __init__(self, init_state_idx=None, init_state_idx_type='obs', policy_array=None, policy_idx_type='obs', p_diabetes=0.2):

        assert p_diabetes >= 0 and p_diabetes <= 1, \
                "Invalid p_diabetes: {}".format(p_diabetes)
        assert policy_idx_type in ['obs', 'full', 'proj_obs']

        # Check the policy dimensions (states x actions)
        if policy_array is not None:
            assert policy_array.shape[1] == Action.NUM_ACTIONS_TOTAL
            if policy_idx_type == 'obs':
                assert policy_array.shape[0] == State.NUM_OBS_STATES
            elif policy_idx_type == 'full':
                assert policy_array.shape[0] == \
                        State.NUM_HID_STATES * State.NUM_OBS_STATES
            elif policy_idx_type == 'proj_obs':
                assert policy_array.shape[0] == State.NUM_PROJ_OBS_STATES

        # p_diabetes is used to generate random state if init_state is None
        self.p_diabetes = p_diabetes
        self.state = None

        # Only need to use init_state_idx_type if you are providing a state_idx!
        self.state = self.get_new_state(init_state_idx, init_state_idx_type)

        self.policy_array = policy_array
        self.policy_idx_type = policy_idx_type  # Used for mapping the policy to actions
        

    def get_new_state(self, state_idx = None, idx_type = 'obs', diabetic_idx = None):

        assert idx_type in ['obs', 'full', 'proj_obs']
        option = None
        if state_idx is not None:
            if idx_type == 'obs' and diabetic_idx is not None:
                option = 'spec_obs'
            elif idx_type == 'obs' and diabetic_idx is None:
                option = 'spec_obs_no_diab'
                diabetic_idx = np.random.binomial(1, self.p_diabetes)
            elif idx_type == 'full':
                option = 'spec_full'
            elif idx_type == 'proj_obs' and diabetic_idx is not None:
                option = 'spec_proj_obs'
        elif state_idx is None and diabetic_idx is None:
            option = 'random'
        elif state_idx is None and diabetic_idx is not None:
            option = 'random_cond_diab'

        assert option is not None, "Invalid specification of new state"

        if option in ['random', 'random_cond_diab'] :
            init_state = self.generate_random_state(diabetic_idx)
            # Do not start in death or discharge state
            while init_state.check_absorbing_state():
                init_state = self.generate_random_state(diabetic_idx)
        else:
            # Note that diabetic_idx will be ignored if idx_type = 'full'
            init_state = State(
                    state_idx=state_idx, idx_type=idx_type,
                    diabetic_idx=diabetic_idx)

        return init_state

    def generate_random_state(self, diabetic_idx=None):
        # generate_random_state: Randomly generates a patient's health state.
        
        # Note that we will condition on diabetic idx if provided
        if diabetic_idx is None:
            diabetic_idx = np.random.binomial(1, self.p_diabetes)

        # hr and sys_bp w.p. [.25, .5, .25]
        hr_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))
        sysbp_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))
        # percoxyg w.p. [.2, .8]
        percoxyg_state = np.random.choice(np.arange(2), p=np.array([.2, .8]))

        if diabetic_idx == 0:
            glucose_state = np.random.choice(np.arange(5), \
                p=np.array([.05, .15, .6, .15, .05]))
        else:
            glucose_state = np.random.choice(np.arange(5), \
                p=np.array([.01, .05, .15, .6, .19]))
        antibiotic_state = 0
        vaso_state = 0
        vent_state = 0

        state_categs = [hr_state, sysbp_state, percoxyg_state,
                glucose_state, antibiotic_state, vaso_state, vent_state]

        return State(state_categs=state_categs, diabetic_idx=diabetic_idx)

    # transition_antibiotics_on/off: Models the effect of turning antibiotics on/off.
    
    def transition_antibiotics_on(self):
        
        '''
        antibiotics state on
        heart rate, sys bp: hi -> normal w.p. .5
        '''
        self.state.antibiotic_state = 1
        if self.state.hr_state == 2 and np.random.uniform(0,1) < 0.5:
            self.state.hr_state = 1
        if self.state.sysbp_state == 2 and np.random.uniform(0,1) < 0.5:
            self.state.sysbp_state = 1

    def transition_antibiotics_off(self):
        '''
        antibiotics state off
        if antibiotics was on: heart rate, sys bp: normal -> hi w.p. .1
        '''
        if self.state.antibiotic_state == 1:
            if self.state.hr_state == 1 and np.random.uniform(0,1) < 0.1:
                self.state.hr_state = 2
            if self.state.sysbp_state == 1 and np.random.uniform(0,1) < 0.1:
                self.state.sysbp_state = 2
            self.state.antibiotic_state = 0

    # transition_vent_on/off: Models the effect of turning ventilation on/off.

    def transition_vent_on(self):
        '''
        ventilation state on
        percent oxygen: low -> normal w.p. .7
        '''
        self.state.vent_state = 1
        if self.state.percoxyg_state == 0 and np.random.uniform(0,1) < 0.7:
            self.state.percoxyg_state = 1

    def transition_vent_off(self):
        '''
        ventilation state off
        if ventilation was on: percent oxygen: normal -> lo w.p. .1
        '''
        if self.state.vent_state == 1:
            if self.state.percoxyg_state == 1 and np.random.uniform(0,1) < 0.1:
                self.state.percoxyg_state = 0
            self.state.vent_state = 0
    
    # transition_vaso_on/off: Models the effect of turning vasopressors on/off, considering if the patient is diabetic.

    def transition_vaso_on(self):
        '''
        vasopressor state on
        for non-diabetic:
            sys bp: low -> normal, normal -> hi w.p. .7
        for diabetic:
            raise blood pressure: normal -> hi w.p. .9,
                lo -> normal w.p. .5, lo -> hi w.p. .4
            raise blood glucose by 1 w.p. .5
        '''
        self.state.vaso_state = 1
        if self.state.diabetic_idx == 0:
            if np.random.uniform(0,1) < 0.7:
                if self.state.sysbp_state == 0:
                    self.state.sysbp_state = 1
                elif self.state.sysbp_state == 1:
                    self.state.sysbp_state = 2
        else:
            if self.state.sysbp_state == 1:
                if np.random.uniform(0,1) < 0.9:
                    self.state.sysbp_state = 2
            elif self.state.sysbp_state == 0:
                up_prob = np.random.uniform(0,1)
                if up_prob < 0.5:
                    self.state.sysbp_state = 1
                elif up_prob < 0.9:
                    self.state.sysbp_state = 2
            if np.random.uniform(0,1) < 0.5:
                self.state.glucose_state = min(4, self.state.glucose_state + 1)

    def transition_vaso_off(self):
        '''
        vasopressor state off
        if vasopressor was on:
            for non-diabetics, sys bp: normal -> low, hi -> normal w.p. .1
            for diabetics, blood pressure falls by 1 w.p. .05 instead of .1
        '''
        if self.state.vaso_state == 1:
            if self.state.diabetic_idx == 0:
                if np.random.uniform(0,1) < 0.1:
                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)
            else:
                if np.random.uniform(0,1) < 0.05:
                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)
            self.state.vaso_state = 0

    def transition_fluctuate(self, hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, glucose_fluctuate):
        
        # transition_fluctuate: Captures the random fluctuations in the patient's state variables.
        
        '''
        all (non-treatment) states fluctuate +/- 1 w.p. .1
        exception: glucose flucuates +/- 1 w.p. .3 if diabetic
        '''
        if hr_fluctuate:
            hr_prob = np.random.uniform(0,1)
            if hr_prob < 0.1:
                self.state.hr_state = max(0, self.state.hr_state - 1)
            elif hr_prob < 0.2:
                self.state.hr_state = min(2, self.state.hr_state + 1)
        if sysbp_fluctuate:
            sysbp_prob = np.random.uniform(0,1)
            if sysbp_prob < 0.1:
                self.state.sysbp_state = max(0, self.state.sysbp_state - 1)
            elif sysbp_prob < 0.2:
                self.state.sysbp_state = min(2, self.state.sysbp_state + 1)
        if percoxyg_fluctuate:
            percoxyg_prob = np.random.uniform(0,1)
            if percoxyg_prob < 0.1:
                self.state.percoxyg_state = max(0, self.state.percoxyg_state - 1)
            elif percoxyg_prob < 0.2:
                self.state.percoxyg_state = min(1, self.state.percoxyg_state + 1)
        if glucose_fluctuate:
            glucose_prob = np.random.uniform(0,1)
            if self.state.diabetic_idx == 0:
                if glucose_prob < 0.1:
                    self.state.glucose_state = max(0, self.state.glucose_state - 1)
                elif glucose_prob < 0.2:
                    self.state.glucose_state = min(4, self.state.glucose_state + 1)
            else:
                if glucose_prob < 0.3:
                    self.state.glucose_state = max(0, self.state.glucose_state - 1)
                elif glucose_prob < 0.6:
                    self.state.glucose_state = min(4, self.state.glucose_state + 1)

    def calculateReward(self):
        
        # calculateReward: Calculates a reward based on the patient's state. The system rewards a healthy state and penalizes an unhealthy state.
        num_abnormal = self.state.get_num_abnormal()
        if num_abnormal >= 3:
            return -1
        elif num_abnormal == 0 and not self.state.on_treatment():
            return 1
        return 0

    def transition(self, action):
        self.state = self.state.copy_state()

        if action.antibiotic == 1:
            self.transition_antibiotics_on()
            hr_fluctuate = False
            sysbp_fluctuate = False
        elif self.state.antibiotic_state == 1:
            self.transition_antibiotics_off()
            hr_fluctuate = False
            sysbp_fluctuate = False
        else:
            hr_fluctuate = True
            sysbp_fluctuate = True

        if action.ventilation == 1:
            self.transition_vent_on()
            percoxyg_fluctuate = False
        elif self.state.vent_state == 1:
            self.transition_vent_off()
            percoxyg_fluctuate = False
        else:
            percoxyg_fluctuate = True

        glucose_fluctuate = True

        if action.vasopressors == 1:
            self.transition_vaso_on()
            sysbp_fluctuate = False
            glucose_fluctuate = False
        elif self.state.vaso_state == 1:
            self.transition_vaso_off()
            sysbp_fluctuate = False

        self.transition_fluctuate(hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, \
            glucose_fluctuate)

        return self.calculateReward()

    def select_actions(self):
        assert self.policy_array is not None
        probs = self.policy_array[
                    self.state.get_state_idx(self.policy_idx_type)
                ]
        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)
        return Action(action_idx = aev_idx)

    def action_idx(self, state_idx):
        assert self.policy_array is not None
        #print(f'state is {state_idx}')
        probs = self.policy_array[state_idx]
        #print(f'probs is {probs}')
        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)
        #print(f'aev_idx is {aev_idx}')
        return aev_idx
        

In [None]:
NSIMSAMPS = 1  # Samples to draw from the simulator (they did a 1000)
NSTEPS = 10  # Max length of each trajectory
NCFSAMPS = 5  # Counterfactual Samples per observed sample (do i need this? probably not, i just need the model)
DISCOUNT_Pol = 0.99 # Used for computing optimal policies
DISCOUNT = 1 # Used for computing actual reward
PHYS_EPSILON = 0.05 # Used for sampling using physician pol as eps greedy
PROB_DIAB = 0.2
n_actions = Action.NUM_ACTIONS_TOTAL

with open("../data/diab_txr_mats-replication.pkl", "rb") as f:
    mdict = pickle.load(f)

tx_mat = mdict["tx_mat"]
r_mat = mdict["r_mat"]
p_mixture = np.array([1 - PROB_DIAB, PROB_DIAB])

tx_mat_full = np.zeros((n_actions, State.NUM_FULL_STATES, State.NUM_FULL_STATES))
r_mat_full = np.zeros((n_actions, State.NUM_FULL_STATES, State.NUM_FULL_STATES))
# tx_mat_full is of the shape (actions, state, state)
for a in range(n_actions):
    tx_mat_full[a, ...] = block_diag(tx_mat[0, a, ...], tx_mat[1, a,...])
    r_mat_full[a, ...] = block_diag(r_mat[0, a, ...], r_mat[1, a, ...])

print(tx_mat_full)
print(r_mat_full)

In [None]:
all_absorbing_states = []
all_absorbing_rewards = []
non_absorbing_states = []
all_rewards = []

for s in range(1440):
    get_states = State(state_idx=s, idx_type = 'full')
    abs = get_states.check_absorbing_state()
    if abs == True: 
        all_absorbing_states.append(s)
        rew = get_states.state_rewards()
        all_absorbing_rewards.append(rew)
        
    if abs == False:
        non_absorbing_states.append(s)

    rew = get_states.state_rewards()
    all_rewards.append(rew)

print(f'winning states are {all_absorbing_states[208]} and {all_absorbing_states[625]}')

for s in range(1440): # for each state
    for a in range(8): # for each action in this state
        if s in all_absorbing_states: # if this state is absorbing 
            tx_mat_full[a, s, :] = np.zeros(1440) # tx_mat_full is of the shape (actions, state, state)
            tx_mat_full[a, s, s] = 1 

for s in range(1440): # for each state
    for a in range(8): # for each action in this state
        if s in all_absorbing_states: # if this state is absorbing 
             # tx_mat_full is of the shape (actions, state, state)
             
            reward_idx = all_absorbing_states.index(s)
            r_mat_full[a, s, :] = np.full((1440,), (all_absorbing_rewards[reward_idx]))
        else:
            for s_p in np.where(tx_mat_full[a, s, :]!=0)[0]:
                r_mat_full[a, s, s_p] = all_rewards[s_p]


rewards_pi = np.zeros((1440, 8)) 

for s in range(1440):
    for a in range(8):
        # Take the action, new state is property of the MDP
        s_p = (np.where(tx_mat_full[a, s, :] == (np.max(tx_mat_full[a, s, :]))))[0][0]
        rewards_pi[s, a] = r_mat_full[a, s, s_p]

In [None]:
fullMDP = cf.MatrixMDP(tx_mat_full, r_mat_full)
fullPol = fullMDP.policyIteration(discount=DISCOUNT_Pol, eval_type=1)

In [None]:
class DataGen(object):
    def __init__(self):
        
        mdp = MDP(init_state_idx=None, policy_array=fullPol, policy_idx_type='full', p_diabetes=PROB_DIAB)
        self.mdp = mdp

    def mdp_sample(self, policy=fullPol, n_obs=1, n_steps=NSTEPS): # trajectory sample for a given policy and MDP
        n_state = 4 # Get the number of states [current state, next state, action, reward]

        # Initialize the trajectories
        trajectories = np.zeros((n_obs, n_steps, n_state))
        # observations, trajectories, states

        # Loop over the observations
        for obs_idx in range(n_obs): # to generate the desired amount of trajectories
            current_state = np.random.choice(non_absorbing_states) # initial state can not be absorbant 
                        
            # Go over time steps
            for time_idx in range(n_steps): # for each time step in the currently generating ("observed") trajectory 
                
                # Get the action
                action = self.mdp.action_idx(state_idx=current_state) # pick a action for that initial state according to the given policy
                
                next_state = np.random.choice(
                    1440, size=1, p=tx_mat_full[action, current_state, :])[0]  # tx_mat_full is of the shape (actions, state, state)
                
                reward = r_mat_full[action, current_state, next_state] # matters what state you **actually** get to

                trajectories[obs_idx, time_idx, :] = np.array([current_state, next_state, action, reward])
            
                current_state = next_state

        return trajectories 

In [None]:
dgen = DataGen()
MDP_samp = dgen.mdp_sample().astype(int)

# Suboptimal Path
# MDP_samp = np.array([[[777., 939.,   3., -50.],
#   [939., 869.,   6., -50.],
#   [869., 869.,   6., -50.],
#   [869., 861.,   6.,  50.],
#   [861., 861.,   6.,  50.],
#   [861., 869.,   6., -50.],
#   [869., 861.,   6.,  50.],
#   [861., 861.,   6.,  50.],
#   [861., 853.,   6., -50.],
#   [853., 853.,   6., -50.]]]).astype(int)


# Catostrophic Path
MDP_samp = np.array([[[  777,   939,     3,   -50],
  [  939,   941,     6,   -50],
  [  941,   949,     6, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000],
  [  949,   949,     0, -1000]]])

print(MDP_samp)

In [None]:
def truncated_gumbel(logit, truncation):
    assert not np.isneginf(logit)

    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit
    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))
    return trunc_g

def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1): # is there only 1 sample? 
    poss_next_states = obs_logits.shape[0]
    gumbels = np.zeros((nsamp, poss_next_states))
    influenced_states = np.full(shape=poss_next_states, fill_value=False)

    # Sample top gumbels
    topgumbel = np.random.gumbel(size=(nsamp))

    for next_state in range(poss_next_states):
        # This is the observed outcome
        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):
            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]
            influenced_states[obs_state] = True
        # These were the other feasible options (p > 0)
        elif not(np.isneginf(obs_logits[next_state])):
            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]
            influenced_states[next_state] = True
        # These have zero probability to start with, so are unconstrained
        else:
            gumbels[:, next_state] = np.random.gumbel(size=nsamp)

    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory

In [None]:
from multiprocessing import Process, Manager

class CounterfactualSampler(object):

    def __init__(self, mdp):
        self.mdp = mdp
        self.sprtb_theta = 0.9
        self.sprtb_delta = 0.05
        self.sprtb_r = 0.9

    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):
        obs_logits = np.log(obs_prob);
        next_state = state
        intrv_logits = np.log(intrv_prob);

        # Sample from the gumbel posterior
        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc);

        # Get the posterior
        posterior = intrv_logits + gumbels
        intrv_posterior = np.argmax(posterior, axis=1)

        # create the counterfactual transition probabilities
        posterior_prob = np.zeros(np.size(intrv_prob, 0))
        for i in range(np.size(intrv_prob, 0)):
            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc

        return posterior_prob, intrv_posterior, influenced_states
    
##########################################################################################################################################################

    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, a, time_idx, P_cf_save, influenced_transitions_save, n_cf_samps=1): 
        n_obs = trajectories.shape[0] 
        n_mc = 1000
        
        for obs_idx in range(n_obs): # for each given "observed" trajectory
            P_cf = {}
            influenced_transitions = {}

            for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given "observed" trajectory 
                    obs_state = trajectories[obs_idx, time_idx, :]
                    obs_current_state = int(obs_state[0]) # same as s_real
                    obs_next_state = int(obs_state[1]) # same as s_p_real
                    obs_action = int(obs_state[2]) # same as a_real

                    P_cf[a, time_idx] = np.zeros((int(1440),int(1440)))
                    influenced_transitions[a, time_idx] = np.full(shape=(int(1440),int(1440)), fill_value=False)

                    # A matrix is initialized to zeros to store transition counts.
                    for s in range(1440):
                    
                        obs_intrv =  tx_mat_full[obs_action, obs_current_state, :]
                        # Get the transition probabilities for the counterfactual state and action:
                        cf_intrv =  tx_mat_full[a, s, :]
                        
                        cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)
                
                        for s_p in range(len(cf_prob)):
                            P_cf[a,time_idx][s,s_p] = cf_prob[s_p]
                            influenced_transitions[a,time_idx][s, s_p] = influenced_states[s_p]

        P_cf_save[(a,time_idx)] = P_cf
        influenced_transitions_save[(a, time_idx)] = influenced_transitions

    def run_sample_tracking_influenced_transitions(self, inp, trajectories, P_cf, influenced_transitions):
        P_cf_save = {}
        influenced_transitions_save = {}

        for i in inp:
            self.cf_sample_prob_tracking_influenced_transitions(trajectories, i[0], i[1], P_cf_save, influenced_transitions_save)

        for i in inp:
            P_cf.update(P_cf_save[i])
            influenced_transitions.update(influenced_transitions_save[i])

    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories):
        n_steps = trajectories.shape[1]
        n_actions = 8
        
        inp = [(a, time_idx) for time_idx in range(n_steps) for a in range(n_actions)]

        # Run with n threads
        def split(a, n):
            k, m = divmod(len(a), n)
            return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
        
        split_work = split(inp, 32)
        processes = []

        with Manager() as manager:
            P_cf = manager.dict()
            influenced_transitions = manager.dict()
            
            for chunk in split_work:
                process = Process(target=self.run_sample_tracking_influenced_transitions, args=(chunk, trajectories, P_cf, influenced_transitions))
                processes.append(process)
                process.start()

            for process in processes:
                process.join()

            return P_cf.copy(), influenced_transitions.copy()

In [None]:
NUM_ITERATIONS = 10000
from collections import deque, defaultdict

print(MDP_samp)

sampler = CounterfactualSampler(dgen)
P_cf, influenced_transitions = sampler.run_parallel_sampling_tracking_influenced_transitions(MDP_samp)

In [None]:
class InfluenceMDPPruner:
    def __init__(self, mdp, look_ahead_k=11):
        self.mdp = mdp
        self.optimal_policy = fullPol
        self.rewards_pi = rewards_pi
        self.sampler = sampler
        self.mdp_sample = MDP_samp
        self.initial_state = self.mdp_sample[0, 0, 0]
        self.states = range(1440)
        self.actions = range(8)
        self.look_ahead_k = look_ahead_k
        self.T = len(self.mdp_sample[0])
        
        # Generate the counterfacutal transition probabilities, keeping track
        # of which transitionals have been influenced by the observed path.
        self.P_cf = P_cf
        self.influenced_transitions = influenced_transitions

    def build_graph(self, transition_probs, T, all_states, all_actions):
        G = nx.MultiDiGraph()
        pos = {}

        print(transition_probs)
        
        for t in range(T):
            for s in all_states:
                G.add_node((t, s))
                pos[(t, s)] = (s, -t)

                for a in all_actions:
                    for s_prime in all_states:
                        if transition_probs[a, s, s_prime] > 0:
                            G.add_node(((t+1), s_prime))
                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) 
                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f"({a}, {transition_probs[a, s, s_prime]})")

        return G

    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):
        G = nx.MultiDiGraph()
        pos = {}
        
        for t in range(T):
            for s in all_states:
                G.add_node((t, s))
                pos[(t, s)] = (s, -t)

                for a in all_actions:
                    for s_prime in all_states:
                        if transition_probs[a, t][s, s_prime] > 0:
                            G.add_node(((t+1), s_prime))
                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) 
                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f"({a}, {transition_probs[a, t][s, s_prime]})")

                # If node has no outgoing edges, remove it from the graph
                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:
                    G.remove_node((t, s))

        # Remove unreachable nodes at t>0 with in-degree = 0
        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}

        while len(unreachable_nodes) > 0:
            G.remove_nodes_from(unreachable_nodes)
            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}

        return G
    
    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):
        print(f"Calculating counterfactual transition probabilities for k={k}")
        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.
        # Remove actions entirely to ensure that the probabilities for each action in
        # each state add up to 1. Keep track of which actions are valid choices in
        # which states.        
        valid_action = np.full((T, len(all_states), len(all_actions)), False)

        T = self.T

        for t in range(T-1, -1, -1):
            for s in range(1440):
                for a in range(8):
                    for s_prime in range(1440):
                        if new_mdp_G.has_node((t, s)) and P_cf[a, t][s, s_prime] > 0.0:
                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):
                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)

                                for imm_descendant in imm_descendants:
                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):
                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)

                # If node has no outgoing edges, remove it from the graph
                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:
                    new_mdp_G.remove_node((t, s))

            # Remove unreachable nodes at t>0 with in-degree = 0
            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}

            while len(unreachable_nodes) > 0:
                new_mdp_G.remove_nodes_from(unreachable_nodes)
                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}

        for t in range(T-1, -1, -1):
            for s in range(1440):
                for a in range(8):
                    for s_prime in range(1440):
                        if P_cf[a, t][s, s_prime] > 0.0:
                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):
                               P_cf[a, t][s, :] = 0.0
                        else:
                            assert(P_cf[a, t][s, s_prime] == 0.0)
                    
                    if sum(P_cf[a, t][s, :]) == 1.0:
                        valid_action[t, s, a] = True

                    if a == A_real[t] and s == S_real[t]:
                        assert(valid_action[t, s, a])

        return P_cf, valid_action

    def find_wcc(self, G): 
        s_0 = self.initial_state
        t_0 = 0
        target_node = (t_0, s_0)
        print(f'the initial state is {s_0}')

        # Create a subgraph containing only the weakly connected component of the target node
        for i, component in enumerate(nx.weakly_connected_components(G)):
            if target_node in component:
                G_sub = G.subgraph(component).copy()
                break
        
        return G_sub

    def get_influence_graph(self, G, k, influenced_transitions):
        print(f"Generating influence graph for k={k}")

        def reverse_bfs(G, start_nodes, k):
            distance = defaultdict(lambda: float('inf'))
            nodes_to_visit = deque([(node, 0) for node in start_nodes])
            within_k_steps = set()

            while nodes_to_visit:
                curr_node, curr_dist = nodes_to_visit.popleft()

                if curr_dist <= k:
                    within_k_steps.add(curr_node)

                    if distance[curr_node] > curr_dist:
                        distance[curr_node] = curr_dist

                        for predecessor in G.predecessors(curr_node):
                            nodes_to_visit.append((predecessor, curr_dist+1))

            return within_k_steps

        directly_influenced_nodes = set()

        for s in range(1440):
            for a in range(8):
                for s_prime in range(1440):
                    for t in range(self.T):
                        if influenced_transitions[a,t][s, s_prime]:
                            directly_influenced_nodes.add((t+1, s_prime))

        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)

        influence_graph = G.subgraph(reachable_nodes).copy()

        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.
        for timestep in range(self.T-k+1, self.T):
            for s in range(1440):
                for a in range(8):
                    for s_prime in range(1440):
                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):
                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)

        # Remove nodes with in-degree = 0 or out-degree = 0
        unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}

        while len(unreachable_nodes) > 0:
            influence_graph.remove_nodes_from(unreachable_nodes)
            unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}

        return influence_graph

    def prune_mdp(self): 
        # Build graph using the original MDP transition probabilities.
        G = self.build_graph(tx_mat_full, self.T, self.states, self.actions)

        # Build the influence graph for each look-ahead k
        influence_graphs = []

        # Generate graphs for the pruned MDP.
        for k in range(1, self.look_ahead_k+1):
            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, self.influenced_transitions)
            influence_graphs.append(influence_graph)

        cf_transition_probs = []
        valid_actions = []

        print(self.mdp_sample)

        A_real = self.mdp_sample[0, :, 2]
        S_real = self.mdp_sample[0, :, 0]

        for look_ahead_k in range(1, self.look_ahead_k+1):
            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(self.P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)
            cf_transition_probs.append(new_P_cf)
            valid_actions.append(valid_action)

        # Generate graphs for the pruned counterfactual MDP.
        cf_graphs = []

        for k in range(1, self.look_ahead_k+1):
            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)
            cf_graphs.append(G)

        return cf_transition_probs, valid_actions, cf_graphs

    def build_graphs(self, cf_transition_probs):
        # Generate graphs for the pruned counterfactual MDP.
        cf_graphs = []

        for k in range(1, self.look_ahead_k+1):
            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)
            cf_graphs.append(G)

        return cf_graphs

    def get_optimal_policy(self, max_num_actions_changed, transition_probs, valid_action, new_mdp_G, all_states, all_actions, A_real, T, rewards_pi):
        h_fun = np.zeros((1440, T+1, max_num_actions_changed+1)) 
        pi = np.zeros((1440, max_num_actions_changed+1, T+1), dtype=int) 
    
        for r in range(1, T+1): 
            for s in range(1440): 
                h_fun[s, r, 0] = rewards_pi[s][(A_real[T-r])]  
                for s_p in range(1440): # for every singe next state (s') for each state s
                    h_fun[s, r, 0] += transition_probs[A_real[T-r], T-r][s, s_p] * h_fun[s_p, r-1, 0]
                pi[s, max_num_actions_changed, T-r] = A_real[T-r]

        for c in range(1, max_num_actions_changed+1): 
            for r in range(1, T+1): 
                for s in range(1440):
                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action
                    best_act = A_real[T-r]
                    max_val = -np.inf
                    
                    for a in range(8):
                        if valid_action[T-r, s, a]:
                            val = rewards_pi[s][a]
                            if a != A_real[T-r]: 
                                for s_p in range(1440):
                                    if transition_probs[a, T-r][s, s_p] != 0:
                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c-1] 
                            elif a == A_real[T-r]:
                                for s_p in range(1440):
                                    if transition_probs[a, T-r][s, s_p] != 0:
                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c]
                                        
                            if val > max_val:
                                max_val = val
                                best_act = a

                    h_fun[s, r, c] = max_val
                    pi[s, max_num_actions_changed-c, T-r] = best_act

        return pi, h_fun

    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):
        # Generate policies for each of the pruned counterfactual MDPs.
        policies = []
        h_funs = []
        k_vals = range(1, self.look_ahead_k+1)
        A_real = self.mdp_sample[0, :, 2]

        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions))) # for now it is the same as the old one - old has rewards_pi[current_state, action]

        for t in range(self.T):
            for s in self.states:
                for a in self.actions:                
                    new_all_rewards[t, s, a] = self.rewards_pi[s, a] # the OG reward

        for look_ahead_k in k_vals:
            print(f"Estimating policy with k={look_ahead_k}")
            policies_k = []
            h_funs_k = []

            for max_num_actions_changed in k_vals:
                # Get the optimal policy
                pi, h_fun = self.get_optimal_policy(
                    max_num_actions_changed, 
                    cf_transition_probs[look_ahead_k-1],
                    valid_actions[look_ahead_k-1],
                    cf_graphs[look_ahead_k-1],
                    self.states,
                    self.actions,
                    A_real,
                    self.T,
                    self.rewards_pi
                )

                policies_k.append(pi)
                h_funs_k.append(h_fun)

            policies.append(policies_k)
            h_funs.append(h_funs_k)

        return policies, new_all_rewards, h_funs

    def generate_random_trajectory(self, MDP_samp, transition_probs, pi, s_0, A_real, all_states, rewards_pi, T):
        n_obs=MDP_samp.shape[0]
        n_state=MDP_samp.shape[2]
        CF_trajectory = np.zeros((n_obs, T, n_state))
        
        rng = np.random.default_rng()
        s = np.zeros(T+1, dtype=int)
        s[0] = s_0   # Initial state the same
        l = np.zeros(T+1, dtype=int)
        l[0] = 0    # Start with 0 changes
        a = np.zeros(T, dtype=int)
        
        for t in range(T):
            # Pick actions according to the given policy
            a[t] = pi[s[t], l[t], t]

            # Sample the next state
            s[t+1] = (rng.choice(a=all_states, size=1,  p=transition_probs[a[t], t][s[t]]))[0]

            # Adjust the number of changes so far
            if a[t] != A_real[t]:
                l[t+1] = l[t] + 1
            else:
                l[t+1] = l[t]
            
            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])
                    
        return CF_trajectory

    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):
        print(f"Generating CF trajectories")
        all_obs = []
        all_cf = []
        k_vals = range(1, self.look_ahead_k+1)
        A_real = self.mdp_sample[0, :, 2]
        s_0 = self.mdp_sample[0, 0, 0]

        for _ in range(1000):
            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))
            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))

            for look_ahead_k in k_vals:
                for max_num_actions_changed in k_vals:
                    CF_trajectory = self.generate_random_trajectory(
                        self.mdp_sample,
                        cf_transition_probs[look_ahead_k-1],
                        policies[look_ahead_k-1][max_num_actions_changed-1],
                        s_0,
                        A_real,
                        self.states,
                        new_all_rewards,
                        self.T
                    )

                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T
                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T
            
            all_obs.append(obs)
            all_cf.append(cf)

        all_obs = np.array(all_obs)
        all_cf = np.array(all_cf)

        mean_obs = all_obs.mean(axis=0)
        mean_cf = all_cf.mean(axis=0)

        return mean_obs, mean_cf, k_vals

# Prune MDP

In [None]:
mdp = MDP(init_state_idx=None, policy_array=fullPol, policy_idx_type='full', p_diabetes=PROB_DIAB)

influence_pruner = InfluenceMDPPruner(mdp, look_ahead_k=11)

In [None]:
cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()

# Generating Policies

In [None]:
policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)

## Print Value Function of S_0

In [None]:
print(MDP_samp)

values = []
k_vals = range(1, 12)
obs_values = None

for look_ahead_k in k_vals:
    k_values = []
    obs_values = []

    for max_num_actions_changed in k_vals:
        h_fun = h_funs[look_ahead_k-1][max_num_actions_changed-1]

        # s_0 = 777

        k_values.append(h_fun[777, -1, max_num_actions_changed])
        obs_values.append(h_fun[777, -1, 0])

    values.append(k_values)

fig = plt.figure(figsize=(4, 6));
ax = fig.add_subplot()

plt.title(f'Value of Initial State Given Influence')
plt.xlabel('Maximum Number of Actions Changed')
plt.ylabel('V(S0)'); 
plt.grid(which='both')

ax.scatter(k_vals, obs_values, color='lightpink', label='Observed reward', marker="o", s=100);
colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']

for look_ahead_k in range(1, 11):
    print(values[look_ahead_k-1])

ax.scatter(k_vals, values[-3], color='blue', label='CF reward', marker="d", s=50)
ax.scatter(k_vals, values[-2], color='deeppink', label='CF reward', marker="d", s=50)
ax.scatter(k_vals, values[-1], color = 'green', label='CF reward', marker="d", s=50)


plt.legend(["Observed Path", "Look-Ahead K=1 to 9", "Look-Ahead K=10 (T)", "Look-Ahead K=∞"], loc=0, frameon=True)
plt.show()

# Generating CF Trajectories

In [None]:
mean_obs, mean_cf, k_vals = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)

In [None]:
fig = plt.figure(figsize=(20, 20))
NUM_ITERATIONS = 1000
ax = fig.add_subplot()

plt.title(f'Final State Reward of Observed vs Counterfactual Paths After Pruning MDP, Averaged Over {NUM_ITERATIONS} Iterations')
plt.xlabel('Maximum Number of Actions Changed');
plt.ylabel('Final State Reward'); 
plt.grid(which='both')

print(mean_obs)
print(mean_cf)

ax.scatter(k_vals, mean_obs[0], color='lightpink', label='Observed reward', marker="o", s=100);
colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']

for look_ahead_k in k_vals[:-1]:
    ax.scatter(k_vals, mean_cf[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker="d", s=50)

ax.scatter(k_vals, mean_cf[-1], color = 'yellow', label='CF reward', marker="d", s=30)

plt.legend(["Observed Path", "Look-Ahead K=1", "Look-Ahead K=2", "Look-Ahead K=3", "Look-Ahead K=4", "Look-Ahead K=5", "Look-Ahead K=6", "Look-Ahead K=7", "Look-Ahead K=8", "Look-Ahead K=9", "Look-Ahead K=10"], loc=0, frameon=True)
plt.show()