In [None]:
from dataclasses import dataclass, field
from enum import Enum 
from copy import copy

import abc

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import cm

In [None]:
# for plotting
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# visualize plots in the jupyter notebook
# check more https://goo.gl/U3Ai8R
%matplotlib inline

In [None]:
def plot_value_function(V, title='Value Function', generate_gif=False, train_steps=None):
    """
    Plots a value function as a surface plot, like in: https://goo.gl/aF2doj

    You can choose between just plotting the graph for the value function
    which is the default behaviour (generate_gif=False) or to train the agent
    a couple of times and save the frames in a gif as you train.

    Args:
        agent: An agent.
        title (string): Plot title.
        generate_gif (boolean): If want to save plots as a gif.
        train_steps: If is not None and generate_gif = True, then will use this
                     value as the number of steps to train the model at each frame.
    """
    # you can change this values to change the size of the graph
    fig = plt.figure(title, figsize=(10, 5))
    
    # explanation about this line: https://goo.gl/LH5E7i
    ax = fig.add_subplot(111, projection='3d')
    
    def plot_frame(ax):
        # min value allowed accordingly with the documentation is 1
        # we're getting the max value from V dimensions
        min_x = 1
        max_x = V.shape[0]
        min_y = 1
        max_y = V.shape[1]

        # creates a sequence from min to max
        x_range = np.arange(min_x, max_x)
        y_range = np.arange(min_y, max_y)

        # creates a grid representation of x_range and y_range
        X, Y = np.meshgrid(x_range, y_range)

        # get value function for X and Y values
        def get_stat_val(x, y):
            return V[x, y]
        Z = get_stat_val(X, Y)

        # creates a surface to be ploted
        # check documentation for details: https://goo.gl/etEhPP
        ax.set_xlabel('Dealer Showing')
        ax.set_ylabel('Player Sum')
        ax.set_zlabel('Value')
        return ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm, 
                               linewidth=0, antialiased=False)

    surf = plot_frame(ax)
    plt.title(title)
    fig.canvas.draw()
    plt.show()

## Environment

In [None]:
class Action(Enum):
    stick = 0
    hit = 1
    
class Mode(Enum):
    exploit = 0
    explore = 1
    
class CardColor(Enum):
    red = 0
    black = 1

@dataclass
class Card:
    value: int = field(default_factory=lambda: np.random.randint(1, 11))
    color: str = field(default_factory=lambda: np.random.choice([CardColor.red, CardColor.black], p=[1/3, 2/3]))
        
    @property
    def key(self):
        return f"{self.color}{self.value}"
        
@dataclass
class State:
    dealer_first_card: Card = field(default_factory=lambda: Card(color=CardColor.black))
    player_sum: int = field(default_factory=lambda: Card(color=CardColor.black).value)
    terminal: bool = False
        
    @property
    def key(self):
        return f"{self.dealer_first_card.value}-{self.player_sum}-{self.terminal}"
    

class Policy:
    
    def __init__(self):
        pass
    
    def step(self, state: State) -> Action:
        raise NotImplementedError()


class DealerPolicy(Policy):
    
    def step(self, state: State, dealer_sum: int):
        
        if dealer_sum < 17:
            return Action.hit
        
        return Action.stick


class Environment:
    
    def __init__(self):
        self.player_sum = 0
        self.dealer_sum = 0
    
    def initialize_state(self):
        self.state = State()
        self.player_sum = self.state.player_sum
        self.dealer_sum = self.state.dealer_first_card.value
        
        return self.state
    
    def get_state(self, set_terminal=True):
        state = copy(self.state)
        state.terminal = set_terminal 
        return state
    
    @staticmethod
    def _update_sum(value: int, card: Card) -> int:
        if card.color is CardColor.black:
            value += card.value
        elif card.color is CardColor.red:
            value -= card.value
        else:
            raise Exception(f"Card color {card.color} not known")

        return value
    
    def update_player_sum(self, card: Card) -> None:
        self.player_sum = self._update_sum(self.player_sum, card)
        self.state.player_sum = self.player_sum
    
    def update_dealer_sum(self, card: Card) -> None:
        self.dealer_sum = self._update_sum(self.dealer_sum, card)

    @staticmethod
    def _is_bust(value: int):
        if value > 21 or value < 1:
            return True

        return False
    
    def dealer_is_bust(self) -> bool:
        return self._is_bust(self.dealer_sum)
    
    def player_is_bust(self) -> bool:
        return self._is_bust(self.player_sum)

    def step(self, action: Action):

        if action is Action.hit:
            self.update_player_sum(Card())

            if self.player_is_bust():

                # Player loses: reward -1
                return self.get_state(), -1

            else:

                # Game continues: reward 0 (intermediate step)
                return self.get_state(set_terminal=False), 0

        # Player sticks, dealer (environment) policy runs
        elif action is Action.stick:
            
            while dealer_policy.step(self.get_state(), self.dealer_sum) is Action.hit:
                
                self.update_dealer_sum(Card())

                if self.dealer_is_bust():

                    # Player wins: reward +1
                    return self.get_state(), 1

            # Dealer won't draw more cards - determine reward
            r = np.sign(self.player_sum - self.dealer_sum)

            return self.get_state(), r

## Monte-carlo control

In [None]:
class StateActionPair:
    
    def __init__(self, state: State, action: Action):
        self.state = state
        self.action = action
    
    @property
    def key(self):
        return f"{self.state.key}-{self.action.name}"
    
def gen_state_action_pairs(state: State):
    return (StateActionPair(state, action) for action in Action)

class Registry(abc.ABC):
    
    def __init__(self):
        self.r = dict()
    
    def __call__(self, key):
        return self.r.get(key, 0)

class CountRegistry(Registry):
    def increment(self, key):
        current = self.r.get(key, 0)
        self.r[key] = current + 1

class ValueRegistry(Registry):
    def store(self, key, value):
        self.r[key] = value

In [None]:
episodes = 1000000

Nzero=100

Q = ValueRegistry()
Ns = CountRegistry()
Nsa = CountRegistry()

rewards = []

env = Environment()

dealer_policy = DealerPolicy()

for i in range(episodes):
    
    if i % (round(episodes / 20)) == 0:
        print(f"Episode: {i:>10}/{episodes} --- {i/episodes*100:>5.1f}%")

    # Reset states list
    state_action_pairs = []
    
    s = env.initialize_state()

    while not s.terminal:

        # Randomly decide wether to exploit or explore
        eps = Nzero / (Nzero + Ns(s.key))
        mode = np.random.choice([Mode.exploit, Mode.explore], p=[1-eps, eps])

        # Exploit: take best action given q
        if mode is Mode.exploit:
            
            # Evaluate Q function for both state, action pairs
            sa_stick, sa_hit = gen_state_action_pairs(s)
            Q_stick, Q_hit = Q(sa_stick.key), Q(sa_hit.key)
            
            # Greedy policy: take the one with the highest q value, if draw then 50/50
            if Q_stick > Q_hit:
                a = Action.stick
            elif Q_stick < Q_hit:
                a = Action.hit
            else:
                a = np.random.choice([Action.hit, Action.stick])

        # Explore: randomly decide between actions
        elif mode is Mode.explore:
            a = np.random.choice([Action.hit, Action.stick])
                       
        # Append state-action pair to list
        sa_next = StateActionPair(s, a)
        state_action_pairs.append(sa_next)
            
        # Take step
        s, G = env.step(a)
        rewards.append(G)
    
    # Episode has ended, perform optimization step using G
    for sa in state_action_pairs:
        
        # Update counters
        Ns.increment(sa.state.key)
        Nsa.increment(sa.key)
        
        # Update Qsa
        q_current = Q(sa.key)
        q_new = q_current + 1 / Nsa(sa.key) * (G - q_current)
        Q.store(sa.key, q_new)

In [None]:
df = pd.DataFrame({"key": Q.r.keys(), "q": Q.r.values()})
df[["Dealer showing", "Player sum", "terminated", "action"]] = df["key"].str.split("-", expand=True).iloc[:, :4]
df

In [None]:
df['terminated'].astype(bool)

In [None]:
df[df['terminated'].astype(bool)]

In [None]:

df_v = df[df['terminated']==False][["Dealer showing", "Player sum", "q"]]
df_v = df_v.groupby(["Dealer showing", "Player sum"]).mean()
df_v

In [None]:
v = df_v.unstack(level=0)
v = v.reset_index(drop=True)
v

In [None]:
plot_value_function(v.values.transpose())

In [None]:
v_columns = v.columns.to_list()

x_linspace = np.array(list(zip(*v_columns))[1], dtype=int)
y_linspace = v.index.to_numpy()

x, y = np.meshgrid(x_linspace, y_linspace)

In [None]:
fig = plt.figure(figsize=(10, 5))
    
ax = fig.add_subplot(111, projection='3d')

# Plot the surface.
surf = ax.plot_surface(x, y, v.values, 
                       cmap=cm.coolwarm,
                       rstride=1, cstride=1,
                       linewidth=0, antialiased=False)

ax.set_xlabel("Dealer showing")
ax.set_ylabel("Player sum")

plt.show()

In [None]:
print(x_linspace.shape, y_linspace.shape)
x_linspace, y_linspace

In [None]:
x, y = np.meshgrid(x_linspace, y_linspace)
print(x.shape, y.shape, v.transpose().values.shape)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})


# Plot the surface.
surf = ax.plot_surface(x, y, v.transpose().values, 
                       cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()