In [3]:
%matplotlib inline

import gym
import matplotlib
import numpy as np
import sys

from collections import defaultdict
from blackjack import BlackjackEnv
import plotting

matplotlib.style.use('ggplot')

env = BlackjackEnv()

def create_random_policy(nA):
    probs = np.ones(nA, dtype=float) / nA
    def policy_fn(state):
        return probs
    
    return policy_fn
    
    
def create_greedy_policy(Q):
    def policy_fn(state):
        probs = np.zeros_like(Q[state], dtype=float)
        best_action = np.argmax(Q[state])
        probs[best_action] = 1.0
        return probs
    
    return policy_fn

def mc_control_importance_sampling(env, num_episodes, behavior_policy, discount_factor=1.0):
    Q = defaultdict(lambda: np.zeros(env.action_space.n))
    # cumulative denominator across all episodes
    C = defaultdict(lambda: np.zeros(env.action_space.n))
    
    target_policy = create_greedy_policy(Q)
    
    for i_episode in range(1, num_episodes+1):
        episode = []
        state = env.reset()
        for t in range(100):
            probs = behavior_policy(state)
            action = np.random.choice(np.arange(len(probs)), p=probs)
            next_state, reward, done, _ = env.step(action)
            episode.append((state, action, reward))
            if done:
                break
            state = next_state
            
        G = 0.0
        W = 1.0
        
        for t in range(len(episode))[::-1]:
            state, action, reward = episode[t]
            G = discount_factor*G + reward
            C[state][action] += W
            Q[state][action] += (W/C[state][action])*(G-Q[state][action])
            
            if action != np.argmax(target_policy(state)):
                break
            W = W *1./behavior_policy(state)[action]
            
    return Q, target_policy

In [None]:
random_policy = create_random_policy(env.action_space.n)
Q, policy = mc_control_importance_sampling(env, num_episodes=500000, behavior_policy=random_policy)

In [None]:
V = defaultdict(float)
for state, actions in Q.items():
    action_value = np.max(actions)
    V[state] = action_value
plotting.plot_value_function(V, title="Optimal Value Function")