In [None]:
from random import randint, shuffle, uniform
import numpy as np

class StickGame:
    def __init__(self, initial_sticks):
        super(StickGame, self).__init__()
        self.original_sticks = initial_sticks
        self.sticks = initial_sticks

    def is_finished(self):
        if self.sticks <= 0:
            return True
        return False

    def reset(self):
        self.sticks = self.original_sticks
        return self.sticks

    def display(self):
        print("| " * self.sticks)

    def step(self, action):
        self.sticks -= action
        if self.sticks <= 0:
            return None, -1
        else:
            return self.sticks, 0

class StickPlayer:
    def __init__(self, is_human, size, trainable=True):
        super(StickPlayer, self).__init__()
        self.is_human = is_human
        self.history = []
        self.values = {}
        for s in range(1, size + 1):
            self.values[s] = 0.
        self.win_count = 0
        self.lose_count = 0
        self.rewards = []
        self.epsilon = 0.99
        self.trainable = trainable

    def resetStat(self):
        self.win_count = 0
        self.lose_count = 0
        self.rewards = []

    def greedyAction(self, state):
        actions = [1, 2, 3]
        min_value = None
        best_action = None
        for action in actions:
            if state - action > 0 and (min_value is None or min_value > self.values[state - action]):
                min_value = self.values[state - action]
                best_action = action
        return best_action if best_action is not None else 1

    def play(self, state):
        if not self.is_human:
            if uniform(0, 1) < self.epsilon:
                action = randint(1, 3)
            else:
                action = self.greedyAction(state)
        else:
            action = int(input("$>"))
        return action

    def addTransition(self, transition):
        self.history.append(transition)
        s, a, r, sp = transition
        self.rewards.append(r)

    def train(self):
        if not self.trainable or self.is_human:
            return

        for transition in reversed(self.history):
            s, a, r, sp = transition
            if r == 0:
                self.values[s] = self.values[s] + 0.001 * (self.values[sp] - self.values[s])
            else:
                self.values[s] = self.values[s] + 0.001 * (r - self.values[s])

        self.history = []

def playGame(game, p1, p2, train=True):
    state = game.reset()
    players = [p1, p2]
    shuffle(players)
    p = 0
    while not game.is_finished():

        if players[p % 2].is_human:
            game.display()

        action = players[p % 2].play(state)
        next_state, reward = game.step(action)

        if reward != 0:
            players[p % 2].lose_count += 1 if reward == -1 else 0
            players[p % 2].win_count += 1 if reward == 1 else 0
            players[(p + 1) % 2].lose_count += 1 if reward == 1 else 0
            players[(p + 1) % 2].win_count += 1 if reward == -1 else 0

        if p != 0:
            s, a, r, sp = players[(p + 1) % 2].history[-1]
            players[(p + 1) % 2].history[-1] = (s, a, reward * -1, next_state)

        players[p % 2].addTransition((state, action, reward, None))

        state = next_state
        p += 1

    if train:
        p1.train()
        p2.train()

if __name__ == '__main__':
    game = StickGame(12)

    p1 = StickPlayer(is_human=False, size=12, trainable=True)
    p2 = StickPlayer(is_human=False, size=12, trainable=True)

    human = StickPlayer(is_human=True, size=12, trainable=False)
    random_player = StickPlayer(is_human=False, size=12, trainable=False)

    for i in range(0, 10000):
        if i % 10 == 0:
            p1.epsilon = max(p1.epsilon * 0.996, 0.05)
            p2.epsilon = max(p2.epsilon * 0.996, 0.05)
        playGame(game, p1, p2)

    p1.resetStat()

    for key in p1.values:
        print(key, p1.values[key])
    print("--------------------------")

    for _ in range(0, 1000):
        playGame(game, p1, random_player, train=False)
    print("AI player 1 win rate", p1.win_count / (p1.win_count + p1.lose_count))
    print("AI player 1 win mean", np.mean(p1.rewards))

    while True:
        playGame(game, p1, human, train=False)


1 -0.9877373959319322
2 0.24474183195089377
3 0.8842927107897013
4 0.3942470652444863
5 -0.8469282840205946
6 0.0665458822707696
7 0.6948590837600174
8 0.0780831849301231
9 -0.6461839292725629
10 0.02949009554564702
11 0.029884115824033327
12 0.47135025470180714
--------------------------
AI player 1 win rate 0.953
AI player 1 win mean -0.016336461591936045
| | | | | | | | | 
