In [372]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import math
from collections import namedtuple, deque
import random
import matplotlib.pyplot as plt
import numpy as np
import gym
from gym import spaces
import random

In [373]:
Action = namedtuple('Action', ("player", "move"))
Delta = namedtuple('Delta', ("summary", "reward"))

REWARD = {
    "GAME_IS_OVER": -100,
    "NOT_TURN": -100,
    "TAKEN_SPACE": -100,
    "TIE": 5,
    "WIN": 100,
    "BLOCKED_WIN": 50, 
    "TWO_IN_ROW": 10,
}

class TicTacToe(gym.Env):
    _no_player = 0
    _player_1 = 1
    _player_2 = 2
    
    reward_range = (-np.inf, np.inf)
    observation_space = spaces.MultiDiscrete([2 for _ in range(0, 9 * 3)])
    action_space = spaces.Discrete(9)

    """
    Board looks like:
    [0, 1, 2,
     3, 4, 5,
     6, 7, 8]
    """
    winning_states = [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8],
        [0, 3, 6],
        [1, 4, 7],
        [2, 5, 8],
        [0, 4, 8],
        [2, 4, 6],
    ]

    def __init__(self):
        super().__init__()
        self.reset()

    def reset(self):
        self.board = torch.zeros(9, 3)
        self.board[:,0] = 1

        self.summary = {
            "done": False,
            "moves":0,
            "player_turn":self._player_1,
            "winner":self._no_player,
            "board":self.board
        }

        return self._one_hot_summary()

    def isDone(self):
        return self.summary["done"]

    def _one_hot_summary(self):
        board = self._one_hot_board()
        turn = self._one_hot_turn()
        summary = torch.cat((self._one_hot_board(), self._one_hot_turn()), dim=0)
        return summary.unsqueeze(0)

    def _one_hot_board(self):
        return self.board.flatten()
    
    def _one_hot_turn(self):
        move = torch.zeros(1,2)
        move[0][self.summary["player_turn"] - 1] = 1
        return move.flatten()
    
    def checkForWin(self):
        for state in self.winning_states:
            if all(self.board[:,self._player_1][index] == 1 for index in state):
                return self._player_1
            elif all(self.board[:,self._player_2][index] == 1 for index in state):
                return self._player_2
                
        return self._no_player
    
    def hasTwoInRow(self, player):
        player_1s = 0
        player_2s = 0
        for state in self.winning_states:
            row_player_1 = sum(self.board[:,self._player_1][index] == 1 for index in state)
            row_player_2 = sum(self.board[:,self._player_2][index] == 1 for index in state)
            if row_player_1 == 2 and row_player_2 == 0:
                player_1s += 1
            elif row_player_2 == 2 and row_player_1 == 0:
                player_2s += 1
                
        return [player_1s, player_2s][player - 1]

    
    def checkForTie(self):
        if self.checkForWin() == self._no_player and self.board[:,0].sum() == 0:
            return True
        
        return False

    def step(self, action):
        reward = 0

        #make sure that if the game is over no one plays
        if self.summary["winner"] != 0:
            reward += REWARD["GAME_IS_OVER"]
        
        #make sure that the player whos turn it is is the one making the move
        if self.summary["player_turn"] != action.player:
            reward += REWARD["NOT_TURN"]

        current_player = self.summary["player_turn"]
        self.summary["moves"] += 1
        self.summary["player_turn"] = self._player_1 if self.summary["player_turn"] == self._player_2 else self._player_2        
        other_player = self.summary["player_turn"]


        #Move has already been played
        if self.board[:,self._no_player][action.move] != 1:
            reward += REWARD["TAKEN_SPACE"]
        else:
            oponent_2_in_rows_before_action = self.hasTwoInRow(other_player)
            
            self.board[:,action.player][action.move] = 1
            self.board[:,self._no_player][action.move] = 0

            oponent_2_in_rows_after_action = self.hasTwoInRow(other_player)
            
            if oponent_2_in_rows_after_action < oponent_2_in_rows_before_action:
                reward += REWARD["BLOCKED_WIN"]

        if self.checkForTie():
            self.summary["winner"] = self._no_player
            self.summary["player_turn"] = self._no_player
            self.summary["done"] = True
            reward += REWARD["TIE"]

        winner = self.checkForWin()
        if winner != self._no_player:
            self.summary["winner"] = winner
            self.summary["player_turn"] = self._no_player
            self.summary["done"] = True
            reward += REWARD["WIN"]
        
        almostThrees = self.hasTwoInRow(current_player)
        if almostThrees != 0:
            reward += almostThrees * REWARD["TWO_IN_ROW"]

        #reward for doing something new?
        return Delta(self._one_hot_summary(), reward)

    def render(self):
        board = [" "] * 9
        
        for i, value in enumerate(self.board):
            if torch.argmax(value) == 1:
                icon = "X"
            elif torch.argmax(value) == 2:
                icon = "O"
            else:
                icon = " "

            board[i] = icon

        # Print the tic-tac-toe grid
        print("{}|{}|{}\n-----\n{}|{}|{}\n-----\n{}|{}|{}".format(*board))


## DQN Agent

In [374]:
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward"))

class ReplayMemory(object):
    """Copied verbatim from the PyTorch DQN tutorial.

    During training, observations from the replay memory are
    sampled for policy learning.
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DQNAgent(nn.Module):
    """Policy model. Consists of a fully connected feedforward
    NN with 3 hidden layers.
    """

    def __init__(self):
        super(DQNAgent, self).__init__()
        self.fc1 = nn.Linear(3*9 + 2, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 9)

    def forward(self, x):
        """Forward pass for the model.
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return x

    def act(self, state):
        with torch.no_grad():
            return  self.forward(state).max(-1).indices.view(1,1)


## User Agent

In [375]:
class UserAgent:
    def __init__(self):
        pass
    
    def act(self, state):
        index = input("Enter position on board (0-8): ")
        return torch.tensor([[int(index)]])

## Random Agent

In [376]:
class RandomAgent:
    def __init__(self):
        pass
    
    def act(self, state):
        state = state[0][:-2].reshape(3, 3, 3)
        open_spots = state[:, :, 0].reshape(-1)
        # Get indices of open spots (where the value is 1)
        open_indices = np.where(open_spots == 1)
        max = len(open_indices[0])
        idx = random.randint(0,max-1)
        return torch.tensor([[open_indices[0][idx]]])

## Train

In [377]:
n_steps = 5000
batch_size = 128
gamma = 0.99
eps_start = 1.0
eps_end = 0.1
eps_steps = 2000

env = TicTacToe()
state = env.reset()

policy_net = DQNAgent()
target_net = DQNAgent()
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.AdamW(policy_net.parameters(), lr=1e-3)
memory = ReplayMemory(1000)

def optimize_model(
    optimizer: optim.Optimizer,
    policy: DQNAgent,
    target: DQNAgent,
    memory: ReplayMemory,
    batch_size: int,
    gamma: float,
):
    """Model optimization step, copied verbatim from the Torch DQN tutorial.
    
    Arguments:
        device {torch.device} -- Device
        optimizer {torch.optim.Optimizer} -- Optimizer
        policy {Policy} -- Policy net
        target {Policy} -- Target net
        memory {ReplayMemory} -- Replay memory
        batch_size {int} -- Number of observations to use per batch step
        gamma {float} -- Reward discount factor
    """
    if len(memory) < batch_size:
        return
    transitions = memory.sample(batch_size)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)),
        dtype=torch.bool,
    )
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy
    state_action_values = policy(torch.zeros(29))
    state_action_values = policy(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(batch_size)
    next_state_values[non_final_mask] = target(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * gamma) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(
        state_action_values, expected_state_action_values.unsqueeze(1)
    )

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

def select_dummy_action(state: np.array) -> int:
    return RandomAgent().act(state)

def select_model_action(
    model: DQNAgent, state: torch.tensor, eps: float
):
    sample = random.random()
    if sample > eps:
        return model.act(state)
    else:
        return torch.tensor([[random.randrange(0, 9)]])
    

for step in range(n_steps):
    t = np.clip(step / eps_steps, 0, 1)
    eps = (1-t) * eps_start + t * eps_end
        
    action = select_model_action(policy_net, state, eps)
    next_state, reward = env.step(Action(env._player_1, action.item()))
    reward = torch.tensor([reward])

    #random action to represent player 2
    if not env.isDone():
        next_state, _ = env.step(Action(env._player_2, select_dummy_action(next_state)))
    else: 
        next_state = None


    memory.push(state, action, next_state, reward)

    state = next_state

    # Go through memory and make updates based on that
    optimize_model(
        optimizer=optimizer,
        policy=policy_net,
        target=target_net,
        memory=memory,
        batch_size=batch_size,
        gamma=gamma,
    )

    if env.isDone():
        state = env.reset()


## Play

In [378]:
game = TicTacToe()

user = UserAgent()
rand = RandomAgent()
ai = policy_net

agents = [ai, user]

game.render()
print("\n")

while not game.isDone():
    i = game.summary["moves"]
    print(f"Player {i%2+1}s turn:")

    if i % 2 == 0:
        move = agents[0].act(game._one_hot_summary())
    else:
        move = agents[1].act(game._one_hot_summary())

    action = Action(i%2+1, move.item())

    game.step(Action(i%2+1, move.item()))
    game.render()
    print("\n")

print(f"Winner: {game.summary['winner']}")


 | | 
-----
 | | 
-----
 | | 


Player 1s turn:
X| | 
-----
 | | 
-----
 | | 


Player 2s turn:
X| | 
-----
 | | 
-----
 | | 


Player 1s turn:
X| | 
-----
X| | 
-----
 | | 


Player 2s turn:
X|O| 
-----
X| | 
-----
 | | 


Player 1s turn:
X|O| 
-----
X|X| 
-----
 | | 


Player 2s turn:
X|O| 
-----
X|X|O
-----
 | | 


Player 1s turn:
X|O|X
-----
X|X|O
-----
 | | 


Player 2s turn:
X|O|X
-----
X|X|O
-----
O| | 


Player 1s turn:
X|O|X
-----
X|X|O
-----
O|X| 


Player 2s turn:
X|O|X
-----
X|X|O
-----
O|X|O


Winner: 0
