In [1]:
from dataclasses import dataclass

import numpy as np
import torch
import matplotlib.pyplot as plt

from c_tictactoe_py import TicTacToeEnvPy, Settings

In [2]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.network = torch.nn.Sequential(
            layer_init(torch.nn.Linear(18, 120)),
            torch.nn.LayerNorm(120),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(120, 84)),
            torch.nn.LayerNorm(84),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(84, 9)),
        )

    def forward(self, x):
        return self.network(x)

In [3]:
@dataclass
class Params:
    num_iterations: int
    lr: float
    gamma: float
    q_lambda: float
    num_envs: int
    num_steps: int
    device: str
    num_opt_steps: int
    batch_size: int
    num_eval_steps: int

In [4]:
def eval(model: MLP, env: TicTacToeEnvPy, params: Params):
    rewards = np.zeros((params.num_envs, params.num_eval_steps))
    with torch.no_grad():
        states, infos = env.reset_all()
        for i in range(params.num_eval_steps):
            state_tensor = torch.tensor(states, device=params.device, dtype=torch.float32)
            actions = model(state_tensor).max(1).indices.flatten().numpy(force=True).astype(np.int16)
            states, reward, done, _ = env.step(actions)
            # print(state, actions, reward, done)
            rewards[:, i] = reward
            # print(reward)
        r = np.mean(rewards)
        return r

In [None]:
params = Params(
    num_iterations=200,
    lr=2e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=24,
    num_steps=128,
    device="cpu",
    num_opt_steps=4,
    batch_size=32,
    num_eval_steps=100,
)
vec_env = TicTacToeEnvPy(Settings(batch_size=params.num_envs))
observations, _ = vec_env.reset_all()
q_network = MLP().to(params.device)
optimizer = torch.optim.AdamW(q_network.parameters(), lr=params.lr, amsgrad=True)
steps_done = 0
losses = []
evals = []

obs = torch.zeros((params.num_steps + 1, params.num_envs, 18)).to(params.device)
actions = torch.zeros((params.num_steps, params.num_envs, 1)).to(params.device)
rewards = torch.zeros((params.num_steps, params.num_envs)).to(params.device)
dones = torch.zeros((params.num_steps + 1, params.num_envs)).to(params.device)
values = torch.zeros((params.num_steps, params.num_envs)).to(params.device)
returns = torch.zeros_like(rewards).to(params.device)

for iteration in range(params.num_iterations):
    q_network.eval()
    # Eval
    evals.append(eval(q_network, vec_env, params))
    
    # Rollout
    done = torch.zeros(params.num_envs, dtype=torch.float32, device=params.device)
    for step in range(params.num_steps):
        steps_done += 1
        state_tensor = torch.tensor(observations.copy(), dtype=torch.float32, device=params.device)
        obs[step] = state_tensor
        dones[step] = done

        random_actions = torch.randint(0, 9, (params.num_envs,)).to(params.device)
        with torch.no_grad():
            q_values = q_network(state_tensor)
            policy_actions = torch.argmax(q_values, dim=1)
            values[step] = q_values[torch.arange(params.num_envs), policy_actions].flatten()
        
        explore_mask = torch.rand(params.num_envs) < 0.25
        action = torch.where(explore_mask, random_actions, policy_actions)
        actions[step] = action.unsqueeze(1)
        observations, reward, done, _ = vec_env.step(action.numpy(force=True).astype(np.int16))
        rewards[step] = torch.tensor(reward, dtype=torch.float32, device=params.device)
        done = torch.tensor(done, dtype=torch.float32, device=params.device)
    # add final observation
    next_obs = torch.tensor(observations, dtype=torch.float32, device=params.device)
    obs[-1] = next_obs
    dones[-1] = done
    
    # Train
    q_network.train()
    # Q(lambda)
    with torch.no_grad():
        for t in reversed(range(params.num_steps)):
            if t == params.num_steps - 1:
                next_value, _ = torch.max(q_network(next_obs), dim=-1)
                nextnonterminal = 1.0 - done
                returns[t] = rewards[t] + params.gamma * next_value * nextnonterminal
            else:
                nextnonterminal = 1.0 - dones[t + 1]
                next_value = values[t + 1]
                returns[t] = rewards[t] + params.gamma * (
                    params.q_lambda * returns[t + 1] + (1 - params.q_lambda) * next_value * nextnonterminal
                )

    # flatten the batch
    b_obs = obs.reshape((-1, 18))
    b_actions = actions.reshape((-1, 1))
    b_returns = returns.reshape(-1)

    # Optimize
    inds = np.arange(params.num_steps)
    for epoch in range(params.num_opt_steps):
        np.random.shuffle(inds)
        for batch_start in range(0, params.num_steps, params.batch_size):
            batch_inds = inds[batch_start : batch_start + params.batch_size]

            old_val = q_network(b_obs[batch_inds]).gather(1, b_actions[batch_inds].long()).squeeze()
            loss = torch.nn.functional.mse_loss(b_returns[batch_inds], old_val)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(q_network.parameters(), max_norm=10.0)
            optimizer.step()
            losses.append(loss.detach().cpu().numpy())

print(losses)
plt.plot(losses)
plt.show()
plt.plot(evals)
plt.show()

In [None]:
q_network(torch.zeros(1, 18))