In [62]:
from contextlib import contextmanager
from dataclasses import dataclass
import time

import numpy as np
import torch
import matplotlib.pyplot as plt

from c_tictactoe_pvp_py import TicTacToeEnvPy, Settings

In [63]:
def random_move(state: np.ndarray) -> int:
    # Reshape the first 18 columns into (N, 9, 2)
    squares = state[:, :18].reshape(-1, 9, 2)
    # A square is free if neither player 0 nor player 1 occupies it
    # i.e. both bits == 0
    free_squares = ~squares.any(axis=2)  # shape: (N, 9), True where square is (0,0)

    # For each row, pick one random free square
    moves = np.array([
        np.random.choice(np.where(row)[0])    # row has True/False for squares 0..8
        for row in free_squares
    ], dtype=np.int16)

    return moves

In [36]:
s = np.zeros((3, 19), dtype=np.uint8)
s[0, 0] = 1
s[1, 2] = 1
s[2, 0] = 1
s[2, 2] = 1
s

array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=uint8)

In [37]:
random_move(s)

array([3, 2, 5], dtype=int16)

In [38]:
# stepping stuff

In [64]:
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.network = torch.nn.Sequential(
            layer_init(torch.nn.Linear(19, 120)),
            torch.nn.LayerNorm(120),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(120, 84)),
            torch.nn.LayerNorm(84),
            torch.nn.ReLU(),
            layer_init(torch.nn.Linear(84, 9)),
        )

    def forward(self, x):
        return self.network(x)

In [40]:
env = TicTacToeEnvPy(Settings(batch_size=3))
x_net = MLP()
o_net = MLP()

In [41]:
n_step = 100
n_env = 3
# buffers
x_states = np.zeros((n_step, n_env, 19), dtype=np.float16)
x_actions = np.zeros((n_step, n_env), dtype=np.int32)
x_rewards = np.zeros((n_step, n_env), dtype=np.float32)
# x_next_states = np.zeros((n_step, n_env, 19), dtype=np.float16)
x_dones = np.zeros((n_step + 1, n_env), dtype=np.int16)

x_ptr = np.zeros(n_env, dtype=np.int32)
x_size = 0

o_states = np.zeros((n_step, n_env, 19), dtype=np.float16)
o_actions = np.zeros((n_step, n_env, 1), dtype=np.int32)
o_rewards = np.zeros((n_step, n_env, 1), dtype=np.float32)
# o_next_states = np.zeros((n_step, n_env, 19), dtype=np.float16)
o_dones = np.zeros((n_step + 1, n_env, 1), dtype=np.int16)

o_ptr = np.zeros(n_env, dtype=np.int32)
o_size = 0

In [42]:
state, info = env.reset_all()
reward = np.zeros((n_env, 2), dtype=np.float32)
done = np.zeros(n_env, dtype=np.int16)

for _ in range(100):
    player_turns = state[:, -1].astype(int)

    x_envs = np.where(player_turns == 0)[0] 
    o_envs = np.where(player_turns == 1)[0]

    # Sample action
    actions = np.random.randint(9, size=n_env).astype(np.int16)

    # Record transitions before we step, because we will maybe not get the same players reward, observations etc.
    if len(x_envs) > 0:
        for e in x_envs:
            t = x_ptr[e]  # the time‐index for X transitions in environment e

            # Fill the buffers at [t, e]
            x_states[t, e] = state[e]      
            x_actions[t, e] = actions[e] 
            x_rewards[t, e] = reward[e, 0]  
            x_dones[t, e] = done[e] 
            # Move the pointer
            x_ptr[e] += 1

    # Record transitions for O
    if len(o_envs) > 0:
        for e in o_envs:
            t = o_ptr[e]
            o_states[t, e] = state[e]
            o_actions[t, e] = actions[e]
            o_rewards[t, e] = reward[e, 1] 
            o_dones[t, e] = done[e]

            o_ptr[e] += 1

    next_state, reward, done, info = env.step(actions)

    state = next_state

# add final transitions (or maybe cut last?)
# TODO

# Main

In [65]:
@contextmanager
def time_block(label: str, timing_dict: dict[str, float]):
    start_time = time.time()
    yield
    end_time = time.time()
    timing_dict[label] = end_time - start_time

In [66]:
@dataclass
class Params:
    num_iterations: int
    lr: float
    gamma: float
    q_lambda: float
    num_envs: int
    num_steps: int
    model_device: str
    training_device: str
    update_epochs: int
    num_minibatches: int
    num_eval_steps: int
    start_epsilon: float
    end_epsilon: float
    epsilon_decay_fraction: float

In [67]:
def evaluate(env: TicTacToeEnvPy, x_net: MLP, o_net: MLP, n_eval_steps=30) -> tuple[float, float]:
    """
    Quick example of evaluating: we let each agent pick greedy moves
    for `n_eval_steps` steps in parallel, ignoring exploration.
    Then we measure average reward for X and O.
    """
    s, _ = env.reset_all()
    total_rX = 0.0
    total_rO = 0.0
    total_count = 0
    with torch.no_grad():
        for _ in range(n_eval_steps):
            player_turns = s[:, 18].astype(int)
            x_envs = np.where(player_turns == 0)[0]
            o_envs = np.where(player_turns == 1)[0]

            actions = np.zeros(env._batch_size, dtype=np.int16)

            if len(x_envs) > 0:
                x_tensor = torch.tensor(s[x_envs], dtype=torch.float32)
                x_q = x_net(x_tensor)
                x_actions = x_q.argmax(dim=1).cpu().numpy().astype(np.int16)
                actions[x_envs] = x_actions

            if len(o_envs) > 0:
                o_tensor = torch.tensor(s[o_envs], dtype=torch.float32)
                o_q = o_net(o_tensor)
                o_actions = o_q.argmax(dim=1).cpu().numpy().astype(np.int16)
                actions[o_envs] = o_actions

            s, r, d, _ = env.step(actions)
            total_rX += r[:,0].sum()  # summing X's rewards
            total_rO += r[:,1].sum()
            total_count += env._batch_size
        avg_rX = total_rX / total_count
        avg_rO = total_rO / total_count
    return avg_rX, avg_rO

In [68]:
def get_epsilon(params: Params, iteration: int):
    slope = (params.end_epsilon - params.start_epsilon) / (params.epsilon_decay_fraction * params.num_iterations)
    return max(slope * iteration + params.start_epsilon, params.end_epsilon)

In [158]:
params = Params(
    num_iterations=4,
    lr=6e-4,
    gamma=0.99,
    q_lambda=0.65,
    num_envs=6,
    num_steps=26,
    model_device="cpu",
    training_device="cpu",
    update_epochs=4,
    num_minibatches=4,
    num_eval_steps=12,
    start_epsilon=1.0,
    end_epsilon=0.1,
    epsilon_decay_fraction=0.5,
)
def train_two_agent(params: Params):
    vec_env = TicTacToeEnvPy(Settings(batch_size=params.num_envs))

    # networks and optimizers
    x_network = MLP().to(params.model_device)
    o_network = MLP().to(params.model_device)
    x_optimizer = torch.optim.RMSprop(x_network.parameters(), lr=params.lr)
    o_optimizer = torch.optim.RMSprop(x_network.parameters(), lr=params.lr)

    # X Buffers
    # No need to pad because we know the max size would be -1 from num steps, but probably way lower
    x_states = np.zeros((params.num_steps, params.num_envs, 19), dtype=np.float16)
    x_actions = np.zeros((params.num_steps, params.num_envs), dtype=np.int32)
    x_rewards = np.zeros((params.num_steps, params.num_envs), dtype=np.float32)
    x_dones = np.zeros((params.num_steps, params.num_envs), dtype=np.int16)
    
    x_ptr = np.zeros(params.num_envs, dtype=np.int32)

    x_values = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)
    x_returns = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)
    
    # O Buffers
    o_states = np.zeros((params.num_steps, params.num_envs, 19), dtype=np.float16)
    o_actions = np.zeros((params.num_steps, params.num_envs), dtype=np.int32)
    o_rewards = np.zeros((params.num_steps, params.num_envs), dtype=np.float32)
    o_dones = np.zeros((params.num_steps, params.num_envs), dtype=np.int16)

    o_ptr = np.zeros(params.num_envs, dtype=np.int32)

    o_values = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)
    o_returns = torch.zeros((params.num_steps, params.num_envs)).to(params.training_device)

    global_step = 0
    timing_data = {}
    losses = []
    x_evals = []
    o_evals = []

    # init state tensor prior to loop
    state_tensor = torch.zeros((params.num_envs, 19), device=params.model_device)
    # mini-batch size
    mb_size = params.num_steps * params.num_envs // params.num_minibatches

    for iteration in range(params.num_iterations):
        print(f"Iteration {iteration}")
        with time_block("device_switch", timing_data):
            x_network.eval()
            x_network.to(params.model_device)
            o_network.eval()
            o_network.to(params.model_device)
        # Eval
        with time_block("evaluation", timing_data):
            x_rew, o_rew = evaluate(vec_env, x_network, o_network, params.num_eval_steps)
            x_evals.append(x_rew)
            o_evals.append(o_rew)

        epsilon = get_epsilon(params, iteration)
        
        # Rollout
        with time_block("rollout", timing_data):
            
            # Reset the environments
            state, _ = vec_env.reset_all()
            reward = np.zeros((params.num_envs, 2), dtype=np.float32)
            done = np.zeros(params.num_envs, dtype=np.int16)

            for step in range(params.num_steps):
                state_tensor[:] = torch.from_numpy(state.astype(np.float32)).to(device=params.model_device)

                player_turns = state[:, 18].astype(int)
                x_envs = np.where(player_turns == 0)[0]
                o_envs = np.where(player_turns == 1)[0]

                actions = np.zeros(params.num_envs, dtype=np.int16)
                # random_actions = torch.randint(0, 9, (params.num_envs,), device=params.model_device)
            
                # X chooses actions for x_envs
                if len(x_envs) > 0:
                    # gather states
                    x_s = torch.tensor(state[x_envs], dtype=torch.float32)
                    with torch.no_grad():
                        q_vals = x_network(x_s)
                        best_acts = q_vals.argmax(dim=1).numpy().astype(np.int16)
                        x_values[x_ptr[x_envs], x_envs] = q_vals[torch.arange(len(x_envs)), best_acts]

                    # Epsilon-greedy
                    rand_acts = np.random.randint(0,9,size=len(x_envs)) # or random_move(s[x_envs])
                    mask = (np.random.rand(len(x_envs)) < epsilon)
                    final_acts = np.where(mask, rand_acts, best_acts)

                    actions[x_envs] = final_acts

                # O chooses actions for o_envs
                if len(o_envs) > 0:
                    o_s = torch.tensor(state[o_envs], dtype=torch.float32)
                    with torch.no_grad():
                        q_vals = o_network(o_s)
                        best_acts = q_vals.argmax(dim=1).numpy().astype(np.int16)
                        o_values[o_ptr[o_envs], o_envs] = q_vals[torch.arange(len(o_envs)), best_acts]

                    # Epsilon-greedy
                    rand_acts = random_move(state[o_envs])  
                    mask = (np.random.rand(len(o_envs)) < epsilon)
                    final_acts = np.where(mask, rand_acts, best_acts)

                    actions[o_envs] = final_acts

                # Record transitions for X
                if len(x_envs) > 0:
                    for e in x_envs:
                        t = x_ptr[e]  # the time‐index for X transitions in environment e

                        # Fill the buffers at [t, e]
                        # print(t, e)
                        # print(x_ptr)
                        # print(x_states.shape)
                        # print(state)
                        x_states[t, e] = state[e]
                        x_actions[t, e] = actions[e]
                        x_rewards[t, e] = reward[e, 0]
                        x_dones[t + 1, e] = done[e]

                        # Move the pointer
                        x_ptr[e] += 1

                # Record transitions for O
                if len(o_envs) > 0:
                    for e in o_envs:
                        t = o_ptr[e]
                        o_states[t, e] = state[e]
                        o_actions[t, e] = actions[e]
                        o_rewards[t, e] = reward[e, 1]
                        o_dones[t + 1, e] = done[e]

                        o_ptr[e] += 1
                
                state_next, reward, done, _ = vec_env.step(actions) 

                state = state_next
                global_step += 1

            # add final observation
            next_obs = torch.tensor(s, dtype=torch.float32, device=params.model_device)
            player_turns = s[:, 18].astype(int)
            x_envs = np.where(player_turns == 0)[0]
            o_envs = np.where(player_turns == 1)[0]
            
            if len(x_envs) > 0:
                x_states[x_ptr[x_envs], x_envs] = next_obs[x_envs]

            if len(o_envs) > 0:
                o_states[o_ptr[o_envs], o_envs] = next_obs[o_envs]
        
        # Train
        with time_block("device_switch", timing_data):
            x_network.train()
            x_network.to(params.training_device)
        
        x_next_obs = x_states[x_ptr - 1]
        
        # Q(lambda)
        with time_block("x_q_lambda", timing_data):
            # create transition masks based on ptr to avoid calculating values for padding
            print(x_ptr)
            row = np.arange(params.num_steps)
            arr = np.tile(row, (params.num_envs, 1)).T
            print(arr)
            x_mask = arr < x_ptr
            print(x_mask)

            with torch.no_grad():
                for t in reversed(range(params.num_steps)):
                    if t == params.num_steps - 1:
                        next_value, _ = torch.max(x_network(x_next_obs.to(params.training_device)), dim=-1)
                        nextnonterminal = 1.0 - done
                        x_returns[t] = x_rewards[t] + params.gamma * next_value * nextnonterminal
                    else:
                        nextnonterminal = 1.0 - x_dones[t + 1]
                        next_value = x_values[t + 1]
                        x_returns[t] = x_rewards[t] + params.gamma * (
                            params.q_lambda * x_returns[t + 1] + (1 - params.q_lambda) * next_value * nextnonterminal
                        )

        with time_block("train_X", timing_data):
            if x_ptr > 0:
                # flatten the batch
                # b_obs = obs.reshape((-1, 18))
                # b_actions = actions.reshape((-1, 1))
                # b_returns = returns.reshape(-1)

                # Optimize
                inds = np.arange(x_ptr)
                for epoch in range(params.update_epochs):
                    np.random.shuffle(inds)
                    for start in range(0, x_ptr, mb_size):
                        end = min(start + mb_size, x_ptr)
                        mb_inds = inds[start : start + mb_size]

                        mb_s = torch.tensor(x_states[mb_inds], dtype=torch.float32)
                        mb_a = torch.tensor(x_actions[mb_inds], dtype=torch.long)
                        mb_r = torch.tensor(x_rewards[mb_inds], dtype=torch.float32)
                        mb_s_next = torch.tensor(x_states[mb_inds + 1], dtype=torch.float32)
                        mb_done = torch.tensor(x_dones[mb_inds], dtype=torch.float32)

                        # 1-step TD target
                        with torch.no_grad():
                            q_next = o_network(mb_s_next)
                            q_next_max = q_next.max(dim=1).values
                            target = mb_r + params.gamma * (1 - mb_done) * q_next_max

                        old_val = x_network(b_obs[batch_inds]).gather(1, b_actions[batch_inds].long()).squeeze()
                        loss = torch.nn.functional.mse_loss(b_returns[batch_inds], old_val)

                        optimizer.zero_grad()
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(x_network.parameters(), max_norm=10.0)
                        optimizer.step()
                        losses.append(loss.detach().cpu().numpy())

    plt.plot(losses)
    plt.show()
    plt.plot(evals)
    plt.show()
    print(timing_data)
train_two_agent(params)

Iteration 0
[16 22 22 16 18 16]
[[ 0  0  0  0  0  0]
 [ 1  1  1  1  1  1]
 [ 2  2  2  2  2  2]
 [ 3  3  3  3  3  3]
 [ 4  4  4  4  4  4]
 [ 5  5  5  5  5  5]
 [ 6  6  6  6  6  6]
 [ 7  7  7  7  7  7]
 [ 8  8  8  8  8  8]
 [ 9  9  9  9  9  9]
 [10 10 10 10 10 10]
 [11 11 11 11 11 11]
 [12 12 12 12 12 12]
 [13 13 13 13 13 13]
 [14 14 14 14 14 14]
 [15 15 15 15 15 15]
 [16 16 16 16 16 16]
 [17 17 17 17 17 17]
 [18 18 18 18 18 18]
 [19 19 19 19 19 19]
 [20 20 20 20 20 20]
 [21 21 21 21 21 21]
 [22 22 22 22 22 22]
 [23 23 23 23 23 23]
 [24 24 24 24 24 24]
 [25 25 25 25 25 25]]
[[ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True  True]
 [ True  True  True  True  True

AttributeError: 'numpy.ndarray' object has no attribute 'to'

In [120]:
a = np.zeros((26, 4))
a[]

array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [None]:
q_network(torch.zeros(1, 18, device=params.training_device))

In [None]:
# cake diagram of timing data
plt.figure(figsize=(10, 10))
plt.pie(timing_data.values(), labels=timing_data.keys(), autopct="%1.1f%%")
plt.show()