In [1]:
from contextlib import contextmanager
from dataclasses import dataclass
import time

import numpy as np
import torch
import matplotlib.pyplot as plt

from c_tictactoe_pvp_py import TicTacToeEnvPy, Settings

In [2]:
# random move stuff

In [3]:
def random_move(state: np.ndarray) -> int:
    # Reshape the first 18 columns into (N, 9, 2)
    squares = state[:, :18].reshape(-1, 9, 2)
    # A square is free if neither player 0 nor player 1 occupies it
    # i.e. both bits == 0
    free_squares = ~squares.any(axis=2)  # shape: (N, 9), True where square is (0,0)

    # For each row, pick one random free square
    moves = np.array([
        np.random.choice(np.where(row)[0])    # row has True/False for squares 0..8
        for row in free_squares
    ], dtype=np.int16)

    return moves

In [4]:
s = np.zeros((3, 19), dtype=np.uint8)
s[0, 0] = 1
s[1, 2] = 1
s[2, 0] = 1
s[2, 2] = 1
s

array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=uint8)

In [5]:
random_move(s)

array([2, 0, 2], dtype=int64)

In [6]:
# stepping stuff

In [7]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.network = torch.nn.Sequential(
            layer_init(torch.nn.Linear(19, 120)),
            torch.nn.LayerNorm(120),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(120, 84)),
            torch.nn.LayerNorm(84),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(84, 9)),
        )

    def forward(self, x):
        return self.network(x)

In [8]:
env = TicTacToeEnvPy(Settings(batch_size=3))
x_net = MLP()
o_net = MLP()
state, info = env.reset_all()
state

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int16)

In [20]:
n_step = 100
n_env = 3
state_space = 19
# buffers
x_observations = np.zeros((n_step + 1, n_env, state_space), dtype=np.float16)
x_rewards = np.zeros((n_step, n_env, 1), dtype=np.float32)
x_dones = np.zeros((n_step + 1, n_env, 1), dtype=np.int16)
x_actions = np.zeros((n_step, n_env, 1), dtype=np.int32)

o_observations = np.zeros((n_step + 1, n_env, state_space), dtype=np.float16)
o_rewards = np.zeros((n_step, n_env, 1), dtype=np.float32)
o_dones = np.zeros((n_step + 1, n_env, 1), dtype=np.int16)
o_actions = np.zeros((n_step, n_env, 1), dtype=np.int32)

x_ptr = np.zeros(n_env, dtype=np.int32)
o_ptr = np.zeros(n_env, dtype=np.int32)

In [21]:
state, info = env.reset_all()

In [23]:
# fill first state and done
# player_turns = state[:, -1]
# x_observations[0] = state

for _ in range(100):
    player_turns = state[:, -1].astype(int)

    x_envs = np.where(player_turns == 0)[0]  # e.g. [0, 2, 5, ...]
    o_envs = np.where(player_turns == 1)[0]

    # ----

    # actions alternative 1
    # x_actions = random_move(state).astype(np.int16)
    # o_actions = random_move(state).astype(np.int16)
    # actions = np.where(player_turns == 0, x_actions, o_actions)

    # acitons alternative 2
    x_act = random_move(state[x_envs]) if len(x_envs) > 0 else []
    o_act = random_move(state[o_envs]) if len(o_envs) > 0 else []

    # Build the "actions" array for all env
    actions = np.zeros(n_env, dtype=np.int16)
    if len(x_envs) > 0:
        actions[x_envs] = x_act
    if len(o_envs) > 0:
        actions[o_envs] = o_act

    # ----

    next_state, reward, done, info = env.step(actions)

    # Record transitions for X
    if len(x_envs) > 0:
        # We do advanced indexing: 
        x_observations[x_ptr[x_envs], x_envs, :] = state[x_envs]
        x_actions[x_ptr[x_envs], x_envs] = actions[x_envs].reshape(-1, 1) # because actions storage has shape (n_step, n_env, 1)
        x_rewards[x_ptr[x_envs], x_envs] = reward[x_envs, 0].reshape(-1, 1)  # X's reward
        x_dones[x_ptr[x_envs], x_envs] = done[x_envs].reshape(-1, 1)
        # Increment the pointer for those envs
        x_ptr[x_envs] += 1

    # Record transitions for O
    if len(o_envs) > 0:
        o_observations[o_ptr[o_envs], o_envs, :] = state[o_envs]
        o_actions[o_ptr[o_envs], o_envs] = actions[o_envs].reshape(-1, 1)
        o_rewards[o_ptr[o_envs], o_envs] = reward[o_envs, 1].reshape(-1, 1)  # O's reward
        o_dones[o_ptr[o_envs], o_envs] = done[o_envs].reshape(-1, 1)
        o_ptr[o_envs] += 1

    state = next_state

# Main

In [25]:
@contextmanager
def time_block(label: str, timing_dict: dict[str, float]):
    start_time = time.time()
    yield
    end_time = time.time()
    timing_dict[label] = end_time - start_time

In [27]:
@dataclass
class Params:
    num_iterations: int
    lr: float
    gamma: float
    q_lambda: float
    num_envs: int
    num_steps: int
    # model_device: str
    # training_device: str
    update_epochs: int
    num_minibatches: int
    num_eval_steps: int
    start_epsilon: float
    end_epsilon: float
    epsilon_decay_fraction: float

In [28]:
def evaluate(env: TicTacToeEnvPy, x_net: MLP, o_net: MLP, n_eval_steps=30) -> tuple[float, float]:
    """
    Quick example of evaluating: we let each agent pick greedy moves
    for `n_eval_steps` steps in parallel, ignoring exploration.
    Then we measure average reward for X and O.
    """
    s, _ = env.reset_all()
    total_rX = 0.0
    total_rO = 0.0
    total_count = 0
    with torch.no_grad():
        for _ in range(n_eval_steps):
            player_turns = s[:, 18].astype(int)
            x_envs = np.where(player_turns == 0)[0]
            o_envs = np.where(player_turns == 1)[0]

            actions = np.zeros(env._batch_size, dtype=np.int16)

            if len(x_envs) > 0:
                x_tensor = torch.tensor(s[x_envs], dtype=torch.float32)
                x_q = x_net(x_tensor)
                x_actions = x_q.argmax(dim=1).cpu().numpy().astype(np.int16)
                actions[x_envs] = x_actions

            if len(o_envs) > 0:
                o_tensor = torch.tensor(s[o_envs], dtype=torch.float32)
                o_q = o_net(o_tensor)
                o_actions = o_q.argmax(dim=1).cpu().numpy().astype(np.int16)
                actions[o_envs] = o_actions

            s, r, d, _ = env.step(actions)
            total_rX += r[:,0].sum()  # summing X's rewards
            total_rO += r[:,1].sum()
            total_count += env._batch_size
        avg_rX = total_rX / total_count
        avg_rO = total_rO / total_count
    return avg_rX, avg_rO

In [29]:
def get_epsilon(params: Params, iteration: int):
    slope = (params.end_epsilon - params.start_epsilon) / (params.epsilon_decay_fraction * params.num_iterations)
    return max(slope * iteration + params.start_epsilon, params.end_epsilon)

In [30]:
params = Params(
    num_iterations=6,
    lr=6e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=2048,
    num_steps=26,
    model_device="cuda",
    training_device="cuda",
    update_epochs=4,
    num_minibatches=4,
    num_eval_steps=12,
    start_epsilon=1.0,
    end_epsilon=0.1,
    epsilon_end_fraction=0.5,
)
def train_two_agent(params: Params):
    vec_env = TicTacToeEnvPy(Settings(batch_size=params.num_envs))
    observations, _ = vec_env.reset_all()
    x_network = MLP().to(params.model_device)
    o_network = MLP().to(params.model_device)
    # optimizer = torch.optim.AdamW(x_network.parameters(), lr=params.lr, amsgrad=True)
    x_optimizer = torch.optim.RMSprop(x_network.parameters(), lr=params.lr)
    o_optimizer = torch.optim.RMSprop(x_network.parameters(), lr=params.lr)
    steps_done = 0
    losses = []
    evals = []
    timing_data = {}

    x_observations = np.zeros((params.num_steps + 1, params.num_envs, state_space), dtype=np.float16)
    x_rewards = np.zeros((params.num_steps, params.num_envs, 1), dtype=np.float32)
    x_dones = np.zeros((params.num_steps + 1, params.num_envs, 1), dtype=np.int16)
    x_actions = np.zeros((params.num_steps, params.num_envs, 1), dtype=np.int32)

    o_observations = np.zeros((params.num_steps + 1, params.num_envs, state_space), dtype=np.float16)
    o_rewards = np.zeros((params.num_steps, params.num_envs, 1), dtype=np.float32)
    o_dones = np.zeros((params.num_steps + 1, params.num_envs, 1), dtype=np.int16)
    o_actions = np.zeros((params.num_steps, params.num_envs, 1), dtype=np.int32)

    x_ptr = np.zeros(params.num_envs, dtype=np.int32)
    o_ptr = np.zeros(params.num_envs, dtype=np.int32)

    x_values = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)
    x_returns = torch.zeros_like(x_rewards).to(params.training_device)
    o_values = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)
    o_returns = torch.zeros_like(o_rewards).to(params.training_device)

    state_tensor = torch.zeros((params.num_envs, 18), device=params.model_device)

    batch_size = params.num_steps * params.num_envs // params.num_minibatches

    for iteration in range(params.num_iterations):
        with time_block("device_switch", timing_data):
            x_network.eval()
            x_network.to(params.model_device)
        # Eval
        with time_block("evaluation", timing_data):
            evals.append(eval(x_network, vec_env, params))

        epsilon = get_epsilon(params, iteration)
        
        # Rollout
        with time_block("rollout", timing_data):
            for step in range(params.num_steps):
                steps_done += 1
                state_tensor[:] = torch.from_numpy(observations.astype(np.float32)).to(device=params.model_device)

                random_actions = torch.randint(0, 9, (params.num_envs,), device=params.model_device)
                with torch.inference_mode():
                    x_q_values = x_network(state_tensor)
                    x_policy_actions = torch.argmax(x_q_values, dim=1)
                    x_values[step] = x_q_values[torch.arange(params.num_envs), x_policy_actions].flatten()

                    o_q_values = o_network(state_tensor)
                    o_policy_actions = torch.argmax(o_q_values, dim=1)
                    o_values[step] = o_q_values[torch.arange(params.num_envs), o_policy_actions].flatten()

                policy_actions = np.where(player_turns == 0, x_policy_actions, o_policy_actions)

                explore_mask = torch.rand(params.num_envs, device=params.model_device) < epsilon
                action = torch.where(explore_mask, random_actions, policy_actions)
                actions = action.unsqueeze(1)

                observations, reward, done, _ = vec_env.step(action.numpy(force=True).astype(np.int16)) 

                if len(x_envs) > 0:
                    # We do advanced indexing: 
                    x_observations[x_ptr[x_envs], x_envs, :] = state[x_envs]
                    x_actions[x_ptr[x_envs], x_envs] = actions[x_envs].reshape(-1, 1) # because actions storage has shape (n_step, n_env, 1)
                    x_rewards[x_ptr[x_envs], x_envs] = reward[x_envs, 0].reshape(-1, 1)  # X's reward
                    x_dones[x_ptr[x_envs], x_envs] = done[x_envs].reshape(-1, 1)
                    # Increment the pointer for those envs
                    x_ptr[x_envs] += 1

                # Record transitions for O
                if len(o_envs) > 0:
                    o_observations[o_ptr[o_envs], o_envs, :] = state[o_envs]
                    o_actions[o_ptr[o_envs], o_envs] = actions[o_envs].reshape(-1, 1)
                    o_rewards[o_ptr[o_envs], o_envs] = reward[o_envs, 1].reshape(-1, 1)  # O's reward
                    o_dones[o_ptr[o_envs], o_envs] = done[o_envs].reshape(-1, 1)
                    o_ptr[o_envs] += 1

            # add final observation
            # next_obs = torch.tensor(observations, dtype=torch.float32, device=params.model_device)
            # obs[-1] = next_obs
            # dones[-1] = done
        
        # Train
        with time_block("device_switch", timing_data):
            x_network.train()
            x_network.to(params.training_device)
        
        x_next_obs = x_observations[x_ptr - 1]
        
        # Q(lambda)
        with time_block("q_lambda", timing_data):
            with torch.no_grad():
                for t in reversed(range(params.num_steps)):
                    if t == params.num_steps - 1:
                        next_value, _ = torch.max(x_network(x_next_obs.to(params.training_device)), dim=-1)
                        nextnonterminal = 1.0 - done
                        returns[t] = rewards[t] + params.gamma * next_value * nextnonterminal
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_value = values[t + 1]
                        returns[t] = rewards[t] + params.gamma * (
                            params.q_lambda * returns[t + 1] + (1 - params.q_lambda) * next_value * nextnonterminal
                        )

        with time_block("optimize", timing_data):
            # flatten the batch
            b_obs = obs.reshape((-1, 18))
            b_actions = actions.reshape((-1, 1))
            b_returns = returns.reshape(-1)

            # Optimize
            inds = np.arange(params.num_steps * params.num_envs)
            for epoch in range(params.update_epochs):
                np.random.shuffle(inds)
                for batch_start in range(0, params.num_steps * params.num_envs, batch_size):
                    batch_inds = inds[batch_start : batch_start + batch_size]

                    old_val = x_network(b_obs[batch_inds]).gather(1, b_actions[batch_inds].long()).squeeze()
                    loss = torch.nn.functional.mse_loss(b_returns[batch_inds], old_val)

                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(x_network.parameters(), max_norm=10.0)
                    optimizer.step()
                    losses.append(loss.detach().cpu().numpy())

    plt.plot(losses)
    plt.show()
    plt.plot(evals)
    plt.show()
    print(timing_data)

TypeError: zeros_like(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

In [None]:
q_network(torch.zeros(1, 18, device=params.training_device))

In [None]:
# cake diagram of timing data
plt.figure(figsize=(10, 10))
plt.pie(timing_data.values(), labels=timing_data.keys(), autopct="%1.1f%%")
plt.show()