In [181]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import math
from collections import namedtuple, deque
import random
import matplotlib.pyplot as plt
import numpy as np
import gym
from gym import spaces
import random


torch.manual_seed(42)

<torch._C.Generator at 0x2a1a3062ed0>

In [182]:
Action = namedtuple('Action', ("player", "move"))
Delta = namedtuple('Delta', ("summary", "reward"))

class TicTacToe(gym.Env):
    _no_player = 0
    _player_1 = 1
    _player_2 = 2
    
    reward_range = (-np.inf, np.inf)
    observation_space = spaces.MultiDiscrete([2 for _ in range(0, 9 * 3)])
    action_space = spaces.Discrete(9)

    """
    Board looks like:
    [0, 1, 2,
     3, 4, 5,
     6, 7, 8]
    """
    winning_states = [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],
        [0, 3, 6],
        [1, 4, 7],
        [2, 5, 8],
        [0, 4, 8],
        [2, 4, 6],
    ]

    def __init__(self):
        super().__init__()
        self.reset()

    def reset(self):
        self.board = torch.zeros(9, 3)
        self.board[:,0] = 1

        self.summary = {
            "done": False,
            "moves":0,
            "player_turn":self._player_1,
            "winner":self._no_player,
            "board":self.board
        }

        return self._one_hot_summary()

    def isDone(self):
        return self.summary["done"]

    def _one_hot_summary(self):
        return torch.cat((self._one_hot_board(), self._one_hot_turn()), dim=0)

    def _one_hot_board(self):
        return self.board.flatten()
    
    def _one_hot_turn(self):
        move = torch.zeros(1,2)
        move[0][self.summary["player_turn"] - 1] = 1
        return move.flatten()
    
    def checkForWin(self):
        for state in self.winning_states:
            if all(self.board[:,self._player_1][index] == 1 for index in state):
                return self._player_1
            elif all(self.board[:,self._player_2][index] == 1 for index in state):
                return self._player_2
            else:
                return self._no_player
    
    def checkForTie(self):
        if self.checkForWin() == self._no_player and self.board[:,0].sum() == 0:
            return True
        
        return False

    def step(self, action):
        #make sure that if the game is over no one plays
        if self.summary["winner"] != 0:
            return Delta(self._one_hot_summary(), -100)
        
        #make sure that the player whos turn it is is the one making the move
        if self.summary["player_turn"] != action.player:
            return Delta(self._one_hot_summary(), -100)

        self.summary["moves"] += 1
        self.summary["player_turn"] = self._player_1 if self.summary["player_turn"] == self._player_2 else self._player_2        
        
        if self.board[:,self._no_player][action.move] != 1:
            return Delta(self._one_hot_summary(), -10)
        
        self.board[:,action.player][action.move] = 1
        self.board[:,self._no_player][action.move] = 0

        if self.checkForTie():
            self.summary["winner"] = self._no_player
            self.summary["player_turn"] = self._no_player
            self.summary["done"] = True
            return Delta(self._one_hot_summary(), 0)

        winner = self.checkForWin()
        if winner != self._no_player:
            self.summary["winner"] = winner
            self.summary["player_turn"] = self._no_player
            self.summary["done"] = True
            return Delta(self._one_hot_summary(), 2)
        
        #maybe add stuff to look ahead for reward?
        return Delta(self._one_hot_summary(), 0)

    def render(self):
        board = [" "] * 9
        
        for i, value in enumerate(self.board):
            if torch.argmax(value) == 1:
                icon = "X"
            elif torch.argmax(value) == 2:
                icon = "O"
            else:
                icon = " "

            board[i] = icon

        # Print the tic-tac-toe grid
        print("{}|{}|{}\n-----\n{}|{}|{}\n-----\n{}|{}|{}".format(*board))


# One Hot
one hot is a method of displaying catigorical data in a numerical fashion. In this case, the board is expressed like so:

| none | x | o |
|------|---|---|
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |
| 1    | 0 | 0 |

## DQN Agent

In [183]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

class ReplayMemory(object):
    """Copied verbatim from the PyTorch DQN tutorial.

    During training, observations from the replay memory are
    sampled for policy learning.
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DQNAgent(nn.Module):
    """Policy model. Consists of a fully connected feedforward
    NN with 3 hidden layers.
    """

    def __init__(self):
        super(DQNAgent, self).__init__()
        self.fc1 = nn.Linear(3*9 + 2, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 9)

    def forward(self, x):
        """Forward pass for the model.
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return x

    def act(self, state):
        with torch.no_grad():
            return  self.forward(state).max(-1).indices.view(1,1)


## User Agent

In [184]:
class UserAgent:
    def __init__(self):
        pass
    
    def act(self, game):
        index = input("Enter position on board (0-8): ")
        action = torch.zeros(9)
        action[int(index)] = 1
        return action

## Random Agent

In [185]:
class RandomAgent:
    def __init__(self):
        pass
    
    def act(self, game):
        return torch.rand(9) * game.board[0]

## Train

In [186]:
## TODO read this closer

# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
N_STEPS = 5000
BATCH_SIZE = 100
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

env = TicTacToe()
state = env.reset()

policy_net = DQNAgent()
target_net = DQNAgent()
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(2000)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1).values
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

def select_dummy_action(state: np.array) -> int:
    # Select a random (valid) move, given a board state.
    state = state.reshape(3, 3, 3)
    open_spots = state[:, :, 0].reshape(-1)
    p = open_spots / open_spots.sum()
    return np.random.choice(np.arange(9), p=p)

def select_model_action(
    model: DQNAgent, state: torch.tensor, eps: float
):
    sample = random.random()
    if sample > eps:
        return model.act(state)
    else:
        return torch.tensor([[random.randrange(0, 9)]])
    

for step in range(N_STEPS):
    t = np.clip(step / EPS_DECAY, 0, 1)
    eps = (1-t) * EPS_START + t * EPS_END
    
    #player 1
    action = select_model_action(policy_net, state, eps)
    next_state, reward = env.step(Action(env._player_1, action.item()))

    #player 2
    if not env.isDone():

    print(action.item(), next_state, reward)

3 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) 0
8 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
1 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
8 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
2 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
8 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
2 tensor([1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0.,
        1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1.]) -100
2 tensor([1., 0., 0., 1

## Play

In [187]:
game = TicTacToe()
user = UserAgent()
random = RandomAgent()
ai = DQNAgent()

agents = [random, random]

for i in range(0,9):
    # game.render()
    if i % 2 == 0:
        move = agents[0](game)
    else:
        move = agents[1](game)



    

        


TypeError: 'RandomAgent' object is not callable