<a href="https://colab.research.google.com/github/dsaragih/RL-games/blob/main/RoGoU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
pip install tensordict==0.2.0
pip install torchrl==0.2.0

In [2]:
# @title Imports
import gym
from gym import spaces
import pygame
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os

from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

In [17]:
# @title Env
class Rogou(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

    def __init__(self, render_mode=None, starting_player=1):
        self.window_size = 512  # The size of the PyGame window

        # Observations are a box of 8 x 3, where each entry is either
        # 0 (empty), >0 (cross), or <0 (circle)
        self.observation_space = spaces.Box(-7, 7, shape=(3, 8), dtype=int)

        # Agent has no control over # of steps. The agent can only
        # choose to move one of the pieces that is not home yet.
        self.action_space = spaces.Discrete(7)
        self.starting_player = starting_player
        self.player = starting_player # 1 for crosses, -1 for circles
        self.prev_step = 0
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None

    def _get_obs(self):
        # Expand first dimension of board and convert to float
        board = np.expand_dims(self.board, axis=0).astype(float)
        return self.board.copy()

    def _get_info(self):
        # Return number of rows that are not -1
        return {
            "player": self.player,
            "n_crosses": np.count_nonzero(self.board > 0),
            "n_circles": np.count_nonzero(self.board < 0),
            "crosses_home": self.crosses_home,
            "circles_home": self.circles_home,
        }

    def _check_win(self):
        # Check if the current player has {self.size} in a row
        # First check rows
        for row in range(self.size):
            if np.all(self.board[row, :] == self.player):
                return True

        # Now check columns
        for col in range(self.size):
            if np.all(self.board[:, col] == self.player):
                return True

        # Now check diagonals
        if np.all(np.diag(self.board) == self.player):
            return True
        if np.all(np.diag(np.fliplr(self.board)) == self.player):
            return True

        return False

    def _get_step_size(self):
        return self.board[0, 4] // 10

    def _set_step_size(self):
        step_size = np.random.choice(5, 1, p=[0.0625, 0.25, 0.375, 0.25, 0.0625])[0]
        self.board[0, 4] = int(step_size) * 10

    def reset(self, seed=None, options=None):
        # We need the following line to seed self.np_random
        super().reset(seed=seed)

        # Reset the crosses and circles locations
        self.board = np.zeros((3, 8), dtype=int)
        self.crosses_home = np.zeros(7, dtype=int)
        self.circles_home = np.zeros(7, dtype=int)
        self.player = self.starting_player
        self._set_step_size()
        self.prev_step = 0

        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, info

    def _is_rosette(self, x, y):
        if x == 0 and y == 0:
            return True
        elif x == 0 and y == 2:
            return True
        elif x == 6 and y == 0:
            return True
        elif x == 6 and y == 2:
            return True
        elif x == 3 and y == 1:
            return True
        else:
            return False

    def _check_valid_square(self, x, y):
        # If board[x, y] piece has the same sign as player, return False
        # If board[x, y] piece has the opposite sign as player and is not on an immune square, return True
        # If board[x, y] is empty return True
        if (x == 4 or x == 5) and (y == 0 or y == 2):
            return False
        elif x < 0 or x > 7 or y < 0 or y > 2:
            return False
        elif self.board[y, x] == 0:
            return True
        # same sign as self.player
        elif self.board[y, x] * self.player > 0:
            return False
        # opposite sign as self.player
        elif self.board[y, x] * self.player < 0:
            return not self._is_rosette(x, y)

    def _is_home(self, x, y):
        if self.player == 1:
            return x == 5 and y == 0
        else:
            return x == 5 and y == 2

    def _get_new_pos(self, curr_x, curr_y, step_size):
        start_row = 0 if self.player == 1 else 2
        if curr_x is None:
            # Piece is not on the board
            new_x = 4 - step_size
            new_y = start_row
            if self._check_valid_square(new_x, new_y):
                return new_x, new_y

        elif curr_x > 5 and curr_y == start_row:
            new_x = curr_x - step_size
            new_y = start_row
            if new_x > 5 and self._check_valid_square(new_x, new_y):
                return new_x, new_y
            elif self._is_home(curr_x - step_size, curr_y):
                # Piece is home
                return -1, -1

        elif curr_x < 4 and curr_y == start_row:
            if curr_x - step_size >= 0:
                new_x = curr_x - step_size
                new_y = start_row
            else:
                # Piece makes it to the next row
                new_x = step_size - curr_x - 1
                new_y = 1
            if self._check_valid_square(new_x, new_y):
                return new_x, new_y

        else:
            new_x = curr_x + step_size
            if new_x > 7 or new_x < 0:
                new_x = 8 - (new_x - 7)
                new_y = start_row
            else:
                new_y = 1

            if self._check_valid_square(new_x, new_y):
                return new_x, new_y
            elif self._is_home(new_x, new_y):
                # Piece is home
                return -1, -1

        return None, None

    def _check_possible_move(self, id, step_size):
        # print("Curr y, x: ", np.where(self.board == id))
        curr_y, curr_x = np.where(self.board == id)
        if curr_x.size == 0:
            curr_x, curr_y = None, None
        else:
            curr_x, curr_y = curr_x[0], curr_y[0]
        new_x, new_y = self._get_new_pos(curr_x, curr_y, step_size)
        # print("New position: ", new_y, new_x )
        return new_x is not None, curr_x, curr_y, new_x, new_y

    def step(self, action):
        # Action is the id of the piece to move
        assert self.action_space.contains(action)
        # Note that a player must move a piece if possible

        home_list = self.crosses_home if self.player == 1 else self.circles_home
        chose_impossible = home_list[action]

        # assert not home_list[action]
        step_size = self._get_step_size()

        id = self.player * (action + 1)
        if step_size == 0:
            # Piece is not on the board
            is_possible, curr_x, curr_y, new_x, new_y = True, None, None, None, None
        elif not chose_impossible:
            is_possible, curr_x, curr_y, new_x, new_y = self._check_possible_move(id, step_size)
        else:
            is_possible, curr_x, curr_y, new_x, new_y = False, None, None, None, None

        if not is_possible:
            # Unable to move
            i = 0
            # List of indices not home for current player
            players = np.where(home_list == 0)[0]
            np.random.shuffle(players)
            for i in players:
                id = self.player * (i + 1)
                is_possible, curr_x, curr_y, new_x, new_y = self._check_possible_move(id, step_size)
                if is_possible:
                    break
            if not is_possible:
                # No possible moves
                # print(f"Player {self.player} has no possible moves.")
                # print(f"Step size: {step_size}.")
                self.prev_step = 0
                self._set_step_size()
                self.player = -self.player
                return self._get_obs(), 0, False, False, self._get_info()
            else:
                chose_impossible = True

        # print(f"Player {self.player} chose piece with id {id} and step size {step_size}.")
        # print(f"Current position: ({curr_y}, {curr_x}). New position: ({new_y}, {new_x}).")
        if is_possible:
            if curr_x is not None and curr_y is not None:
                self.board[curr_y, curr_x] = 0

            if new_x == -1 and new_y == -1:
                # Piece is home
                home_list[action] = 1
            elif new_x is not None and new_y is not None:
                # Piece moved
                self.board[new_y, new_x] = id


        # An episode is done iff the current player has won or the board is full
        terminated = home_list.sum() == 7

        # Reward 1, 0, -1 for win, draw, loss
        s = 1 if self.player == self.starting_player else -1
        if terminated:
            reward = 10 * s
        elif not chose_impossible and home_list[action] == 1:
            # Got a piece home
            reward = 1 * s
        elif chose_impossible:
            reward = -0.1 * s
        else:
            reward = 0

        if not self._is_rosette(new_x, new_y):
            self.player = -self.player
        else:
            # print("Player gets another turn.")
            pass

        self.prev_step = step_size
        self._set_step_size()
        observation = self._get_obs()
        info = self._get_info()

        if self.render_mode == "human":
            self._render_frame()

        return observation, reward, terminated, False, info

    def render(self):
        if self.render_mode == "rgb_array":
            return self._render_frame()

    def _draw_cross(self, canvas, center_x, center_y, size, color, line_width=3):
        # Draw a cross centered at (center_x, center_y) with a given size and color
        pad = size / 10
        pygame.draw.line(
            canvas,
            color,
            (center_x - size / 2 + pad, center_y - size / 2 + pad),
            (center_x + size / 2 - pad, center_y + size / 2 - pad),
            width=line_width,
        )
        pygame.draw.line(
            canvas,
            color,
            (center_x + size / 2 - pad, center_y - size / 2 + pad),
            (center_x - size / 2 + pad, center_y + size / 2 - pad),
            width=line_width,
        )

    def _draw_circle(self, canvas, center_x, center_y, size, color, line_width=3):
        # Draw a circle centered at (center_x, center_y) with a given size and color
        pad = size / 10
        pygame.draw.circle(
            canvas,
            color,
            (center_x, center_y),
            int(size / 2 - pad),
            width=line_width,
        )


    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        pix_square_size = (
            self.window_size / 8
        )  # The size of a single grid square in pixels

        # First we draw crosses and circles
        pad = pix_square_size / 10
        width = 3

        # Paint black rectangles at x == 4, 5 and y == 0, 2
        pygame.draw.rect(
            canvas,
            (0, 0, 0),
            (pix_square_size * 4, pix_square_size * 3, pix_square_size * 2, pix_square_size),
            width=0,
        )
        pygame.draw.rect(
            canvas,
            (0, 0, 0),
            (pix_square_size * 4, pix_square_size * 5, pix_square_size * 2, pix_square_size),
            width=0,
        )

        rosette_coords = [(0, 0), (0, 2), (6, 0), (6, 2), (3, 1)]
        for x, y in rosette_coords:
            # Paint flower if a rosette square
            pygame.draw.circle(
                canvas,
                (255, 0, 0),
                (
                    int(pix_square_size * (x + 0.5)),
                    int(pix_square_size * (y + 3 + 0.5)),
                ),
                int(pix_square_size / 4 - pad),
                width=width,
            )

        for row in range(3, 6):
            for col in range(8):
                idx_row = row - 3
                if self.board[idx_row, col] > 0:
                    self._draw_cross(
                        canvas,
                        int(pix_square_size * (col + 0.5)),
                        int(pix_square_size * (row + 0.5)),
                        int(pix_square_size),
                        (0, 0, 0),
                    )
                elif self.board[idx_row, col] < 0:
                    self._draw_circle(
                        canvas,
                        int(pix_square_size * (col + 0.5)),
                        int(pix_square_size * (row + 0.5)),
                        int(pix_square_size),
                        (0, 0, 0),
                    )

        # Finally, add some gridlines
        # Note x: 0 is the leftmost, y: 0 is the topmost
        for y in range(3, 7):
            pygame.draw.line(
                canvas,
                0,
                (0, pix_square_size * y),
                (self.window_size, pix_square_size * y),
                width=3,
            )
        for x in range(9):
            pygame.draw.line(
                canvas,
                0,
                (pix_square_size * x, pix_square_size * 3),
                (pix_square_size * x, pix_square_size * 6),
                width=3,
            )

        crosses_left = 7 - self.crosses_home.sum() - np.count_nonzero(self.board > 0)
        circles_left = 7 - self.circles_home.sum() - np.count_nonzero(self.board < 0)
        for i in range(crosses_left):
            self._draw_cross(
                canvas,
                int(0.3 * pix_square_size * (i + 0.5) + 10),
                int(2 * pix_square_size * (0.5)),
                int(pix_square_size / 4),
                (0, 0, 0),
            )
        for i in range(circles_left):
            self._draw_circle(
                canvas,
                int(0.25 * pix_square_size * (i + 0.5) + 340),
                int(2 * pix_square_size * (0.5)),
                int(pix_square_size / 4),
                (0, 0, 0),
            )


        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the predefined framerate.
            # The following line will automatically add a delay to keep the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

In [4]:
# @title Neural Net
class PlayerNet(nn.Module):
  def __init__(self, input_dim, output_dim):
    super().__init__()
    h, w = input_dim

    self.step_encoder = nn.Linear(1, 64)
    self.online = self._build_dense(h, w, output_dim)
    self.target = self._build_dense(h, w, output_dim)

    self.target.load_state_dict(self.online.state_dict())

    for p in self.target.parameters():
      p.requires_grad = False

  def forward(self, input, model):
    input = input.float()
    nn_model = self.online if model == "online" else self.target
    step_batch = input[:, 0, 4].contiguous() / 10
    input[:, 0, 4] = 0.

    if len(step_batch.shape) == 1:
      step_batch = step_batch.unsqueeze(1)
    step_emb = self.step_encoder(step_batch)

    for i, layer in enumerate(nn_model):
      if i == 3:
        input = torch.cat([input, step_emb], dim=1)
      input = layer(input)
    return input

  def _build_cnn(self, h, w, output_dim):
    return nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=1),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear((h-4)** 2 * 64, 512),
        nn.ReLU(),
        nn.Linear(512, output_dim),
    )

  def _build_dense(self, h, w, output_dim):
    return nn.ModuleList([
        nn.Flatten(),
        nn.Linear(h * w, 64),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, output_dim)]
    )

In [14]:
# @title Model class
class Player:
  def __init__(self, player, state_dim, action_dim, save_dir):
    # player = 1 (X) or -1 (O)
    # state_dim = (3, 8)
    # action_dim = 7
    self.state_dim = state_dim
    self.action_dim = action_dim # (x, y) to place the X or O
    self.player = player
    self.save_dir = save_dir

    self.exp_rate = 1
    self.exp_rate_decay = 0.9999975
    self.exp_rate_min = 0.2

    self.device = "cuda" if torch.cuda.is_available() else "cpu"
    self.net = PlayerNet(self.state_dim, self.action_dim).float()
    self.net = self.net.to(device=self.device)

    self.curr_step = 0
    self.save_every = 5e4

    self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000,
                                         device=torch.device('cpu')))
    self.batch_size = 32

    self.gamma = 0.9
    self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.001)
    self.loss_fn = torch.nn.SmoothL1Loss()

    self.burnin = 1e3  # min. experiences before training
    self.learn_every = 1  # no. of experiences between updates to Q_online
    self.sync_every = 1e3  # no. of experiences between Q_target & Q_online sync


  def act(self, state):
    # State is 3 x 8
    # Explore
    if np.random.rand() < self.exp_rate:
      # We have 7 pieces. Moreover, with how the env is setup, if we choose
      # an invalid piece, it will automatically force a random move if possible.
      action_idx = np.random.randint(7)
      from_model = False
    else:
      state = state.__array__().copy()
      state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
      action_values = self.net(state, "online")

      # Take the argmax over the valid action values
      action_idx = torch.argmax(action_values, dim=1).item()

      from_model = True

    self.exp_rate *= self.exp_rate_decay
    self.exp_rate = max(self.exp_rate_min, self.exp_rate)

    self.curr_step += 1
    return action_idx, from_model

  def cache(self, state, next_state, action, reward, done):
    # Preprocess as before and note that we convert to array before tensor-ing
    def first_if_tuple(x):
      return x[0] if isinstance(x, tuple) else x

    state = first_if_tuple(state).__array__() # size x size
    next_state = first_if_tuple(state).__array__()
    # print(state)
    state = torch.tensor(state.copy())
    next_state = torch.tensor(next_state.copy())
    action = torch.tensor([action])
    reward = torch.tensor([reward])
    done = torch.tensor([done])

    # Leave batch size unspecified (later fixed during sampling)
    self.memory.add(TensorDict({
        'state': state,
        'next_state': next_state,
        'action': action,
        'reward': reward,
        'done': done
    }, batch_size=[]))

  def recall(self):
    batch = self.memory.sample(self.batch_size).to(self.device)
    state, next_state, action, reward, done = \
     (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
    return state, next_state, action, reward, done

  def td_estimate(self, state, action):
    current_Q = self.net(state, model="online")
    # print(current_Q.shape)
    # print("State: ", state)
    # print("Action: ", action)
    action_clone = action.contiguous()
    # print("Action print: ", action_clone)
    current_Q = current_Q[np.arange(0, self.batch_size), action_clone]
    # print(current_Q.shape)
    return current_Q

  @torch.no_grad()
  def td_target(self, reward, next_state, done):
    next_state_Q = self.net(next_state, model="online")
    max_action = torch.argmax(next_state_Q, dim=1)
    next_Q = self.net(next_state, model="target")[
        np.arange(0, self.batch_size), max_action
    ]

    return (reward + (1 - done.float()) * (self.gamma * next_Q)).float()

  def update_Q_online(self, td_e, td_t):
    loss = self.loss_fn(td_e, td_t)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    return loss.item()

  def sync_Q_target(self):
    self.net.target.load_state_dict(self.net.online.state_dict())

  def learn(self):
    if self.curr_step % self.sync_every == 0:
      self.sync_Q_target()
    if self.curr_step % self.save_every == 0:
      self.save()
    if self.curr_step % self.learn_every != 0:
      return None, None
    if self.curr_step < self.burnin:
      return None, None

    # Sample
    state, next_state, action, reward, done = self.recall()
    if state is None or action is None:
      return None, None

    # Get td_e, td_t
    td_e = self.td_estimate(state, action)
    td_t = self.td_target(reward, next_state, done)

    # Update params
    loss = self.update_Q_online(td_e, td_t)

    return (td_e.mean().item()), loss


  def save(self):
    pre = "X" if self.player == 1 else "O"
    save_path = (
        self.save_dir / f"{pre}_player_net_{int(self.curr_step // self.save_every)}.chkpt"
    )
    torch.save(
        dict(model=self.net.state_dict(), exploration_rate=self.exp_rate),
        save_path,
    )
    print(f"PlayerNet saved to {save_path} at step {self.curr_step}")

In [6]:
# @title Logger
import numpy as np
import time, datetime
import matplotlib.pyplot as plt


class MetricLogger:
    def __init__(self, save_dir, pre="W"):
        self.save_log = save_dir / "log"
        with open(self.save_log, "w") as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
        self.ep_rewards_plot = save_dir / f"{pre}_reward_plot.jpg"
        self.ep_lengths_plot = save_dir / f"{pre}_length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / f"{pre}_loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / f"{pre}_q_plot.jpg"
        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        "Mark end of episode"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)

        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        look_behind = 50
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-look_behind:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-look_behind:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-look_behind:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-look_behind:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(
                f"{episode:8d}{step:8d}{epsilon:10.3f}"
                f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
                f"{time_since_last_record:15.3f}"
                f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
            )

        for metric in ["ep_lengths", "ep_avg_losses", "ep_avg_qs", "ep_rewards"]:
            plt.clf()
            plt.plot(getattr(self, f"moving_avg_{metric}"), label=f"moving_avg_{metric}")
            plt.legend()
            plt.savefig(getattr(self, f'{metric}_plot'))

In [None]:
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()
env = Rogou(render_mode='rgb_array')

# Small technicality
def preprocess(state):
  return np.expand_dims(state, axis=0).astype(np.float32)

def swap(state):
  step = state[0, 4]
  state = np.flipud(state * -1)
  state[2, 4] = 0
  state[0, 4] = step
  return state

save_dir = Path("tcheckpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)
player = 1
model = Player(player=player, state_dim=(3, 8), action_dim=env.action_space.n, save_dir=save_dir)

episodes = 2000
def train_loop(episodes):
    logger = MetricLogger(save_dir)
    step_count = 0
    for e in range(episodes):

        state, info = env.reset()
        ep_steps = 0
        # Play the game!
        # state = preprocess(state)
        # Self-play
        while True:
            flip = info["player"] * player # -1 if swap necessary
            # Run agent on the state
            # Multiply state by playing_now so that the model only sees itself
            # as playing 'X'.
            if flip == -1:
              state = swap(state)
            # print("Prev state\n", state)
            # print("What env sees\n", env._get_obs())

            action, f = model.act(state)

            next_state, reward, done, trunc, info = env.step(action)
            # next_state = preprocess(next_state)
            if flip == -1:
              next_state = swap(next_state)
            # print("Next state\n", next_state)
            reward *= flip # +1 reward whenever the model wins (either X or O)

            # Remember
            model.cache(state, next_state, action, reward, done)

            # Learn
            q, loss = model.learn()

            # Logging
            logger.log_step(reward, loss, q)

            # Update variables
            state = swap(next_state) if flip == -1 else next_state
            ep_steps += 1
            step_count += 1

            # Check if end of game
            if done:
                # print("curr_reward: ", reward)
                # print("Game length ", ep_steps)
                # print("Overall Steps ", step_count)
                break

        logger.log_episode()

        if (e % 20 == 0) or (e == episodes - 1):
            logger.record(episode=e, epsilon=model.exp_rate, step=model.curr_step)

train_loop(episodes)
# train_loop(episodes, "O")

In [27]:
from IPython import display as ipythondisplay

def test_env(k: int=1):
  env = Rogou(render_mode='rgb_array')
  state, info = env.reset()
  screen = env.render()
  images = [Image.fromarray(screen)]
  terminated = False
  i = 0
  while not terminated:
      if i == 18: break
      curr_player = info["player"]
      # if i % 2 == 0:
      #   action, fn = model.act(state)
      # else:
      #   action = env.action_space.sample()
      if curr_player == -1:
        state = swap(state)
      action, fn = model.act(state)
      state, reward, terminated, truncated, info = env.step(action)
      # Render screen every k steps
      if i % k == 0:
        # print(f'observation \n{state}')
        screen = env.render()
        images.append(Image.fromarray(screen))
      if terminated:
          observation, info = env.reset()

      if terminated or truncated:
        print("DONE ", i)
        print("reward: ", reward)
        break
      i += 1


  env.close()

  return images

# Save GIF image
images = test_env()
image_file = 'short.gif'
# loop=0: loop forever, duration=1: play each frame for 1ms
images[0].save(
    image_file, save_all=True, append_images=images[1:],  loop=0, duration=3)

  and should_run_async(code)


In [25]:
# prompt: save the current model

torch.save(
        dict(model=model.net.state_dict(), exploration_rate=model.exp_rate),
        "model_2000.chkpt",
)