In [None]:
# Import all modules 
import numpy as np
from multiprocessing import Pool
import matplotlib.pyplot as plt
import seaborn as sns
!pip3 install networkx
import networkx as nx
import copy
!pip3 install pydot
np.seterr(divide = 'ignore') 
from scipy.stats import hypergeom
import numpy as np

In [None]:
# Constants
MAX_POPULATION = 10

state_space = [(S, I, V) for S in range(MAX_POPULATION + 1)
                         for I in range(MAX_POPULATION + 1)
                         for V in range(2 * MAX_POPULATION + 1)]
num_states = len(state_space)
state_index = {state: i for i, state in enumerate(state_space)}
state_from_index = {i: state for i, state in enumerate(state_space)}

# Action Space
actions = ["NIL", "V_I", "V_S"]
action_index = {"NIL": 0, "V_I": 1, "V_S": 2}
action_from_index = {0: "NIL", 1: "V_I", 2: "V_S"}
num_actions = len(actions)

# Initialize Transition Matrix and Reward Matrix
transition_matrix = np.zeros((num_actions, num_states, num_states))
reward_matrix = np.zeros((num_states, num_actions))

# Function to compute transition probabilities
def compute_transitions(S, I, V, action):
    M = S + I  # Total population for hypergeometric distribution
    transitions = {}

    if action == "NIL":
        N = S
        n = min(S, I)
        V_prime = V

        for k in range(S + 1):
            prob = hypergeom(M, n, N).pmf(k)
            S_prime, I_prime = S - k, I + k

            if S_prime >= 0 and I_prime <= MAX_POPULATION:
                transitions[(S_prime, I_prime, V_prime)] = prob

    elif action == "V_I" and I > 0 and V > 0:
        M -= 1
        N = S
        n = min(S, I - 1)
        V_prime = V - 1

        for k in range(S + 1):
            prob = hypergeom(M, n, N).pmf(k)
            S_prime, I_prime = S - k, I - 1 + k
            if S_prime >= 0 and I_prime <= MAX_POPULATION:
                transitions[(S_prime, I_prime, V_prime)] = prob

    elif action == "V_S" and S > 0 and V > 0:
        M -= 1
        N = S - 1
        n = min(S - 1, I)
        V_prime = V - 1
        for k in range(S):
            prob = hypergeom(M, n, N).pmf(k)
            S_prime, I_prime = S - 1 - k, I + k
            if S_prime >= 0 and I_prime <= MAX_POPULATION:
                transitions[(S_prime, I_prime, V_prime)] = prob

    return transitions

# Compute Transition and Reward Matrices
for action_idx, action in enumerate(actions):
    for state in state_space:
        S, I, V = state
        state_idx = state_index[state]
        transitions = compute_transitions(S, I, V, action)

        # Update transition matrix
        for next_state, prob in transitions.items():
            next_state_idx = state_index[next_state]
            transition_matrix[action_idx, state_idx, next_state_idx] = prob

        # Update reward matrix (negative of the number of infected individuals)
        reward_matrix[state_idx, action_idx] = -I

transition_matrix = np.nan_to_num(transition_matrix)

# Output the size of the matrices for verification
print(transition_matrix.shape)
print(reward_matrix.shape)

In [None]:
class MDP:
    def __init__(self, states, state_index, actions, action_index, transition_probabilities, rewards):
        self.states = states
        self.n_states = len(states)
        self.state_index = state_index
        self.actions = actions
        self.n_actions = len(actions)
        self.action_index = action_index
        self.transition_probabilities = transition_probabilities
        self.rewards = rewards
        self.initial_state = self.get_initial_state()
        self.T = 7

    def get_initial_state(self):
        return (9, 1, 20)
        
    def get_optimal_policy(self):
        # The suboptimal policy performs NIL when V is below a threshold, and
        # always performs V_I (i.e., no prevention, vaccinate only infected ones)
        pi = np.zeros((len(self.states))) 
        threshold = 10 # 20% of the population

        for state in self.states:
            (S, I, V) = state

            if V <= threshold:
                pi[self.state_index[state]] = self.action_index["NIL"]
            elif I == 0:
                # No one to vaccinate.
                pi[self.state_index[state]] = self.action_index["NIL"]
            else:
                pi[self.state_index[state]] = self.action_index["V_I"]
        
        return pi

    def get_suboptimal_policy(self):
        # Never vaccinates.
        
        pi = np.zeros((len(self.states))) 

        for state in self.states:
            pi[self.state_index[state]] = self.action_index["NIL"]

        return pi
    
    def generate_suboptimal_path(self, policy):
        print(f"Initial State = {self.initial_state}")
        path_to_print = []
        samp = []
        rng = np.random.default_rng()
    
        s = self.state_index[self.initial_state]

        for t in range(self.T+1):
            a = policy[s]
            a = int(a)

            print(np.sum(self.transition_probabilities[a, s]))
            s_prime = (rng.choice(a=range(len(self.states)), size=1,  p=self.transition_probabilities[a, s]))[0]
            samp.append([t, s, a, s_prime])
            path_to_print.append([t, state_from_index[s], action_from_index[a], state_from_index[s_prime]])
            print(path_to_print[-1])
            s = s_prime

        return path_to_print, np.array([samp])

epidemic_mdp = MDP(state_space, state_index, actions, action_index, transition_matrix, reward_matrix)
policy = epidemic_mdp.get_suboptimal_policy()

MDP_samp = np.array([[[   0, 2120,    0, 1910],
  [   1, 1910,    0, 1700],
  [   2, 1700,    0, 1070],
  [   3, 1070,    0,  650],
  [   4,  650,    0,  440],
  [   5,  440,    0,  440],
  [   6,  440,    0,  440]]])

print(MDP_samp)

In [None]:
def truncated_gumbel(logit, truncation):
    assert not np.isneginf(logit)

    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit
    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))
    return trunc_g

def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1): # is there only 1 sample? 
    poss_next_states = obs_logits.shape[0]
    gumbels = np.zeros((nsamp, poss_next_states))
    influenced_states = np.zeros(shape=poss_next_states)

    # Sample top gumbels
    topgumbel = np.random.gumbel(size=(nsamp))

    for next_state in range(poss_next_states):
        # This is the observed outcome
        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):
            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]
            influenced_states[obs_state] = 1
        # These were the other feasible options (p > 0)
        elif not(np.isneginf(obs_logits[next_state])):
            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]
            influenced_states[next_state] = 1
        # These have zero probability to start with, so are unconstrained
        else:
            gumbels[:, next_state] = np.random.gumbel(size=nsamp)

    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory

In [None]:
class CounterfactualSampler(object):
    def __init__(self, mdp):
        self.mdp = mdp
        self.sprtb_theta = 0.9
        self.sprtb_delta = 0.05
        self.sprtb_r = 0.9
    
    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):
        obs_logits = np.log(obs_prob)
        next_state = state
        intrv_logits = np.log(intrv_prob)
        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc)
        posterior = intrv_logits + gumbels
        intrv_posterior = np.argmax(posterior, axis=1)
        posterior_prob = np.zeros(np.size(intrv_prob, 0))
        
        for i in range(np.size(intrv_prob, 0)):
            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc

        return posterior_prob, intrv_posterior, influenced_states

    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, T, influenced_transitions, n_cf_samps=1): 
        n_obs = trajectories.shape[0] 
        n_mc = 1000

        P_cf = np.zeros(shape=(self.mdp.n_states, self.mdp.n_actions, self.mdp.n_states, T))
        
        for a in range(self.mdp.n_actions):
            for t in range(T):
                for obs_idx in range(n_obs): # for each given "observed" trajectory 
                    # Get the observed trajectory
                    for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given "observed" trajectory 
                            obs_state = trajectories[obs_idx, t, :]

                            obs_current_state = int(obs_state[1]) # same as s_real
                            obs_next_state = int(obs_state[3]) # same as s_p_real
                            obs_action = int(obs_state[2]) # same as a_real

                            for s in range(self.mdp.n_states):
                                obs_intrv = self.mdp.transition_probabilities[obs_action, obs_current_state, :]
                                cf_intrv = self.mdp.transition_probabilities[a, s, :]
                                cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)
                                
                                for s_p in range(len(cf_prob)):
                                    P_cf[s, a, s_p, t] = cf_prob[s_p]
                                
                                influenced_transitions[s, a, :, t] = influenced_states
        
        return P_cf, influenced_transitions

    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories, influenced_transitions):
        n_steps = trajectories.shape[1]

        P_cf, influenced_transitions = self.cf_sample_prob_tracking_influenced_transitions(trajectories, n_steps, influenced_transitions)

        return P_cf, influenced_transitions

In [None]:
NUM_ITERATIONS = 10000
from collections import deque, defaultdict

class InfluenceMDPPruner:
    def __init__(self, mdp, mdp_sample, look_ahead_k = 1):
        self.mdp = mdp
        self.rewards_pi = reward_matrix
        self.sampler = CounterfactualSampler(self.mdp)
        self.mdp_sample = mdp_sample
        self.initial_state = mdp_sample[0][1]
        self.look_ahead_k = look_ahead_k
        self.T = len(self.mdp_sample[0])
        self.states = range(mdp.n_states)
        self.actions = range(mdp.n_actions)

    def build_graph(self, transition_probs):
        G = nx.MultiDiGraph()
        pos = {}
        
        for t in range(self.T):
            for s in range(self.mdp.n_states):
                G.add_node((t, s))
                pos[(t, s)] = (s, -t)

                for a in range(self.mdp.n_actions):
                    for s_prime in range(self.mdp.n_states):

                        if transition_probs[a, s, s_prime] > 0:
                            G.add_node(((t+1), s_prime))
                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) 
                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f"({a}, {transition_probs[a, s, s_prime]})")

        return G

    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):
        G = nx.MultiDiGraph()
        pos = {}
        
        for t in range(T):
            for s in all_states:
                G.add_node((t, s))
                pos[(t, s)] = (s, -t)

                for a in all_actions:
                    for s_prime in all_states:
                        if transition_probs[s, a, s_prime, t] > 0:
                            G.add_node(((t+1), s_prime))
                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) 
                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f"({a}, {transition_probs[s, a, s_prime, t]})")

                # If node has no outgoing edges, remove it from the graph
                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:
                    G.remove_node((t, s))

        # Remove unreachable nodes at t>0 with in-degree = 0
        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}

        while len(unreachable_nodes) > 0:
            G.remove_nodes_from(unreachable_nodes)
            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}

        return G
    
    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):
        print(f"Calculating counterfactual transition probabilities for k={k}")

        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.
        # Remove actions entirely to ensure that the probabilities for each action in
        # each state add up to 1. Keep track of which actions are valid choices in
        # which states.        
        valid_action = np.full((T, len(self.states), len(self.actions)), False)

        for t in range(T-1, -1, -1):
            for s in self.states:
                for a in self.actions:
                    for s_prime in self.states:
                        if new_mdp_G.has_node((t, s)) and P_cf[s, a, s_prime, t] > 0.0:
                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):
                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)

                                for imm_descendant in imm_descendants:
                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):
                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)

                # If node has no outgoing edges, remove it from the graph
                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:
                    new_mdp_G.remove_node((t, s))

            # Remove unreachable nodes at t>0 with in-degree = 0
            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}

            while len(unreachable_nodes) > 0:
                new_mdp_G.remove_nodes_from(unreachable_nodes)
                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}

        for t in range(T-1, -1, -1):
            for s in all_states:
                for a in all_actions:
                    for s_prime in all_states:
                        if P_cf[s, a, s_prime, t] > 0.0:
                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):
                               P_cf[s, a, :, t] = 0.0
                        else:
                            assert(P_cf[s, a, s_prime, t] == 0.0)
                    
                    if round(sum(P_cf[s, a, :, t]), 10) == 1.0:
                        valid_action[t, s, a] = True

        return P_cf, valid_action

    def get_influence_graph(self, G, k, influenced_transitions):
        print(f"Generating influence graph for k={k}")

        def reverse_bfs(G, start_nodes, k):
            distance = defaultdict(lambda: float('inf'))
            nodes_to_visit = deque([(node, 0) for node in start_nodes])
            within_k_steps = set()

            while nodes_to_visit:
                curr_node, curr_dist = nodes_to_visit.popleft()

                if curr_dist <= k:
                    within_k_steps.add(curr_node)

                    if distance[curr_node] > curr_dist:
                        distance[curr_node] = curr_dist

                        for predecessor in G.predecessors(curr_node):
                            nodes_to_visit.append((predecessor, curr_dist+1))

            return within_k_steps
            
        directly_influenced_nodes = set()

        for s in self.states:
            for a in self.actions:
                for s_prime in self.states:
                    for t in range(self.T):
                        if influenced_transitions[s, a, s_prime, t]:
                            directly_influenced_nodes.add((t+1, s_prime))

        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)
        influence_graph = G.subgraph(reachable_nodes).copy()

        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.
        for timestep in range(self.T-k+1, self.T):
            for s in self.states:
                for a in self.actions:
                    for s_prime in self.states:
                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):
                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)
    
        # Remove nodes with in-degree = 0 or out-degree = 0
        unreachable_nodes = {n for n in influence_graph.nodes if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}

        while len(unreachable_nodes) > 0:
            influence_graph.remove_nodes_from(unreachable_nodes)
            unreachable_nodes = {n for n in influence_graph.nodes if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}
        
        return influence_graph

    def prune_mdp(self):
        # Initialise a matrix to keep track of which transitions' probabilities
        # are directly influenced by the observed trajectory.
        influenced_transitions = np.zeros(shape=(len(self.states), len(self.actions), len(self.states), self.T+1))

        # Generate the counterfacutal transition probabilities, keeping track
        # of which transitionals have been influenced by the observed path.
        P_cf, influenced_transitions = self.sampler.run_parallel_sampling_tracking_influenced_transitions(self.mdp_sample, influenced_transitions)

        # Build graph using the original MDP transition probabilities.
        G = self.build_graph(self.mdp.transition_probabilities)

        # Build the influence graph for each look-ahead k
        influence_graphs = []

        # Generate graphs for the pruned MDP.
        for k in range(1, self.look_ahead_k+1):
            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, influenced_transitions)
            influence_graphs.append(influence_graph)

        cf_transition_probs = []
        valid_actions = []

        A_real = self.mdp_sample[0, :, 2]
        S_real = self.mdp_sample[0, :, 1]

        for look_ahead_k in range(1, self.look_ahead_k+1):
            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)
            cf_transition_probs.append(new_P_cf)
            valid_actions.append(valid_action)

        # Generate graphs for the pruned counterfactual MDP.
        cf_graphs = []

        for k in range(1, self.look_ahead_k+1):
            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)
            cf_graphs.append(G)

        return cf_transition_probs, valid_actions, cf_graphs
    
    def get_optimal_policy(self, max_num_actions_changed, P_cf, valid_action, new_mdp_G, all_states, all_actions, S_real, A_real, T, rewards_pi):
        if len(all_states) == 0:
            return None

        for t in range(T):
            for s in all_states:
                for a in all_actions:
                    if valid_action[t, s, a]:
                        assert(round(sum(P_cf[s, a, :, t]), 10) == 1.0)
                    else:
                        assert(round(sum(P_cf[s, a, :, t]), 10) == 0.0)

        h_fun = np.zeros((len(all_states), T+1, max_num_actions_changed+1)) 
        pi = np.zeros((len(all_states), max_num_actions_changed+1, T+1), dtype=int) 
    
        for r in range(1, T+1): # last r steps of the decision making process
            for s in all_states: # for all possible states
                h_fun[s, r, 0] = rewards_pi[(T-r), s, (A_real[T-r])] # for all time steps counting backwards (T-r is T-1, T-2 etc) 

                for s_p in all_states: # for every singe next state (s') for each state s
                    h_fun[s, r, 0] += P_cf[s, A_real[T-r], s_p, T-r] * h_fun[s_p, r-1, 0] # P_cf[obs_ind][a,t][s,s']

                pi[s, max_num_actions_changed, T-r] = A_real[T-r]

        # For t=1,...,T-2 do recursive computations
        for c in range(1, max_num_actions_changed+1): # iterates over the number of changes allowed
            for r in range(1, T+1): # iterates over the time steps in reverse order
                for s in all_states:
                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action
                    best_act = A_real[T-r]                
                    max_val = -np.inf

                    for a in all_actions: # For each state and action, it computes the value based on rewards and future values.
                        if valid_action[T-r, s, a]:
                            val = rewards_pi[T-r][s][a]
                            
                            # If an action differs from the observed action, the number of remaining changes (c) decreases.
                            if a != A_real[T-r]:
                                for s_p in all_states:
                                    if P_cf[s, a, s_p, T-r] != 0:
                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c-1] 
                            elif a == A_real[T-r]:
                                for s_p in all_states:
                                    if P_cf[s, a, s_p, T-r] != 0:
                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c]

                            if val > max_val:
                                max_val = val
                                best_act = a
                    
                    h_fun[s, r, c] = max_val

                    if max_val == -np.inf:
                        pi[s, max_num_actions_changed-c, T-r] = A_real[T-r]
                    else:
                        pi[s, max_num_actions_changed-c, T-r] = best_act

        return pi, h_fun

    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):
        # Generate policies for each of the pruned counterfactual MDPs.
        policies = []
        h_funs = []
        k_vals = range(1, self.look_ahead_k+1)
        S_real = self.mdp_sample[0, :, 1]
        A_real = self.mdp_sample[0, :, 2]

        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions)))

        for t in range(self.T):
            for s in self.states:
                for a in self.actions:                
                    new_all_rewards[t, s, a] = self.rewards_pi[s, a]

        for look_ahead_k in k_vals:
            print(f"Estimating policy with k={look_ahead_k}")
            policies_k = []
            h_funs_k = []

            for max_num_actions_changed in k_vals:
                # Get the optimal policy
                pi, h_fun = self.get_optimal_policy(
                    max_num_actions_changed, 
                    cf_transition_probs[look_ahead_k-1],
                    valid_actions[look_ahead_k-1],
                    cf_graphs[look_ahead_k-1],
                    self.states,
                    self.actions,
                    S_real,
                    A_real,
                    self.T,
                    new_all_rewards
                )

                policies_k.append(pi)
                h_funs_k.append(h_fun)

            policies.append(policies_k)
            h_funs.append(h_funs_k)

        return policies, new_all_rewards, h_funs


    def generate_random_trajectory(self, MDP_samp, P_cf, pi, s_0, A_real, all_states, rewards_pi, T):
        n_obs=MDP_samp.shape[0]
        n_state=MDP_samp.shape[2]
        CF_trajectory = np.zeros((n_obs, T, n_state))
        
        rng = np.random.default_rng()
        s = np.zeros(T+1, dtype=int)
        s[0] = s_0   # Initial state the same
        l = np.zeros(T+1, dtype=int)
        l[0] = 0    # Start with 0 changes
        a = np.zeros(T, dtype=int)
        
        for t in range(T):
            # Pick actions according to the given policy
            a[t] = pi[s[t], l[t], t]

            # Sample the next state
            s[t+1] = (rng.choice(a=self.states, size=1,  p=P_cf[s[t], a[t], :, t]))[0]

            # Adjust the number of changes so far
            if a[t] != A_real[t]:
                l[t+1] = l[t] + 1
            else:
                l[t+1] = l[t]
            
            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])
                    
        return CF_trajectory

    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):
        print(f"Generating CF trajectories")
        all_obs = []
        all_cf = []
        k_vals = range(1, self.look_ahead_k+1)
        A_real = self.mdp_sample[0, :, 2]
        s_0 = self.mdp_sample[0, 0, 1]

        for _ in range(1000):
            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))
            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))

            for look_ahead_k in k_vals:
                for max_num_actions_changed in k_vals:
                    CF_trajectory = self.generate_random_trajectory(
                        self.mdp_sample,
                        cf_transition_probs[look_ahead_k-1],
                        policies[look_ahead_k-1][max_num_actions_changed-1],
                        s_0,
                        A_real,
                        self.states,
                        new_all_rewards,
                        self.T
                    )

                    print(CF_trajectory)

                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T
                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T
            
            all_obs.append(obs)
            all_cf.append(cf)

        all_obs = np.array(all_obs)
        all_cf = np.array(all_cf)

        mean_obs = all_obs.mean(axis=0)
        mean_cf = all_cf.mean(axis=0)

        return mean_obs, mean_cf, k_vals

In [None]:
influence_pruner = InfluenceMDPPruner(epidemic_mdp, MDP_samp, look_ahead_k=epidemic_mdp.T+1)

In [None]:
cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()

In [None]:
policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)

In [None]:
mean_obs, mean_cf, k_vals = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)

In [None]:
values = []
T = len(MDP_samp[0])
obs_values = None

for look_ahead_k in k_vals:
    k_values = []
    obs_values = []

    for max_num_actions_changed in k_vals:
        h_fun = h_funs[look_ahead_k-1][max_num_actions_changed-1]
        # s_0 = 2120
        obs_values.append(h_fun[2120, -1, 0])
        k_values.append(h_fun[2120, -1, max_num_actions_changed])

    values.append(k_values)

fig = plt.figure(figsize=(8, 8));
ax = fig.add_subplot()

plt.xlabel('Maximum Number of Actions Changed', fontsize=14)
plt.ylabel('Value of Initial State',fontsize=14); 
plt.grid(which='both')

ax.scatter(k_vals, obs_values, color='lightpink', label='Observed reward', marker="o", s=100);
colors = ['darkblue', 'darkblue', 'darkviolet', 'red', 'green', 'orange', 'blue', 'grey']
markers = ['x', 'x', 'd', '+', '.', '^', 's', 'x']

for look_ahead_k in k_vals[1:]:
    ax.scatter(k_vals, values[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker=markers[look_ahead_k-1], s=50)

plt.legend(["Observed Path", "K=1 to K=2", "K=3", "K=4", "K=5", "K=6", "K=7", "K=T+1"], loc=0, frameon=True, fontsize=12)
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 10));
ax = fig.add_subplot()

plt.xlabel('Maximum Number of Actions Changed');
plt.ylabel('Final State Reward'); 
plt.grid(which='both')

ax.scatter(k_vals, mean_obs[-1], color='lightpink', label='Observed reward', marker="o", s=100);
colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']

for look_ahead_k in k_vals:
    ax.scatter(k_vals, mean_cf[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker="d", s=50)

plt.legend(["Observed Path", "Look-Ahead K=1", "Look-Ahead K=2", "Look-Ahead K=3", "Look-Ahead K=4", "Look-Ahead K=5", "Look-Ahead K=6", "Look-Ahead K=7", "Look-Ahead K=T+1"], loc=0, frameon=True)
plt.show()