In [1]:
from dataclasses import dataclass
import random

import numpy as np
import torch

from c_tictactoe_py import TicTacToeEnvPy, Settings

In [2]:
class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(18, 128)
        self.fc2 = torch.nn.Linear(128, 128)
        self.fc3 = torch.nn.Linear(128, 18)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [21]:
@dataclass
class Params:
    lr: float
    gamma: float
    batch_size: int
    buffer_size: int
    n_envs: int
    num_steps: int
    num_episodes: int
    epsilon: float
    device: str

@dataclass
class Batch:
    states: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    next_states: np.ndarray
    dones: np.ndarray

class Buffer:
    def __init__(self, buffer_size: int, n_envs: int):
        self.buffer_idx = 0
        self.buffer_size = buffer_size
        self.full = False
        self.n_envs = n_envs
        self.state_buffer = np.zeros((self.buffer_size, self.n_envs, 18), dtype=np.int16)
        self.reward_buffer = np.zeros((self.buffer_size, self.n_envs), dtype=np.int16)
        self.done_buffer = np.zeros((self.buffer_size, self.n_envs), dtype=np.int16)
        self.action_buffer = np.zeros((self.buffer_size, self.n_envs), dtype=np.int16)

    def store(self, state, action, reward, done):
        self.state_buffer[self.buffer_idx] = state
        self.action_buffer[self.buffer_idx] = action
        self.reward_buffer[self.buffer_idx] = reward
        self.done_buffer[self.buffer_idx] = done
        self.buffer_idx += 1
        if self.buffer_idx >= self.buffer_size:
            self.buffer_idx = 0
            self.full = True

    def sample(self, batch_size) -> Batch:
        """SARS'D"""
        if self.buffer_idx < batch_size and not self.full:
            raise ValueError("Not enough samples in buffer")
        indices = np.random.randint(0, self.buffer_size if self.full else self.buffer_idx, batch_size)
        next_indices = indices + 1
        # SARSD
        return Batch(
            self.state_buffer[indices],
            self.action_buffer[indices],
            self.reward_buffer[indices],
            self.state_buffer[next_indices],
            self.done_buffer[indices]
        )

In [22]:
def optimize_model(model: MLP, optimizer: torch.optim.Optimizer, buffer: Buffer, params: Params):
    batch = buffer.sample(params.batch_size)

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    state_tensor = torch.tensor(batch.states).to(device="cuda", non_blocking=True).reshape(-1, 18)
    action_tensor = torch.tensor(batch.actions).to(device="cuda", non_blocking=True).reshape(-1, 1)
    reward_tensor = torch.tensor(batch.rewards).to(device="cuda", non_blocking=True).reshape(-1, 1)
    next_state_tensor = torch.tensor(batch.next_states).to(device="cuda", non_blocking=True).reshape(-1, 18)
    done_tensor = torch.tensor(batch.dones).to(device="cuda", non_blocking=True).reshape(-1, 1)
    non_final_mask = torch.tensor(1 - done_tensor, device="cuda", dtype=torch.bool)

    state_action_values = model(torch.tensor(state_tensor)).gather(1, action_tensor)

    next_state_values = torch.zeros(params.batch_size, device="cuda")
    with torch.no_grad():
        next_state_values[non_final_mask] = model(next_state_tensor).max(1).values
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * params.gamma) + reward_tensor

    # Compute Huber loss
    criterion = torch.nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(model.parameters(), 100)
    optimizer.step()

In [None]:
params = Params(
    lr=1e-3,
    gamma=0.99,
    batch_size=32,
    buffer_size=10_000,
    n_envs=1,
    num_episodes=1000,
    device="cuda",
    num_steps=500,
    epsilon=0.1
)
vec_env = TicTacToeEnvPy(Settings(batch_size=params.n_envs))
vec_env.reset_all()
model = MLP().to(params.device)
buffer = Buffer(params.buffer_size, params.n_envs)
optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
steps_done = 0

for step in range(params.num_steps):
    # Initialize the environment and get its state
    state = torch.tensor(vec_env.game_states, dtype=torch.float32, device=params.device)
    if random.random() < params.epsilon:
        action = torch.randint(0, 18, (1,), dtype=torch.int16)
    else:
        action = model(state).argmax(1)
    action = action.numpy(force=True).astype(np.int16)
    observations, rewards, dones, _ = vec_env.step(action)

    # Store the transition in memory
    buffer.store(observations, action, rewards, dones)

    # Perform one step of the optimization (on the policy network)
    # optimize_model()

print('Complete')