In [4]:
import gym
from tqdm import trange
from collections import defaultdict
import numpy as np
from matplotlib import pyplot as plt
hit = 1
stick = 0
gamma = 1.0
num_episodes = 10000


class Policy:

    def get_action(self, state):
        if state[0] < 20:
            return hit
        return stick

In [3]:

def generate_episode(env, policy):
    episode = []
    state = env.reset()
    while True:
        action = policy.get_action(state)
        next_state, reward, done, info = env.step(action)
        episode.append((state, action, reward))
        state = next_state
        if done:
            break
    return episode


def policy_evaluation(env, policy):
    V = defaultdict(lambda: np.zeros(env.action_space.n))
    N = defaultdict(lambda: np.zeros(env.action_space.n))

    for n in trange(num_episodes + 1):
        episode = generate_episode(env, policy)
        states, actions, rewards = zip(*episode)
        G = 0
        first_occurence = {state[0]: i for i, state in enumerate(states)}
        for i in range(len(states) - 1, -1, -1):
            G = gamma * G + rewards[i]
            if first_occurence[states[i][0]] == i:# first-visit MC
                vs = V[states[i]][actions[i]]
                N[states[i]][actions[i]] += 1.0
                V[states[i]][actions[i]] +=  (G - vs) / N[states[i]][actions[i]]
    return V

blackjack = gym.make('Blackjack-v0')
stick_policy = Policy()
V = policy_evaluation(blackjack, stick_policy)

100%|██████████| 11/11 [00:00<00:00, 7349.05it/s]


In [5]:
def plot_v(V):
    def get_V(x,y, usable_ace):
        if (x,y,usable_ace) in V:
            return V[(x,y, usable_ace)]
        return 0

    X, Y = np.meshgrid(np.arange(1, 11), np.arange(11, 21))
    z1 = np.array(get_V(x,y, True) for x,y in zip(np.ravel(X), np.ravel(Y))).reshape(X.shape)
    z2 = np.array(get_V(x,y, False) for x,y in zip(np.ravel(X), np.ravel(Y))).reshape(X.shape)
    fig = plt.figure()
    ax1 = plt.axes(projection='3d')
    ax1.plot_surface(X,Y, Z=z1)
    ax1.view_init(elev=60)
    plt.savefig('Usable_ace.png')
    ax2= plt.axes(projection='3d')
    ax2.plot_surface(X,Y, Z=z2)
    ax2.view_init(elev=60)
    plt.savefig('Non_usable_ace.png')
    plt.show()
plot_v(V)

ValueError: cannot reshape array of size 1 into shape (10,10)