In [47]:
%load_ext autoreload
%autoreload 2

import numpy as np
print(np.__version__)

import torch
print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

from tqdm.notebook import trange
import math
import copy
import time
import random

from bg import Backgammon
from mcts import MCTS
from node import Node
from model import ResNet


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
1.26.4
2.2.2
cpu


In [60]:
num_parallel_games = 1
num_concurrent_processes = 2
total_processes = 2
train_iteration = 2

INIT_BOARD = np.array([0, 2, 0, 0, 0, 0, -5, 0, -3, 0, 0, 0, 5, -5, 0, 0, 0, 3, 0, 5, 0, 0, 0, 0, -2, 0])
TEST_BOARD = np.array([0, 2, 0, 0, 0, 0, -3, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0, -2, 0])

args = {
    'dir': 'data/TEST_13',
    'num_searches': 100,
    'train_iteration': 0,
    'num_parallel_games': num_parallel_games,
    'num_concurrent_processes': num_concurrent_processes,
    'num_total_processes': total_processes,
    'num_epochs': 3,
    'batch_size': 128,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.0001,
    'C': 2, 
}
bg = Backgammon()

## Model 0

In [62]:
model0 = ResNet(game=bg, num_resBlocks=20, num_hidden=64, num_features=6, device=device)
model0.load_state_dict(torch.load('data/TEST_13/num_searches5/model_0.pt'))
model0.eval()
print()





In [63]:
mcts = MCTS(game=bg, model=model0, args=args)
board = TEST_BOARD
jumps = [1,3]
root = Node(bg, args, board, jumps)
mcts.search(root)
print(root)
for child in sorted(root.children, key=lambda x: x.prior, reverse=True):
    print(child)

Process ID: 17037: search took 4.449402809143066 seconds to run.

------------------------------------------------------------
Level: 0, N: 100, val: -5.511, prior: 0.000, uct:None, weight: 2, num_children: 14
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 0   0   0   0  -1   0 |-3   0   0   0   0   2 ||  0 


|| 0   0   0   0   1   0 | 3   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=0.04874015599489212, action_taken=None
board = [ 0  2  0  0  0  0 -3  0 -1  0  0  0  0  0  0  0  0  1  0  3  0  0  0  0 -2  0]
------------------------------------------------------------


  ------------------------------------------------------------
  Level: 1, N: 7, val: 0.500, prior: 0.074, uct:0.6496166493970787, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 0   0   0   0  -1   0 |-1  -1   0  -1   0   2 ||  0 
  
  
  || 0   0   0   0   1   0 | 3   0   0   0   0  -2 ||  0 
     13  

## Model 2

In [64]:
model2 = ResNet(game=bg, num_resBlocks=20, num_hidden=64, num_features=6, device=device)
model2.load_state_dict(torch.load('data/TEST_13/num_searches5/model_2.pt'))
model2.eval()
print()




In [65]:
mcts = MCTS(game=bg, model=model2, args=args)
board = TEST_BOARD
jumps = [1,3]
root = Node(bg, args, board, jumps)
mcts.search(root)
print(root)
for child in sorted(root.children, key=lambda x: x.visit_count, reverse=True):
    print(child)

Process ID: 17037: search took 4.658331871032715 seconds to run.

------------------------------------------------------------
Level: 0, N: 100, val: -4.733, prior: 0.000, uct:None, weight: 2, num_children: 14
   12  11  10  9   8   7   6   5   4   3   2   1     0 
|| 0   0   0   0  -1   0 |-3   0   0   0   0   2 ||  0 


|| 0   0   0   0   1   0 | 3   0   0   0   0  -2 ||  0 
   13  14  15  16  17  18  19  20  21  22  23  24    25
jumps=[1, 3], state_value=0.03619949892163277, action_taken=None
board = [ 0  2  0  0  0  0 -3  0 -1  0  0  0  0  0  0  0  0  1  0  3  0  0  0  0 -2  0]
------------------------------------------------------------


  ------------------------------------------------------------
  Level: 1, N: 8, val: 0.260, prior: 0.078, uct:0.6574193257754508, weight: 1, num_children: 21
     12  11  10  9   8   7   6   5   4   3   2   1     0 
  || 0   0   0   0   0  -1 |-3   0   0   0   0   2 ||  0 
  
  
  || 0   0   0   0   1   0 | 3   0  -1   0   0  -1 ||  0 
     13  

In [None]:
mcts.search(root)
