# Importance sampling — infinite variance example

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit, prange

In [2]:
%config InlineBackend.figure_format = "retina"

In [63]:
# transition_matrix[left/right, state/end_state]
transition_matrix = np.array([
    [0.9, 0.1],
    [0.0, 1.0]
])


reward_matrix = np.array([
    [0.0, 1.0],
    [-np.inf, 0.0]
])

In [172]:
@njit
def step(state, policy):
    pr_actions = policy[state]
    action = np.random.multinomial(1, pvals=pr_actions).argmax()
    
    next_state = np.random.multinomial(1, pvals=transition_matrix[action]).argmax()
    reward = reward_matrix[action, next_state]
    
    return action, next_state, reward

@njit
def episode(state, policy):
    while state != 1:
        action, state, reward = step(state, policy)
        print(state, action, reward)

In [173]:
policy = np.zeros((
    1, # non-terminal state
    2, # left or right
))

policy[0, 0] = 1/2
policy[0, 1] = 1/2

In [175]:
episode(0, policy)

0 0 0.0
1 1 0.0
