In [1]:
import sys
import os

sys.path.insert(0, os.path.dirname(os.getcwd()))

In [2]:
from scripts import dataset
from scripts.chess_utils import tensor_to_board, action_to_san, move_to_int, board_to_tensor, index_to_coordinates

In [3]:
dataset = dataset.TrainingDataset('../training_data.bin')

In [4]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [5]:
# let's find a checkmate move in the dataset
checkmate_idx = -1
for i in range(1000):
    record = dataset[i]
    board = tensor_to_board(record['board_tensor'])
    if board.is_checkmate():
        print(f"Checkmate found in record {i}")
        print(f"{board}")
        checkmate_idx = i
        break

Checkmate found in record 29
r n b . k . . .
. . p p Q p b r
. . . . . . . p
p . . . p . B .
P . P P P . . P
. p N . . . . .
. P . . B P P q
R . . . . K N R


In [6]:
record = dataset[checkmate_idx]
print(tensor_to_board(record['board_tensor']))
policy = record["policy"]
legal_mask = record["legal_mask"]
child_visit_counts = record["child_visit_counts"]
child_values = record["child_values"]
value = record["value"]
final_value = record["final_value"]
print("Policy: ", policy)
print("Policy sum: ", policy.sum())
print("Legal Mask: ", legal_mask)
print("Legal Mask sum: ", legal_mask.sum())
print("Child Visit Counts: ", child_visit_counts)
print("Child Visit Counts sum: ", child_visit_counts.sum())
print("Child Values: ", child_values)
print("Child Values sum: ", child_values.sum())
print("Value: ", value)
print("Final Value: ", final_value)

r n b . k . . .
. . p p Q p b r
. . . . . . . p
p . . . p . B .
P . P P P . . P
. p N . . . . .
. P . . B P P q
R . . . . K N R
Policy:  tensor([0., 0., 0.,  ..., 0., 0., 0.])
Policy sum:  tensor(0.)
Legal Mask:  tensor([False, False, False,  ..., False, False, False])
Legal Mask sum:  tensor(0)
Child Visit Counts:  tensor([0, 0, 0,  ..., 0, 0, 0], dtype=torch.int32)
Child Visit Counts sum:  tensor(0)
Child Values:  tensor([0., 0., 0.,  ..., 0., 0., 0.])
Child Values sum:  tensor(0.)
Value:  tensor(-1.)
Final Value:  tensor(-1, dtype=torch.int32)


In [7]:
import torch
from scripts import model

In [8]:
# Load the model
model = "../model.pt"
model = torch.jit.load(model)
model.eval()

RecursiveScriptModule(
  original_name=ChessModel
  (conv_blocks): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Conv2d)
      (1): RecursiveScriptModule(original_name=BatchNorm2d)
      (2): RecursiveScriptModule(original_name=ReLU)
      (3): RecursiveScriptModule(original_name=Conv2d)
      (4): RecursiveScriptModule(original_name=BatchNorm2d)
    )
    (1): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Conv2d)
      (1): RecursiveScriptModule(original_name=BatchNorm2d)
      (2): RecursiveScriptModule(original_name=ReLU)
      (3): RecursiveScriptModule(original_name=Conv2d)
      (4): RecursiveScriptModule(original_name=BatchNorm2d)
    )
    (2): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(original_name=Conv2d)
      (1): RecursiveScriptModule(original_name=BatchNorm2d)

In [9]:
model(record['board_tensor'].unsqueeze(0).to('cuda'))

(tensor([[-0.0302, -2.4905, -2.8885,  ..., -2.6366, -0.5154,  0.0293]],
        device='cuda:0', grad_fn=<AddmmBackward0>),
 tensor([[-1.]], device='cuda:0', grad_fn=<TanhBackward0>))

In [11]:
import chess
board = chess.Board("rnbqkb1r/pppp1ppp/5n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 2 4")
tensor = torch.tensor(board_to_tensor(board)).to('cuda')
model(tensor.unsqueeze(0).to('cuda'))

(tensor([[ 0.0297, -2.6662, -2.6756,  ..., -3.8222, -4.3159,  0.0090]],
        device='cuda:0', grad_fn=<AddmmBackward0>),
 tensor([[-0.0654]], device='cuda:0', grad_fn=<DifferentiableGraphBackward>))

In [12]:
len(dataset)

115542