<a href="https://colab.research.google.com/github/bgalbraith/accelerated-othello/blob/main/CUDA_enhanced_AlphaZero_Othello.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This is an initial pass at implementing massively parallel support for MCTS experiments using the game Othello.

It is based on custom CUDA kernels implemented via Numba's CUDA JIT capabilities.

Currently, the code does the following:
Given a number of games to evaluate in parallel, a starting board configuration, and the player who's turn it is, play all the games out to completion by taking random actions for each player. The final output is a vector indicating which player won each game or if it was a draw.

In [4]:
import math
import os
import time

from numba import jit, cuda, float32, int32
import numpy as np
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F

## Setting up Numba for CUDA

In order for Numba to JIT compile CUDA kernels, it needs to know where the appropriate system libraries are. This is done using environment variables. The below values should be safe for Google Colab, but if an error occurs when attempting to compile, running the following in a cell will give you an idea of which paths to use:
```
!find / -iname 'libdevice'
!find / -iname 'libnvvm.so'
```

In [5]:
os.environ['NUMBAPRO_LIBDEVICE'] = "/usr/local/cuda-10.0/nvvm/libdevice"
os.environ['NUMBAPRO_NVVM'] = "/usr/local/cuda-10.0/nvvm/lib64/libnvvm.so"

We set the number of games as a constant here. We also use a ray-casting strategy for finding valid moves in Othello, so we establish the eight possible ray directions here as well. Note the dtype is set to `np.int32`. This is because we will be loading this onto the GPU later, and it's better to use 32-bit data types (`int32`, `float32`) for GPU ops unless you really need the higher precision.

In [31]:
N_GAMES = 32
N_SIMULATIONS = 10
MAX_STATES = N_SIMULATIONS * 60  # upper bound
EPS = np.finfo(np.float32).eps
RAYS = np.array([[0, 1],  # east
                 [0, -1],  # west
                 [1, 0],  # south
                 [-1, 0],  # north
                 [1, 1],  # southeast
                 [1 , -1],  # southwest
                 [-1, 1],  # northeast
                 [-1, -1]  # northwest
                 ], dtype=np.int32)

# CUDA kernels

In [32]:
#@title ray_cast
@cuda.jit(device=True)
def ray_cast(board, y0, x0, ray, player):
  opponent = 3 - player
  y = y0 + ray[0]
  x = x0 + ray[1]

  if x < 0 or y < 0 or x >= 8 or y >= 8 or board[y, x] != opponent:
    return False
    
  y += ray[0]
  x += ray[1]
  while x >= 0 and y >= 0 and x < 8 and y < 8:
    if board[y, x] == 0:
      return False
    
    if board[y, x] == player:    
      return True
    
    y += ray[0]
    x += ray[1]
    
  return False

In [33]:
#@title find_valid_actions
@cuda.jit
def find_valid_actions(boards, player, valid_actions):
  rays = cuda.const.array_like(RAYS)

  tx = cuda.threadIdx.x
  ty = cuda.threadIdx.y
  bx = cuda.blockIdx.x
  
  # space isn't empty, invalid move
  if boards[bx, ty, tx] != 0:
    valid_actions[bx, ty, tx] = 0
    return

  opponent = 3 - player
  for i in range(8):
    hit = ray_cast(boards[bx], ty, tx, rays[i], player)
    if hit:
      valid_actions[bx, ty, tx] = 1
      return

  valid_actions[bx, ty, tx] = 0

In [34]:
#@title apply_ucb
@cuda.jit
def apply_ucb(q_values, edge_visits, board_visits, policies, board_ids, cpuct, valid_actions):
  tx = cuda.threadIdx.x
  ty = cuda.threadIdx.y
  bx = cuda.blockIdx.x
  b_id = board_ids[bx]

  ucb = (q_values[bx, b_id, ty, tx] +
         cpuct * policies[bx, b_id, ty, tx] *
         math.sqrt(board_visits[bx, b_id] + EPS) / 
         (1 + edge_visits[bx, b_id, ty, tx]))
  valid_actions[bx, ty, tx] *= ucb

In [35]:
#@title select_action
@cuda.jit
def select_action(valid_actions, actions):  
  row_max = cuda.shared.array(shape=(8,), dtype=float32)
  row_max_idx = cuda.shared.array(shape=(8,), dtype=int32)
  
  tx = cuda.threadIdx.x  
  bx = cuda.blockIdx.x

  # argmax reduction part 1
  current_max = -math.inf
  current_idx = 0
  for i in range(8):
    action_value = valid_actions[bx, tx, i]
    if action_value != 0 and action_value > current_max:
      current_max = action_value
      current_idx = i
  
  row_max[tx] = current_max
  row_max_idx[tx] = current_idx
  cuda.syncthreads()
  
  # argmax reduction part 2
  if tx == 0:
    current_max = -math.inf
    current_idx = 0
    for i in range(8):
      action_value = row_max[i]
      if action_value != 0 and action_value > current_max:
        current_max = action_value
        current_idx = i
    
    y_idx = -1
    x_idx = -1
    if current_max != 0 and current_max > -math.inf:
      y_idx = current_idx
      x_idx = row_max_idx[current_idx]
    
    actions[bx] = y_idx, x_idx

In [36]:
#@title step
@cuda.jit
def step(boards, actions, player, player_status):
  rays = cuda.const.array_like(RAYS)

  tx = cuda.threadIdx.x
  bx = cuda.blockIdx.x  

  # execute action and update game state
  opponent = 3 - player
  act_y, act_x = actions[bx]

  if act_x == -1:
    if tx == 0:
      player_status[bx][player-1] = 1
    return

  r = rays[tx]
  hit = ray_cast(boards[bx], act_y, act_x, r, player)
  if hit:    
    y = act_y + r[0]
    x = act_x + r[1]
    while boards[bx, y, x] == opponent:
      boards[bx, y, x] = player
      y += r[0]
      x += r[1]
      
  
  if tx == 0:
    boards[bx, act_y, act_x] = player    
    player_status[bx][player-1] = 0

In [37]:
#@title q_update
@cuda.jit
def q_update(board_ids, actions, rewards, max_steps, q_values, edge_visits, board_visits):
  tid = cuda.grid(1)

  reward = rewards[tid]
  for i in range(max_steps):  
    s_tm1 = board_ids[i, tid]
    if s_tm1 == -1:
      continue
    
    reward = -reward
    a_tm1_y, a_tm1_x = actions[i, tid]
    if a_tm1_y == -1:
      continue
  
    q_sa = q_values[tid, s_tm1, a_tm1_y, a_tm1_x]
    n_sa = edge_visits[tid, s_tm1, a_tm1_y, a_tm1_x]

    q_values[tid, s_tm1, a_tm1_y, a_tm1_x] = (n_sa * q_sa + reward) / (n_sa + 1)
    
    edge_visits[tid, s_tm1, a_tm1_y, a_tm1_x] += 1
    board_visits[tid, s_tm1] += 1

In [38]:
#@title policy_update
@cuda.jit
def policy_update(update_ids, new_policies, policies):
  tx = cuda.threadIdx.x
  ty = cuda.threadIdx.y
  bx = cuda.blockIdx.x

  g_id, b_id = update_ids[bx]
  policies[g_id, b_id, ty, tx] = new_policies[bx, ty, tx]

In [39]:
#@title board_lookup
@cuda.jit
def board_lookup(hashes, table_entries, table_lengths, board_ids):  
  # naive version  
  tid = cuda.grid(1)

  cap = table_lengths[tid]
  board_hash = hashes[tid]

  for i in range(cap):
    if table_entries[tid, i] == board_hash:
      board_ids[tid] = i
      return

  board_ids[tid] = -1

# Neural Network

In [40]:
#@title OthelloNNet
class OthelloNNet(nn.Module):
    def __init__(self, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        self.fc1 = nn.Linear(8*8, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(256, 1)        

    def forward(self, s):
        #                                      s: batch_size x board_x x board_y
        s = s.view(-1, 8*8)                # batch_size x 1 x board_x x board_y        
        s = F.dropout(F.relu(self.fc1(s)), p=self.dropout, training=self.training)  # batch_size x num_channels x 2
        s = F.dropout(F.relu(self.fc2(s)), p=self.dropout, training=self.training)  # batch_size x num_channels 
        pi = self.fc3(s)                                                                         # batch_size x action_size
        v = self.fc4(s)                                                                          # batch_size x 1

        return F.log_softmax(pi, dim=1), torch.tanh(v)

# Setup and Utilities

In [41]:
#@title device array pre-allocation
# state
boards = cuda.device_array(shape=(N_GAMES, 8, 8), dtype=np.float32)
player_status = cuda.device_array(shape=(N_GAMES, 2), dtype=np.uint8)
sim_boards = cuda.device_array(shape=(N_GAMES, 8, 8), dtype=np.float32)
sim_player_status = cuda.device_array(shape=(N_GAMES, 2), dtype=np.uint8)

# intermediates
valid_actions = cuda.device_array(shape=(N_GAMES, 8, 8), dtype=np.float32)
actions = cuda.device_array(shape=(N_GAMES, 2), dtype=np.int32)
board_ids = cuda.device_array(shape=(N_GAMES,), dtype=np.int32)

# mcts
q_values = cuda.device_array(shape=(N_GAMES, MAX_STATES, 8, 8), dtype=np.float32)  # Q(s,a)
edge_visits = cuda.device_array(shape=(N_GAMES, MAX_STATES, 8, 8), dtype=np.uint16)  # N(s,a)
board_visits = cuda.device_array(shape=(N_GAMES, MAX_STATES), dtype=np.uint16)  # N(s)
policies = cuda.device_array(shape=(N_GAMES, MAX_STATES, 8, 8), dtype=np.float32)  # P(s)

In [42]:
#@title utility methods
def uniform_policy(state):
  p = np.ones((len(state), 8, 8)) / 64
  v = np.zeros(len(state))
  return p, v

def random_policy(state):
  p = np.random.rand(len(state),8,8)
  p = np.exp(p) / np.sum(np.exp(p), axis=(1,2)).reshape(-1,1,1)
  
  return p.astype(np.float32), np.tanh(np.random.randn(len(state))).astype(np.float32)

def hash_boards(boards):
  return np.array([hash(b.tobytes()) for b in boards], dtype=np.uint64)

def is_game_over(boards, status):
  check = np.where(boards == 0)[0]
  no_moves = np.array([i not in check for i in range(len(boards))])
  all_passed = (np.sum(status, axis=1) == 2)
  return no_moves | all_passed

def score_game(boards, planner):
  scores = np.array([np.bincount(b.flatten().astype(int), minlength=3)[1:] 
                     for b in boards])
  result = -1*np.ones(len(boards))
  result[scores[:, planner-1] > scores[:, np.abs(planner-2)]] = 1  
  return result

# MCTS

In [43]:
#@title search
def search(eval_boards, eval_status, eval_player, table_entries, table_lengths, net):

  def get_policy(boards):
    # pi, v = net(torch.from_numpy(boards.astype(np.float32)))
    pi, v = net(torch.as_tensor(sim_boards, device='cuda'))
    return pi.view(-1,8,8).cpu().numpy(), v.view(-1).cpu().numpy()

  for s in range(N_SIMULATIONS):
    # initialize the simulation
    cuda.to_device(eval_boards, to=sim_boards)
    cuda.to_device(eval_status, to=sim_player_status)  
    sim_player = eval_player
    cpuct = 1
    
    # setup trajectory log
    sim_board_ids = []
    sim_actions = []
    sim_rewards = np.nan*np.ones(N_GAMES, np.float32)

    t = 0    
    while True:
      _status = sim_player_status.copy_to_host()
      _boards = sim_boards.copy_to_host()

      terminals = is_game_over(_boards, _status)
      terminals[~np.isnan(sim_rewards)] = False

      if np.any(terminals):
        sim_rewards[terminals] = score_game(_boards[terminals], eval_player)
      
      if np.all(~np.isnan(sim_rewards)):
        break

      hashes = hash_boards(_boards)
      board_lookup[N_GAMES, 1](hashes, table_entries, table_lengths, board_ids)
      _board_ids = board_ids.copy_to_host()            
      leaves = np.array([b == -1 for b in _board_ids])
      leaves[~np.isnan(sim_rewards)] = False      
      
      if np.any(leaves):        
        # add new boards to cache
        _board_ids[leaves] = table_lengths[leaves]
        table_lengths[leaves] += 1                
        table_entries[leaves, _board_ids[leaves]] = hashes[leaves]

        # compute policy/value from NN and update        
        pred_policy, pred_value = get_policy(sim_boards)#leaves)#_boards[leaves])
        update_ids = np.stack((leaves.nonzero()[0], _board_ids[leaves]), axis=1)
        policy_update[len(update_ids), (8,8)](update_ids, 
                                              pred_policy[leaves], 
                                              policies)
        sim_rewards[leaves] = pred_value[leaves]

      # if all our games have rewards after leaf node check, we are finished
      if np.all(~np.isnan(sim_rewards)):
        break

      # determine best action      
      find_valid_actions[N_GAMES, (8,8)](sim_boards, sim_player, valid_actions)   
      apply_ucb[N_GAMES, (8,8)](q_values, edge_visits, board_visits, policies, board_ids, cpuct, valid_actions)      
      select_action[N_GAMES, 8](valid_actions, actions)      
      step[N_GAMES, 8](sim_boards, actions, sim_player, sim_player_status)
      sim_player = 3 - sim_player

      # if we've received a reward, the forward pass is over for that game    
      _board_ids[~np.isnan(sim_rewards)] = -1      

      # log trajectory
      sim_board_ids.append(_board_ids)
      sim_actions.append(actions.copy_to_host())
      t += 1
    if s+1 % 100 == 0:
      print(f"sim {s} played {t} steps")
    
    # if we've seen at least one set of board/action pair, we backup the reward
    # values
    if len(sim_board_ids) > 0:
      sim_board_ids = np.flipud(np.stack(sim_board_ids, axis=0)).astype(np.int32)
      sim_actions = np.flipud(np.stack(sim_actions, axis=0)).astype(np.int32)
      q_update[N_GAMES, 1](sim_board_ids, sim_actions, sim_rewards, len(sim_board_ids),
                          q_values, edge_visits, board_visits)

In [44]:
#@title get_action_probabilities
def get_action_probabilities(boards, player_status, player, temperature, table_entries, table_lengths, net):
  # tic = time.time()
  # print('searching')
  search(boards, player_status, player, table_entries, table_lengths, net)
  # print(f"search (n_sim={N_SIMULATIONS}): {time.time()-tic}s")

  hashes = hash_boards(boards.copy_to_host())
  board_lookup[N_GAMES, 1](hashes, table_entries, table_lengths, board_ids)
  counts = edge_visits.copy_to_host()[range(N_GAMES), board_ids.copy_to_host()]
  passed = np.where(np.sum(counts, axis=(1,2)) == 0)
  if temperature == 0:    
    # greedy strategy
    best = np.unravel_index(np.argmax(counts.reshape(N_GAMES, -1), axis=1), (8,8))
    probs = np.zeros((N_GAMES, 8, 8))
    probs[range(N_GAMES), best[0], best[1]] = 1
  else:
    probs = counts / np.sum(counts, axis=(1,2)).reshape(-1, 1, 1)
  probs[passed] = 0
  return probs


In [45]:
#@title run_games
def run_games(net):
  # initialize game  
  board = np.zeros((8, 8), dtype=np.float32)
  board[[3, 4], [4, 3]] = 1
  board[[3, 4], [3, 4]] = 2
  player = 1

  cuda.to_device(np.tile(board, (N_GAMES, 1, 1)), to=boards)
  cuda.to_device(np.zeros((N_GAMES, 2), dtype=np.uint8), to=player_status)

  # initialize MCTS  
  table_entries = cuda.pinned_array((N_GAMES, MAX_STATES), dtype=np.uint64)
  table_entries[:] = 0
  table_lengths = cuda.pinned_array((N_GAMES,), dtype=np.int32)
  table_lengths[:] = 0
  cuda.to_device(np.zeros((N_GAMES, MAX_STATES, 8, 8), dtype=np.float32), to=q_values)
  cuda.to_device(np.zeros((N_GAMES, MAX_STATES, 8, 8), dtype=np.uint16), to=edge_visits)
  cuda.to_device(np.zeros((N_GAMES, MAX_STATES), dtype=np.uint16), to=board_visits)
  cuda.to_device(np.zeros((N_GAMES, MAX_STATES, 8, 8), dtype=np.float32), to=policies)
  cpuct = 1

  t = 0
  _boards = boards.copy_to_host()
  _status = player_status.copy_to_host()
  pbar = tqdm(total=60)
  while ~np.all(is_game_over(_boards, _status)):    
    if t >= 60:
      pbar.total += 1
    tic = time.time()
    temperature = int(t < 15)
    pi = get_action_probabilities(boards, player_status, player, temperature, table_entries, table_lengths, net)
    # collect_training_samples()    
    _actions = -1*np.ones((N_GAMES, 2), dtype=np.int32)
    for i, p in enumerate(pi):
      if np.sum(p) == 0:
        continue
      _p = p.flatten()      
      # print(_p)
      _actions[i] = np.unravel_index(np.random.choice(len(_p), p=_p), (8,8))
    # print(_actions)
    step[N_GAMES, 8](boards, _actions, player, player_status)
    player = 3-player    
    t += 1
    # print(f"turn {t} took {time.time()-tic}s")
    _boards = boards.copy_to_host()
    _status = player_status.copy_to_host()
    pbar.update(1)    
  pbar.close()
  print(_boards)
  print(score_game(_boards, 1))

# Training

In [None]:
net = OthelloNNet().cuda()
with torch.no_grad():
  # this ensures dropout is applied. we need some stochasticity otherwise all
  # parallel games will play out the same. 
  net.train()
  run_games(net)

# Kernel Tests

In [45]:
#@title valid action test
def test_find_valid_action():
  board = np.zeros((1, 8, 8), dtype=np.float32)
  board[[0, 0], [3, 4], [3, 4]] = 2
  board[[0, 0], [3, 4], [4, 3]] = 1

  valid_actions = cuda.device_array(shape=(1, 8, 8), dtype=np.float32)

  find_valid_actions[1, (8,8)](board, 1, valid_actions)
  result = valid_actions.copy_to_host()
  expected = np.zeros((1,8,8))
  expected[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = 1
  assert np.allclose(result, expected)

  find_valid_actions[1, (8,8)](board, 2, valid_actions)
  result = valid_actions.copy_to_host()

  expected = {(2, 4), (3, 5), (4, 2), (5, 3)}
  expected = np.zeros((1,8,8))
  expected[[0,0,0,0],[2,3,4,5],[4,5,2,3]] = 1
  assert np.allclose(result, expected)
  print('PASS')
  
test_find_valid_action()

PASS


In [46]:
#@title select action test
def test_select_action():
  actions = cuda.device_array(shape=(1, 2), dtype=np.int32)
  
  valid_actions = np.zeros((1,8,8))
  valid_actions[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = [0.1, 0.2, 0.4, 0.3]
  select_action[1, 8](valid_actions, actions)
  expected = [[4,5]]
  assert np.allclose(actions.copy_to_host(), expected)

  valid_actions = np.zeros((1,8,8))
  valid_actions[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = [-0.1, -0.2, -0.4, -0.3]
  select_action[1, 8](valid_actions, actions)
  expected = [[2,3]]
  assert np.allclose(actions.copy_to_host(), expected)
  
  valid_actions = np.zeros((1,8,8))  
  select_action[1, 8](valid_actions, actions)
  expected = [[-1,-1]]
  assert np.allclose(actions.copy_to_host(), expected)

  print('PASS')

test_select_action()

PASS


In [47]:
#@title step test
def test_step():
  board = np.zeros((1, 8, 8), dtype=np.float32)
  board[[0, 0], [3, 4], [3, 4]] = 2
  board[[0, 0], [3, 4], [4, 3]] = 1
  boards = cuda.to_device(board)
  status = cuda.to_device(np.zeros((1,2), dtype=np.int32))
  actions = cuda.to_device(np.array([[4, 5]], dtype=np.int32))
  
  step[1, 8](boards, actions, 1, status)
  expected = np.zeros((1, 8, 8), dtype=np.float32)
  expected[[0], [3], [3]] = 2
  expected[[0, 0, 0, 0], [3, 4, 4, 4], [4, 3, 4, 5]] = 1
  assert np.allclose(boards.copy_to_host(), expected)
  assert np.allclose(status.copy_to_host(), [[0,0]])

  cuda.to_device(np.array([[3, 5]], dtype=np.int32), to=actions)
  step[1, 8](boards, actions, 2, status)
  expected[[0, 0, 0], [3, 3, 3], [3, 4, 5]] = 2
  assert np.allclose(boards.copy_to_host(), expected)
  assert np.allclose(status.copy_to_host(), [[0,0]])

  cuda.to_device(np.array([[2, 3]], dtype=np.int32), to=actions)
  step[1, 8](boards, actions, 1, status)
  expected[[0, 0, 0], [2, 3, 3], [3, 3, 4]] = 1
  assert np.allclose(boards.copy_to_host(), expected)
  assert np.allclose(status.copy_to_host(), [[0,0]])

  cuda.to_device(np.array([[-1, -1]], dtype=np.int32), to=actions)
  step[1, 8](boards, actions, 1, status)
  assert np.allclose(boards.copy_to_host(), expected)
  assert np.allclose(status.copy_to_host(), [[1,0]])

  cuda.to_device(np.array([[-1, -1]], dtype=np.int32), to=actions)
  step[1, 8](boards, actions, 2, status)
  assert np.allclose(boards.copy_to_host(), expected)
  assert np.allclose(status.copy_to_host(), [[1,1]])

  print('PASS')

test_step()

PASS


In [84]:
#@title apply ucb test
def test_apply_ucb():
  q_values = cuda.to_device(np.zeros((1, 2, 8, 8), dtype=np.float32))
  edge_visits = cuda.to_device(np.zeros((1, 2, 8, 8), dtype=np.float32))
  board_visits = cuda.to_device(np.zeros((1, 2), dtype=np.float32))
  
  _policies = np.zeros((1, 2, 8, 8), dtype=np.float32)
  _policies[[0,0,0,0],[0,0,0,0],[2,3,4,5],[3,2,5,4]] = 0.25
  policies = cuda.to_device(_policies)

  v_actions = np.zeros((1,8,8), dtype=np.float32)
  v_actions[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = 1.0
  valid_actions = cuda.to_device(v_actions)

  board_ids = cuda.to_device(np.array([0], dtype=np.int32))
  apply_ucb[1, (8,8)](q_values, edge_visits, board_visits, policies, board_ids, 1, valid_actions)
  
  expected = np.zeros((1,8,8), dtype=np.float32)
  expected[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = 0.25*np.sqrt(EPS)
  assert np.allclose(valid_actions.copy_to_host(), expected)

  v_actions = np.zeros((1,8,8), dtype=np.float32)
  v_actions[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = 1.0
  valid_actions = cuda.to_device(v_actions)

  _temp = np.zeros((1, 2, 8, 8), dtype=np.float32)
  _temp[[0],[0],[2],[3]] = 1
  cuda.to_device(_temp, to=edge_visits)
  _temp = np.zeros((1, 2), dtype=np.float32)
  _temp[[0],[0]] = 1
  cuda.to_device(_temp, to=board_visits)
  apply_ucb[1, (8,8)](q_values, edge_visits, board_visits, policies, board_ids, 1, valid_actions)

  expected = np.zeros((1,8,8), dtype=np.float32)
  expected[[0,0,0,0],[2,3,4,5],[3,2,5,4]] = [0.125, 0.25, 0.25, 0.25]
  assert np.allclose(valid_actions.copy_to_host(), expected)

  print('PASS')
  
test_apply_ucb()

PASS


In [96]:
#@title q update test
def test_q_update():
  q_values = cuda.to_device(np.zeros((2, 2, 8, 8), dtype=np.float32))
  edge_visits = cuda.to_device(np.zeros((2, 2, 8, 8), dtype=np.float32))
  board_visits = cuda.to_device(np.zeros((2, 2), dtype=np.float32))

  sim_board_ids = np.array([[0, 0]], dtype=np.int32)
  sim_actions = np.array([[[2, 3],[2, 3]]], dtype=np.int32)
  sim_rewards = np.array([1, -1], dtype=np.float32)  
  q_update[2, 1](sim_board_ids, sim_actions, sim_rewards, len(sim_board_ids),
                         q_values, edge_visits, board_visits)
  expected = np.zeros((2, 2, 8, 8))
  expected[[0,1],[0,0],[2,2],[3,3]] = [-1, 1]
  assert np.allclose(q_values.copy_to_host(), expected)
  expected = np.zeros((2, 2, 8, 8))
  expected[[0,1],[0,0],[2,2],[3,3]] = 1
  assert np.allclose(edge_visits.copy_to_host(), expected)
  expected = np.zeros((2, 2))
  expected[[0,1],[0,0]] = 1
  assert np.allclose(board_visits.copy_to_host(), expected)

  sim_board_ids = np.array([[1, -1], [0, 0]], dtype=np.int32)
  sim_actions = np.array([[[2, 5],[2, 3]],[[2, 3],[2, 3]]], dtype=np.int32)
  sim_rewards = np.array([0.5, -0.5], dtype=np.float32)  
  q_update[2, 1](sim_board_ids, sim_actions, sim_rewards, len(sim_board_ids),
                         q_values, edge_visits, board_visits)
  
  expected = np.zeros((2, 2, 8, 8))
  expected[[0],[1],[2],[5]] = -0.5
  expected[[0,1],[0,0],[2,2],[3,3]] = [-0.25, 0.75]
  # result = q_values.copy_to_host()
  # print(result.nonzero())
  # print(result[result.nonzero()])
  assert np.allclose(q_values.copy_to_host(), expected)
  expected = np.zeros((2, 2, 8, 8))
  expected[[0,1],[0,0],[2,2],[3,3]] = 2
  expected[[0],[1],[2],[5]] = 1
  assert np.allclose(edge_visits.copy_to_host(), expected)
  expected = np.zeros((2, 2))
  expected[[0,1],[0,0]] = 2
  expected[[0],[1]] = 1
  assert np.allclose(board_visits.copy_to_host(), expected)
  
  print('PASS')

test_q_update()

PASS


In [25]:
#@title board lookup test
def test_board_lookup():
  table_entries = np.zeros((2, 2), dtype=np.uint64)
  table_lengths = np.zeros((2,), dtype=np.int32)
  board_ids = np.zeros((2,), dtype=np.int32)

  board = np.zeros((1, 8, 8), dtype=np.float32)
  board[[0, 0], [3, 4], [3, 4]] = 2
  board[[0, 0], [3, 4], [4, 3]] = 1
  hashes = hash_boards([board, board])

  board_lookup[N_GAMES, 1](hashes, table_entries, table_lengths, board_ids)
  expected = np.array([-1, -1])
  assert np.allclose(board_ids, expected)

  table_entries[0,0] = hashes[0]
  table_entries[1,1] = hashes[0]
  table_lengths[0] = 1
  table_lengths[1] = 2
  board_lookup[N_GAMES, 1](hashes, table_entries, table_lengths, board_ids)
  expected = np.array([0, 1])
  assert np.allclose(board_ids, expected)

  print('PASS')

test_board_lookup()

[0 1]
PASS
