In [32]:
import numpy as np
import torch
import ray
from tqdm import tqdm

import othello
from network import PVNet
from MCTS import MCTS
from buffer import ReplayBuffer

In [33]:
from dataclasses import dataclass

@ray.remote(num_cpus=1, num_gpus=0)
def selfplay(weights, num_mcts_simulations, dirichlet_alpha=0.35):
    @dataclass
    class Sample:
        board: list
        mcts_policy: list
        player: int
        reward: int

    record = []
    board = othello.get_initial_board()

    network = PVNet()
    network(othello.encode_state(board, othello.BLACK))
    network.set_weights(weights)

    mcts = MCTS(network, alpha=dirichlet_alpha)

    current_player = othello.BLACK
    done = False
    i = 0
    while not othello.is_done(board):
        mcts_policy = mcts.search(board, current_player, num_mcts_simulations)
        if i < 10:
            move = np.random.choice(range(othello.ACTION_SPACE), p=mcts_policy)

        else:
            move = np.random.choice(np.where(mcts_policy == mcts_policy.max())[0])

        record.append(Sample(board.tolist(), mcts_policy.tolist(), current_player, None))

        board = othello.step(board, move, current_player)
        current_player = 3 - current_player

        i += 1

    for sample in record:
        sample.reward = othello.get_result(sample.board, sample.player)

    return record


In [34]:
num_cpus = 4
batch_size = 128

In [None]:
def main(n_parallel_selfplay=20,
        num_mcts_simulations=50):

    ray.shutdown()
    ray.init(num_cpus=num_cpus, num_gpus=1)

    network = PVNet()
    dummy_state = othello.encode_state(othello.get_initial_board(), 1)
    network(dummy_state)

    current_weights = ray.put(network.get_weights())
    
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0005)
    
    replay = ReplayBuffer(buffer_size=40000)

    work_in_progresses = [
        selfplay.remote(current_weights, num_mcts_simulations)
        for _ in range(n_parallel_selfplay)]
    
    n = 0
    while n <= 10000:
        for _ in tqdm(range(4)):
            finished, work_in_progresses = ray.wait(work_in_progresses, num_returns=1)
            replay.add_record(ray.get(finished[0]))
            work_in_progresses.extend([
                selfplay.remote(current_weights, num_mcts_simulations)
            ])
            n += 1
        
        num_iters = 5 * (len(replay) // batch_size)        
        for i in range(num_iters):
            
            boards, mcts_policy, rewards = replay.get_minibatch(batch_size=batch_size)
            boards = torch.tensor(boards, dtype=torch.float32)
            mcts_policy = torch.tensor(mcts_policy, dtype=torch.float32)
            rewards = torch.tensor(rewards, dtype=torch.float32)

            optimizer.zero_grad()
            network.train()
            p_pred, v_pred = network(boards)
            
            value_loss = (rewards - v_pred).pow(2)
            
            policy_loss = -mcts_policy * torch.log(p_pred + 1e-4)
            policy_loss = torch.sum(policy_loss, dim=1, keepdim=True)

            loss = torch.mean(value_loss + policy_loss)
            print(f"loss: {loss.item()}")

            loss.backward()
            optimizer.step()

        current_weights = ray.put(network.get_weights())


In [42]:
main()

2025-03-11 19:11:03,515	INFO worker.py:1841 -- Started a local Ray instance.
100%|██████████| 5/5 [00:41<00:00,  8.24s/it]


AttributeError: 'numpy.ndarray' object has no attribute 'dim'