# Miniproject 1: Tic Tac Toe
This notebook contains the code for Miniproject 1.

## Utilities
Here we define some utility functions.

In [None]:
import numpy as np
from tic_env import TictactoeEnv, OptimalPlayer
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid

def play(player1, player2, episodes=5, debug=False, first_player="alternate", disable_tqdm=False, seed=None):
    """Play Tic Tact Toe between two players

    Args:
        player1: Player 1
        player2: Player 2
        episodes (int, optional): Number of episodes. Defaults to 5.
        debug (bool, optional): Whether to print debug messages. Defaults to False.
        first_player (str, optional): Strategy to determine first player. Defaults to "alternate".
            "alternate" means alternate between player1 and player2 every game. Otherwise, randomly determined.
        disable_tqdm (bool, optional): Whether to disable progress bar. Defaults to False.
        seed (_type_, optional): Random seed. Defaults to None.

    Returns:
        (dict, dict): Player 1 stats and Player 2 stats (M-values)
    """
    env = TictactoeEnv()
    Turns = np.array(['X','O'])
    player1_stats = {'wins': 0, 'losses': 0, 'M': 0}
    player2_stats = {'wins': 0, 'losses': 0, 'M': 0}
    
    if seed is not None: 
        # Set a seed for reproducibility
        np.random.seed(seed)
        random.seed(seed)

    for i in tqdm(range(episodes), disable=disable_tqdm):
        env.reset()
        grid, _, __ = env.observe()
        invalid_player = None

        if first_player == "alternate":
            Turns = np.flip(Turns)
        else:
            Turns = Turns[np.random.permutation(2)]

        player1.set_player(Turns[0])
        player2.set_player(Turns[1])

        if debug:
            print('-------------------------------------------')
            print(f"Game {i}, Player 1 = {Turns[0]}, Player 2 = {Turns[1]}")

        for j in range(9):
            if env.current_player == player1.player:
                move = player1.act(grid)
            else:
                move = player2.act(grid)

            try:
                grid, end, winner = env.step(move, print_grid=False)
            except ValueError:
                # If wrong move is played, penalize the player
                end = True
                invalid_player = player1.player if env.current_player == player1.player else player2.player
                winner = None

            if end:
                if hasattr(player1, 'end'):
                    player1.end(grid, winner, invalid_move=(invalid_player==player1.player))
                
                if hasattr(player2, 'end'):
                    player2.end(grid, winner, invalid_move=(invalid_player==player2.player))

                if winner == player1.player:
                    player1_stats['wins'] += 1
                    player2_stats['losses'] += 1
                elif winner == player2.player:
                    player1_stats['losses'] += 1
                    player2_stats['wins'] += 1
                
                if debug:
                    print('-------------------------------------------')
                    print('Game end, winner is player ' + str(winner))
                    print('Player 1 = ' +  Turns[0])
                    print('Player 2 = ' +  Turns[1])
                    env.render()
                    print('-------------------------------------------')
                
                break
    
    player1_stats['M'] = (player1_stats['wins'] - player1_stats['losses']) / episodes
    player2_stats['M'] = (player2_stats['wins'] - player2_stats['losses']) / episodes

    if hasattr(player1, 'finish_run'):
        player1.finish_run()

    if hasattr(player2, 'finish_run'):
        player2.finish_run()

    return player1_stats, player2_stats


def calculate_m_opt(q_player, episodes=500):
    """Calculate M_opt for a given player

    Args:
        q_player: Player
        episodes (int, optional): Number of episodes. Defaults to 500.

    Returns:
        float: M_opt value
    """
    # Put player in evaluation mode to avoid training and logging
    if hasattr(q_player, 'eval'): 
        q_player.eval()

    optimal_player = OptimalPlayer(epsilon=0.0)
    player1_stats, _ = play(q_player, optimal_player, episodes=episodes, debug=False, first_player='alternate', disable_tqdm=True)

    # Put player back in training mode
    if hasattr(q_player, 'train'): 
        q_player.train()

    return player1_stats['M']

def calculate_m_rand(q_player, episodes=500):
    """Calculate M_rand for a given player

    Args:
        q_player: Player
        episodes (int, optional): Number of episodes. Defaults to 500.

    Returns:
        float: M_rand value
    """
    # Put player in evaluation mode to avoid training and logging
    if hasattr(q_player, 'eval'): 
        q_player.eval()

    random_player = OptimalPlayer(epsilon=1.0)
    player1_stats, _ = play(q_player, random_player, episodes=episodes, debug=False, first_player='alternate', disable_tqdm=True)

    # Put player back in training mode
    if hasattr(q_player, 'train'): 
        q_player.train()

    return player1_stats['M']

def save_stats(players, path):
    """ Save player stats to a file"""
    player_stats = []
    
    for player in players:
        
        stat = {
            'loss': player.avg_losses if hasattr(player, 'avg_losses') else None,
            'reward': player.avg_rewards,
            'm_opt': player.m_values['m_opt'],
            'm_rand': player.m_values['m_rand']
        }
        player_stats.append(stat)
        
    with open(path, 'wb') as npy:
        np.save(npy, player_stats)

def plot_average_rewards(stats_path, labels, log_every=250, save_path=None):
    try:
        with open(stats_path, 'rb') as npy:
            player_stats = np.load(npy, allow_pickle=True)
    except:
        print('File not found!')
        raise
        
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))

    game_ids = list(range(0, len(player_stats[0]['reward'])*log_every, log_every))
    for i, player in enumerate(player_stats):
        label = labels[i]
        ax.plot(game_ids, player['reward'], label=label)
    
    ax.set_title(f'Average reward per {log_every} games', fontsize=20, fontweight='bold')
    ax.set_xlabel('Game', fontsize=16)
    ax.set_ylabel('Reward', fontsize=16)
    ax.set_xlim([0, len(game_ids)*log_every])
    ax.legend(loc='lower right')
    ax.grid()

    plt.show()
    
    if save_path is not None: fig.savefig(save_path, format='pdf')

def plot_m_values(stats_path, labels, test_every=250, save_path=None):
    try:
        with open(stats_path, 'rb') as npy:
            player_stats = np.load(npy, allow_pickle=True)
    except:
        print('File not found!')
        raise
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    game_ids = list(range(0, len(player_stats[0]['reward'])*test_every, test_every))
    for i, player in enumerate(player_stats):
        label = labels[i]
        axes[0].plot(game_ids, player["m_opt"], label=label)
        axes[1].plot(game_ids, player["m_rand"], label=label)

        
    axes[0].set_title(f'M_opt per {test_every} games', fontsize=20, fontweight='bold')
    axes[0].set_xlabel('Game', fontsize=16)
    axes[0].set_ylabel('M_opt', fontsize=16)
    axes[0].set_xlim([0, len(game_ids)*test_every])
    axes[0].legend()
    axes[0].grid()
    
    axes[1].set_title(f'M_rand per {test_every} games', fontsize=20, fontweight='bold')
    axes[1].set_xlabel('Game', fontsize=16)
    axes[1].set_ylabel('M_rand', fontsize=16)
    axes[1].set_xlim([0, len(game_ids)*test_every])
    axes[1].legend()
    axes[1].grid()

    plt.show()
    
    if save_path is not None: fig.savefig(save_path, format='pdf')

def read_grid(grid):
    empty = list(zip(*np.where(grid==0)))
    filled = list(zip(*np.where(grid!=0)))
    return empty, filled

def plot_heatmaps(states_list, qvalues_list, titles, cmap=None, save_path=None):
    fig = plt.figure(figsize=(15, 5))
    axes = AxesGrid(fig, 111,
                nrows_ncols=(1, 3),
                axes_pad=0.2,
                cbar_mode='single',
                cbar_location='right',
                cbar_pad=0.1
                )
    
    if cmap is None: cmap = plt.cm.get_cmap('Blues', 10)
    
    for ax, state, qvalues, title in zip(axes, states_list, qvalues_list, titles):
        
        empty, filled = read_grid(state) # Read the configuration of the grid
        qvalue_grid = -1*np.ones((3,3))

        for action in empty:
            qstate = QStateAction(state, action)
            qvalue_grid[action] = qvalues.get(qstate, 0)
        

        img = ax.imshow(qvalue_grid, cmap=cmap, vmin=-1, vmax=1)
        ax.set_axis_off()

        for i, j in empty:
            qval = qvalue_grid[i, j]
            text = ax.text(j, i, f'{qval:.3f}', ha='center', va='center', color='k', fontsize=14)

        for i, j in filled:
            player = 'X' if state[i, j] == 1 else 'O'
            text = ax.text(j, i, f'{player}', ha='center', va='center', color='k', fontsize=16, fontweight='bold')
            
        
        ax.set_title(title, fontsize=14, fontweight='bold')
            
            
    cbar = ax.cax.colorbar(img)
    cbar = axes.cbar_axes[0].colorbar(img)


    if save_path is not None: fig.savefig(save_path, format='pdf')

## Q-Learning

In [None]:
import random
from collections import defaultdict
import numpy as np


class QStateAction:
    """
    This is a helper class to store the state-action pair.
    """
    def __init__(self, grid, action):
        self.grid = grid
        self.state = tuple(grid.ravel().tolist())
        self.action = action

    def __hash__(self) -> int:
        return hash((self.state, self.action))
    
    def __eq__(self, other) -> bool:
        return isinstance(other, QStateAction) and (self.state == other.state) and self.action == other.action

    def __ne__(self, other) -> bool:
        return not self.__eq__(other)
    
    def __repr__(self) -> str:
        return f"state={self.state}, action={self.action}"


class QPlayer:
    def __init__(self, epsilon=0.01, alpha=0.05, gamma=0.99, player='X', log_every=250, test_every=None, qvalues=None, wandb_name=None, *args, **kwargs):
        """Initialize a Q-learning player

        Args:
            epsilon (float, optional): Epsilon value. Defaults to 0.01.
            alpha (float, optional): Alpha value. Defaults to 0.05.
            gamma (float, optional): Gamma value. Defaults to 0.99.
            player (str, optional): Player symbol. Defaults to 'X'.
            log_every (int, optional): Logging frequency (i.e. for logging avg reward). Defaults to 250.
            test_every (int, optional): Testing frequency (i.e. for logging M_opt and M_rand). Defaults to None.
            qvalues (dict, optional): Q-values dict. Defaults to None.
            wandb_name (str, optional): Wandb run name. Defaults to None.
        """
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.player = player # 'X' or 'O'
        self.qvalues = defaultdict(int) if qvalues is None else qvalues
        self.last_qstate = None
        self.last_reward = 0
        self.num_games = 0
        self.running_reward = 0
        self.avg_rewards = []
        self.m_values = {"m_opt": [], "m_rand": []}
        self.log_every = log_every
        self.log = False if log_every is None or log_every <=0 else True
        self.test_every = test_every
        self.test = False if test_every is None or log_every <=0 else True
        self.eval_mode = False
        self.wandb_name = wandb_name
        self.wandb_run = None

        if self.log and self.wandb_name is not None:
            import wandb
            self.wandb_run = wandb.init(project="ann-project", name=wandb_name, reinit=True,
                                        config={"epsilon": epsilon, "gamma": gamma, "player": player,
                                                "log_every": log_every, "test_every": test_every, "alpha": alpha})
        
        
    def set_player(self, player = 'X', j=-1):
        self.player = player
        self.last_qstate = None
        self.last_reward = 0
        if j != -1:
            self.player = 'X' if j % 2 == 0 else 'O'
            
    def eval(self):
        """
        Put player in evaluation mode. In this mode, player does not train
        and does not log metrics.
        """
        self.eval_mode = True

    def train(self):
        """
        Put player in training mode. In this mode, player trains and logs metrics.
        """
        self.eval_mode = False

    def finish_run(self):
        """Finish wandb run if present and if in training mode"""
        if not self.eval_mode and self.wandb_run:
            self.wandb_run.finish()

    def random(self, grid):
        """Chose a random action from the available options. """
        avail = self.empty(grid)
        return QStateAction(grid, avail[random.randint(0, len(avail)-1)])

    def empty(self, grid):
        """Return all empty cells from the grid."""
        avail = np.where(grid==0)
        return list(zip(*avail))

    def greedy(self, grid):
        """Return the best action according to the current Q-values."""
        actions = self.empty(grid)
        best_qstates = None
        max_qvalue = None

        for action in actions:
            qstate = QStateAction(grid, action)
            qvalue = self.qvalues.get(qstate, 0)
            
            if max_qvalue is None or qvalue > max_qvalue:
                max_qvalue = qvalue
                best_qstates = [qstate]
            elif max_qvalue is not None and qvalue==max_qvalue:
                best_qstates.append(qstate)
        
        return np.random.choice(best_qstates)

    def opponent(self):
        """Get the opponent player symbol."""
        return 'X' if self.player == 'O' else 'O'

    def decide(self, grid):
        """Decide on the next action."""
        epsilon = self.epsilon(self.num_games) if callable(self.epsilon) else self.epsilon
        if self.eval_mode or random.random() > epsilon:
            return self.greedy(grid)
        return self.random(grid)

    def update(self, grid, reward=0, end=False):
        """Update the Q-values based on the last action."""
        next_value = 0

        if not end:
            next_value = self.qvalues[self.greedy(grid)]

        if self.last_qstate:
            self.qvalues[self.last_qstate] += self.alpha * (reward + self.gamma * next_value - self.qvalues[self.last_qstate])

    def end(self, grid, winner, *args, **kwargs):
        """End of game callback. Update the Q-values based on the last action and log metrics"""
        if self.eval_mode:
            return
        
        self.num_games += 1
        reward = 0

        if winner == self.player:
            reward = 1
        elif winner == self.opponent():
            reward = -1

        self.update(grid, reward=reward, end=True)

        self.last_qstate = None
        
        if self.log:
            self.running_reward += reward
            
            if (self.num_games+1) % self.log_every == 0:
                avg_reward = self.running_reward / self.log_every
                self.avg_rewards.append(avg_reward)
                self.running_reward = 0

                if self.wandb_name is not None:
                    import wandb
                    wandb.log({"avg_reward": avg_reward})
                
        if self.test:
            if  (self.num_games+1) % self.test_every == 0:
                m_opt = calculate_m_opt(self)
                m_rand = calculate_m_rand(self)

                self.m_values["m_opt"].append(m_opt)
                self.m_values["m_rand"].append(m_rand)

                if self.wandb_name is not None:
                    import wandb
                    wandb.log({"m_opt": m_opt, "m_rand": m_rand})
                

    def act(self, grid):
        """Act on the grid."""
        qstate = self.decide(grid)

        if not self.eval_mode: 
            self.update(grid)

        self.last_qstate = qstate
        return qstate.action

### Questions

#### Question 1

In [None]:
from tic_env import OptimalPlayer

epsilons = [0.01, 0.1, 0.2, 0.5, 0.75]
eps_q_players = []

for eps in epsilons:
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = QPlayer(epsilon=eps)
    play(suboptimal_player, q_player, episodes=20000)
    eps_q_players.append(q_player)

save_stats(players=eps_q_players, path='answers/Q1.npy')

In [None]:
labels = [f'$\epsilon$={eps}' for eps in epsilons]
plot_average_rewards('answers/Q1.npy', labels=labels, save_path='artifacts/fig_Q1.pdf')

#### Question 2

In [None]:
EPS_MIN = 0.1
EPS_MAX = 0.8
n_stars = [1, 100, 1000, 10000, 40000]
n_star_players = []


for n_star in n_stars:
    get_epsilon = lambda n, n_star=n_star: max(EPS_MIN, EPS_MAX * (1 - n / n_star))
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = QPlayer(epsilon=get_epsilon)
    play(suboptimal_player, q_player, episodes=20000)
    n_star_players.append(q_player)
    
save_stats(players=n_star_players, path='answers/Q2.npy')

In [None]:
labels = [f'n*={n_star}' for n_star in n_stars]
plot_average_rewards(stats_path='answers/Q2.npy', labels=labels, save_path='artifacts/fig_Q2.pdf')

#### Question 3

In [None]:
EPS_MIN = 0.1
EPS_MAX = 0.8
n_stars = [1, 100, 1000, 10000, 40000]
n_star_players = []

for n_star in n_stars:
    get_epsilon = lambda n, n_star=n_star: max(EPS_MIN, EPS_MAX * (1 - n / n_star))
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = QPlayer(epsilon=get_epsilon, test_every=250)
    play(suboptimal_player, q_player, episodes=20000)
    n_star_players.append(q_player)
    
save_stats(players=n_star_players, path='answers/Q3.npy')

In [None]:
labels = [f"n*={n_star}" for n_star in n_stars]
plot_m_values(stats_path='answers/Q3.npy', labels=labels, save_path='artifacts/fig_Q3.pdf')

#### Question 4

In [None]:
N_STAR = 100
EPS_MIN = 0.1
EPS_MAX = 0.8
get_epsilon = lambda n: max(EPS_MIN, EPS_MAX * (1 - n / N_STAR))

epsilons = [0, 0.01, 0.1, 0.25, 0.5, 0.75, 1]

eps_players = []

for eps in epsilons:
    other_player = OptimalPlayer(epsilon=eps)
    q_player = QPlayer(epsilon=get_epsilon, test_every=250)
    play(other_player, q_player, episodes=20000)
    eps_players.append(q_player)
    
save_stats(players=eps_players, path='answers/Q4.npy')

In [None]:
labels = [f'$\epsilon$={eps}' for eps in epsilons]
plot_m_values(stats_path='answers/Q4.npy', labels=labels, save_path='artifacts/fig_Q4.pdf')

#### Question 7

In [None]:
epsilons = [0.01, 0.1, 0.2, 0.5]
self_q_players = []

for eps in epsilons:
    qvalues = defaultdict(int)
    q_player1 = QPlayer(epsilon=eps, qvalues=qvalues, test_every=250)
    q_player2 = QPlayer(epsilon=eps, qvalues=qvalues, test_every=250)
    play(q_player1, q_player2, episodes=20000)
    self_q_players.append(q_player1)

save_stats(players=self_q_players, path='answers/Q7.npy')

In [None]:
labels = [f'$\epsilon$={eps}' for eps in epsilons]
plot_m_values(stats_path='answers/Q7.npy', labels=labels, save_path='artifacts/fig_Q7.pdf')

#### Question 8

In [None]:
EPS_MIN = 0.1
EPS_MAX = 0.8
n_stars = [1, 100, 1000, 10000, 40000]
self_n_star_players = []

for n_star in n_stars:
    get_epsilon = lambda n, n_star=n_star: max(EPS_MIN, EPS_MAX * (1 - n / n_star))
    qvalues = defaultdict(int)
    q_player1 = QPlayer(epsilon=get_epsilon, qvalues=qvalues, test_every=250)
    q_player2 = QPlayer(epsilon=get_epsilon, qvalues=qvalues, test_every=250)
    play(q_player1, q_player2, episodes=20000)
    self_n_star_players.append(q_player1)
    
save_stats(players=self_n_star_players, path='answers/Q8.npy')

qvalues = [player.qvalues for player in self_n_star_players]
with open('answers/Q10.npy', 'wb') as npy:
    np.save(npy, qvalues)

In [None]:
labels = [f'n*={n_star}' for n_star in n_stars]
plot_m_values(stats_path='answers/Q8.npy', labels=labels, save_path='artifacts/fig_Q8.pdf')

#### Question 10

In [None]:
state1 = np.array([
    [1.0, 0.0, 0.0],
    [0.0, -1.0, 0.0],
    [1.0, 0.0, 0.0],
])

state2 = np.array([
    [-1.0, 0.0, 1.0],
    [0.0, 1.0, 0.0],
    [-1.0, 0.0, 0.0],
])

state3 = np.array([
    [0.0, 0.0, -1.0],
    [0.0, 1.0, -1.0],
    [0.0, 1.0, 0.0],
])

with open('answers/Q10.npy', 'rb') as npy:
    qvalues = np.load(npy, allow_pickle=True)

states_list = [state1, state2, state3]
qvalues_list = [qvalues[4], qvalues[3], qvalues[2]]
titles=[f'Self-Play with n*={n_star}' for n_star in [40000, 10000, 1000]]
cmap = plt.cm.get_cmap('Blues', 100)

plot_heatmaps(states_list, qvalues_list, titles=titles, cmap=cmap, save_path='artifacts/fig_Q10.pdf')

## Deep Q-Learning

In [None]:
from collections import namedtuple, deque
import torch
import random
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import json

VALUE_TO_PLAYER = {-1: 'O', 1: 'X', 0: None}

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, next_state=None, reward=0):
        """Save a transition"""
        self.memory.append(Transition(state=state.clone(),
                                      action=action.clone() if isinstance(action, torch.Tensor) else torch.tensor(action),
                                      next_state=next_state.clone() if next_state is not None else None,
                                      reward=reward.clone() if isinstance(reward, torch.Tensor) else torch.tensor(reward)))

    def sample(self, batch_size):
        """ Sample a batch of transitions """
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DeepQNetwork(nn.Module):
    def __init__(self, in_dim=18, out_dim=9, hidden_dim=128) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.model(x.flatten(start_dim=1))
    
    def predict(self, x):
        predictions = self.forward(x)
        return torch.argmax(predictions)


class DeepQPlayer:
    def __init__(self, epsilon=0.1, gamma=0.99, player='X', memory_capacity=10000, target_update=500,
                 batch_size=64, learning_rate=5e-4, log_every=250, debug=False,
                 policy_net=None, target_net=None, memory=None, swap_state=False, log=True, wandb_name=None, do_optimize=True, *args, **kwargs):
        """Initialize a Deep Q-learning player

        Args:
            epsilon (float, optional): Epsilon value. Defaults to 0.1.
            gamma (float, optional): Gamma value. Defaults to 0.99.
            player (str, optional): Player symbol. Defaults to 'X'.
            memory_capacity (int, optional): Replay memory capacity. Defaults to 10000.
            target_update (int, optional): Update frequency for target network. Defaults to 500.
            batch_size (int, optional): Batch size. Defaults to 64.
            learning_rate (float, optional): Learning rate. Defaults to 5e-4.
            log_every (int, optional): Logging frequency. Defaults to 250.
            debug (bool, optional): Whether print debug messages. Defaults to False.
            policy_net (nn.Module, optional): Shared policy network. Defaults to None.
            target_net (nn.Module, optional): Shared target network. Defaults to None.
            memory (ReplayMemory, optional): Shared replay memory. Defaults to None.
            swap_state (bool, optional): Whether to swap state before saving to memory. Defaults to False.
            log (bool, optional): Whether to log metrics. Defaults to True.
            wandb_name (str, optional): Wandb run name. Defaults to None.
            do_optimize (bool, optional): Whether to optimize the policy network. Defaults to True.
        """
        self.epsilon = epsilon
        self.gamma = gamma
        self.player = player
        self.memory = ReplayMemory(memory_capacity) if memory is None else memory
        self.policy_net = DeepQNetwork().to(DEVICE) if policy_net is None else policy_net.to(DEVICE)
        self.target_net = DeepQNetwork().to(DEVICE) if target_net is None else target_net.to(DEVICE)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.last_state = None
        self.last_action = None
        self.last_reward = 0
        self.num_games = 0
        self.target_update = target_update
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.criterion = nn.HuberLoss()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
        self.running_reward = 0
        self.running_loss = 0
        self.avg_rewards = []
        self.avg_losses = []
        self.log_every = log_every
        self.debug = debug
        self.m_values = {"m_opt": [], "m_rand": []}
        self.eval_mode = False
        self.swap_state = swap_state
        self.log = log
        self.wandb_name = wandb_name
        self.wandb_run = None
        self.do_optimize = do_optimize

        if self.log and self.wandb_name is not None:
            import wandb
            self.wandb_run = wandb.init(project="ann-project", name=wandb_name, reinit=True,
                                        config={"epsilon": epsilon, "gamma": gamma, "player": player, "memory_capacity": memory_capacity,
                                                "target_update": target_update, "batch_size": batch_size, "learning_rate": learning_rate,
                                                "log_every": log_every, "debug": debug, "swap_state": swap_state, "log": log})

    def eval(self):
        """
        Set the player to evaluation mode.
        In this mode, the player will not explore and will only use the policy network to make decisions.
        """
        self.eval_mode = True

    def train(self):
        """ Set the player to training mode. """
        self.eval_mode = False

    def finish_run(self):
        """
        Finish wandb run if present and in training mode.
        """
        if not self.eval_mode and self.wandb_run:
            self.wandb_run.finish()

    def save_pretrained(self, save_path):
        """ Save a pretrained model. """
        Path(save_path).mkdir(parents=True, exist_ok=True)
        config = dict(epsilon=None if callable(self.epsilon) else self.epsilon, gamma=self.gamma, player=self.player, memory_capacity=len(self.memory),
                      target_update=self.target_update, learning_rate=self.learning_rate, batch_size=self.batch_size,
                      log_every=self.log_every, debug=self.debug, avg_losses=self.avg_losses, avg_rewards=self.avg_rewards, m_values=self.m_values)
        Path(save_path, "config.json").write_text(json.dumps(config))
        torch.save(self.policy_net.state_dict(), Path(save_path, "policy_net.pt"))
        torch.save(self.target_net.state_dict(), Path(save_path, "target_net.pt"))

    @classmethod
    def from_pretrained(cls, load_path):
        """ Load a pretrained model. """
        config = json.loads(Path(load_path, "config.json").read_text())
        policy_net = torch.load(Path(load_path, "policy_net.pt"))
        target_net = torch.load(Path(load_path, "target_net.pt"))
        player = cls(**config)
        player.policy_net.load_state_dict(policy_net)
        player.target_net.load_state_dict(target_net)
        player.avg_losses = config["avg_losses"]
        player.avg_rewards = config["avg_rewards"]
        player.m_values = config["m_values"]
        return player

    def set_player(self, player = 'X', j=-1):
        self.player = player
        self.last_state = None
        self.last_action = None
        self.last_reward = 0
        if j != -1:
            self.player = 'X' if j % 2 == 0 else 'O'

    def empty(self, grid):
        """ Return all empty positions in the grid. """
        avail = []
        for i in range(9):
            pos = (int(i/3), i % 3)
            if grid[pos] == 0:
                avail.append(i)
        return avail

    def random(self, grid):
        """ Chose a random action from the available options. """
        avail = self.empty(grid)
        return avail[random.randint(0, len(avail)-1)]

    def opponent(self):
        """ Return the opponent of the player. """
        return 'X' if self.player == 'O' else 'O'

    def grid_to_state(self, grid):
        """ Convert the grid to a state. """
        state = torch.zeros((3, 3, 2))

        for i in range(len(grid)):
            for j in range(len(grid[0])):
                if VALUE_TO_PLAYER[grid[i, j]] == self.player:
                    state[i, j] = torch.tensor([1, 0])
                elif VALUE_TO_PLAYER[grid[i, j]] == self.opponent():
                    state[i, j] = torch.tensor([0, 1])
                else:
                    state[i, j] = torch.tensor([0, 0])
        
        return state

    def maybe_swap_state(self, state):
        """ Swap the state if necessary. """
        if self.swap_state:
            return state.flip(dims=[2])
        return state

    def greedy(self, grid):
        """ Return the best action according to the current policy network. """
        with torch.no_grad():
            prediction = self.policy_net.predict(self.grid_to_state(grid).unsqueeze(0).to(DEVICE))
            return prediction.item()

    def decide(self, grid):
        """ Decide on an action. """
        epsilon = self.epsilon(self.num_games) if callable(self.epsilon) else self.epsilon
        if self.eval_mode or random.random() > epsilon:
            return self.greedy(grid)
        return self.random(grid)

    def act(self, grid):
        """ Act on the grid. """
        state = self.grid_to_state(grid)
        action = self.decide(grid)

        if not self.eval_mode:
            if self.last_state is not None:
                self.memory.push(self.maybe_swap_state(self.last_state), self.last_action, self.maybe_swap_state(state), self.last_reward)
            
            if self.do_optimize:
                self.optimize()

        self.last_state = state
        self.last_action = action
        self.last_reward = 0

        return action

    def end(self, grid, winner, invalid_move=False):
        """ End callback of the game. """
        if not self.eval_mode:
            self.num_games += 1
            reward = 0

            if winner == self.player:
                reward = 1
            elif winner == self.opponent() or invalid_move:
                reward = -1

            self.memory.push(self.maybe_swap_state(self.last_state), self.last_action, None, reward)

            if self.do_optimize:
                loss = self.optimize()

            self.last_state = None
            self.last_action = None

            if self.log:
                self.running_reward += reward

                if loss is not None:
                    self.running_loss += loss

                if (self.num_games+1) % self.log_every == 0:
                    avg_reward = self.running_reward / self.log_every
                    self.avg_rewards.append(avg_reward)
                    self.running_reward = 0

                    avg_loss = self.running_loss / self.log_every
                    self.avg_losses.append(avg_loss)
                    self.running_loss = 0

                    m_opt = calculate_m_opt(self)
                    m_rand = calculate_m_rand(self)

                    self.m_values["m_opt"].append(m_opt)
                    self.m_values["m_rand"].append(m_rand)

                    if self.wandb_name is not None:
                        import wandb
                        wandb.log({"avg_reward": avg_reward, "avg_loss": avg_loss, "m_opt": m_opt, "m_rand": m_rand})

    def optimize(self):
        """ Optimize the policy network. """
        if len(self.memory) < self.batch_size:
            return

        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), dtype=torch.bool).to(DEVICE)
        non_final_next_states = [s for s in batch.next_state if s is not None]

        if non_final_next_states:
            non_final_next_states = torch.stack(non_final_next_states).to(DEVICE)
        else:
            non_final_next_states = None

        state_batch = torch.stack(batch.state).to(DEVICE)
        action_batch = torch.stack(batch.action).view(-1, 1).to(DEVICE)
        reward_batch = torch.stack(batch.reward).view(-1, 1).to(DEVICE)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)

        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1)[0].
        # This is merged based on the mask, such that we'll have either the expected
        # state value or 0 in case the state was final.
        next_state_values = torch.zeros((self.batch_size, 1)).to(DEVICE)

        if non_final_next_states is not None:
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].view(-1, 1)
        # Compute the expected Q values
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch

        # Compute loss
        loss = self.criterion(state_action_values, expected_state_action_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.num_games % self.target_update == 0:
            if self.debug:
                print(f"num_games={self.num_games}, loss={loss.item()}")
            self.target_net.load_state_dict(self.policy_net.state_dict())
        
        return loss.item()

### Questions

#### Question 11

In [None]:
from tic_env import OptimalPlayer

epsilons = [0.001, 0.01, 0.1, 0.2]

for eps in epsilons:
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = DeepQPlayer(epsilon=eps)
    play(suboptimal_player, q_player, episodes=20000)

#### Question 12

In [None]:
from tic_env import OptimalPlayer

epsilons = [0.001, 0.01, 0.1, 0.2]

for eps in epsilons:
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = DeepQPlayer(epsilon=eps, batch_size=1, memory_capacity=1)
    play(suboptimal_player, q_player, episodes=20000)

#### Question 13

In [None]:
from tic_env import OptimalPlayer

EPS_MIN = 0.1
EPS_MAX = 0.8
n_stars = [1, 50, 100, 500, 1000, 5000, 10000, 20000, 40000]

for n_star in n_stars:
    suboptimal_player = OptimalPlayer(epsilon=0.5)
    q_player = DeepQPlayer(epsilon=lambda n, n_star=n_star: max(EPS_MIN, EPS_MAX * (1 - n / n_star)))
    play(suboptimal_player, q_player, episodes=20000)

#### Question 14

In [None]:
from tic_env import OptimalPlayer

EPS_MIN = 0.1
EPS_MAX = 0.8
N_STAR = 100
epsilons = [0, 0.01, 0.1, 0.5, 1]

for eps in epsilons:
    other_player = OptimalPlayer(epsilon=eps)
    q_player = DeepQPlayer(epsilon=lambda n: max(EPS_MIN, EPS_MAX * (1 - n / N_STAR)))
    play(other_player, q_player, episodes=20000)

#### Question 16

In [None]:
epsilons = [0.001, 0.01, 0.1, 0.2, 0.3, 0.5]

for eps in epsilons:
    memory = ReplayMemory()
    policy_net = DeepQNetwork()
    target_net = DeepQNetwork()
    q_player1 = DeepQPlayer(epsilon=eps, policy_net=policy_net, target_net=target_net, memory=memory)
    q_player2 = DeepQPlayer(epsilon=eps, policy_net=policy_net, target_net=target_net, memory=memory, log=False, do_optimize=False)
    play(q_player1, q_player2, episodes=20000)

#### Question 17

In [None]:
EPS_MIN = 0.1
EPS_MAX = 0.8
n_stars = [1, 50, 100, 500, 1000, 5000, 10000, 20000, 40000]

for n_star in n_stars:
    memory = ReplayMemory()
    policy_net = DeepQNetwork()
    target_net = DeepQNetwork()
    epsilon = lambda n, n_star=n_star: max(EPS_MIN, EPS_MAX * (1 - n / n_star))
    q_player1 = DeepQPlayer(epsilon=epsilon, policy_net=policy_net, target_net=target_net, memory=memory)
    q_player2 = DeepQPlayer(epsilon=epsilon, policy_net=policy_net, target_net=target_net, memory=memory, log=False, do_optimize=False)
    play(q_player1, q_player2, episodes=20000)