In [None]:
!pip install torch numpy chess graphviz==0.20.1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import chess
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Need to add some kind of arena to this model. Currently the model is improving its ability to predict the mcts process, but there is not a guarantee that the mcts process will be very good. Its a possibility that im not doing enough training iterations, which would mean that im consistently getting bad data. Its possible that theres a bug in my MCTS which would make my implementation give bad results, and then the model would continue to train on the bad results. So its possible that I just need better data (run for more iterations) to have better games from my mcts.  

What is the point in having the old network and the new network compete against itself? Why not just have the network constantly do self play to generate data and then train on that data. 

In [None]:
# Search tree
# tree = [((n, w, q, p), [((n2, w2, q2, p2) [...])]), ... ((n_n, w_n, q_n, p_n), [...])]
# whole tree is stored in numpy array? how to do this
# t(s,a) = (n(s,a), w(s,a), q(s,a), p(s,a))
# how to encode (s,a) -> s'? this could just be the game rules
# Need a way to discard the rest of the tree when its not needed anymore
# is s an encoding of the board state?
# not sure if we can make a numpy array of size s x a, because s is massive
# numpy array doesnt need to be indexed like a function
# p(s,a) is just a 2d numpy array of probabilities
# n(s,a) is the number of times action a has been taken in state s
# we know that each node in the tree will have the same number of actions, which means we can know the number of nodes in the tree based on the size of the numpy array

# Chess numbers
T = 8
N = 8
M = 6
L = 7

past_boards = {}

def itb(num: int, length: int):
    """
    Converts integer to bit array
    Someone fix this please :D - it's horrible
    :param num: number to convert to bits
    :param length: length of bits to convert to
    :return: bit array
    """
    num = int(num)
    if length == 1:
        return [int(i) for i in '{0:01b}'.format(num)]
    if length == 2:
        return [int(i) for i in '{0:02b}'.format(num)]
    if length == 3:
        return [int(i) for i in '{0:03b}'.format(num)]
    if length == 4:
        return [int(i) for i in '{0:04b}'.format(num)]
    if length == 5:
        return [int(i) for i in '{0:05b}'.format(num)]
    if length == 8:
        return [int(i) for i in '{0:08b}'.format(num)]
    if length == 11:
        return [int(i) for i in '{0:011b}'.format(num)]
    raise TypeError("Length not supported:", length)

REPETITION_IDX_START = 12
REPETITION_IDX_END = REPETITION_IDX_START + 2
COLOR_IDX = REPETITION_IDX_END
MOVE_COUNT_IDX = COLOR_IDX + 1
P1_CASTLING_IDX_KINGSIDE = MOVE_COUNT_IDX + 1
P1_CASTLING_IDX_QUEENSIDE = P1_CASTLING_IDX_KINGSIDE + 1
P2_CASTLING_IDX_KINGSIDE = P1_CASTLING_IDX_QUEENSIDE + 1
P2_CASTLING_IDX_QUEENSIDE = P2_CASTLING_IDX_KINGSIDE + 1
NO_PROGRESS_IDX = P2_CASTLING_IDX_QUEENSIDE + 1

def encode_board(board: chess.Board) -> np.ndarray:
    b = np.zeros((N, N, 6+6+2+1+1+2+2+1))

    for x in range(0,8):
        for y in range(0,8):
            piece = board.piece_at(chess.square(x,y))
            if piece is None:
                continue
            piece_type = int(piece.piece_type) - 1
            piece_color = piece.color
            if piece_color == chess.BLACK:
                piece_type += 6
            # Piece encoding
            b[x][y][piece_type] = 1

    # Repetition encoding
    # TODO : incorporate repetition encoding
    b[:, :, REPETITION_IDX_START:REPETITION_IDX_END] = itb(0, 2)

    # OTHER ENCODINGS
    b[:, :, COLOR_IDX] = 1 if board.turn == chess.BLACK else 0
    b[:, :, MOVE_COUNT_IDX] = board.fullmove_number 
    b[:, :, P1_CASTLING_IDX_KINGSIDE] = 1 if board.has_kingside_castling_rights(chess.WHITE) else 0
    b[:, :, P1_CASTLING_IDX_QUEENSIDE] = 1 if board.has_queenside_castling_rights(chess.WHITE) else 0
    b[:, :, P2_CASTLING_IDX_KINGSIDE] = 1 if board.has_kingside_castling_rights(chess.BLACK) else 0
    b[:, :, P2_CASTLING_IDX_QUEENSIDE] = 1 if board.has_queenside_castling_rights(chess.BLACK) else 0
    b[:, :, NO_PROGRESS_IDX] = board.halfmove_clock / 2
    
    return b.astype(np.float32)

# Moves = 7 * 8 queen moves + 8 knight moves + 9 underpromotions 

def encode_move(move):
    from_square = int(move.from_square)
    from_square_x = from_square % 8
    from_square_y = from_square // 8

    to_square = int(move.to_square)
    to_square_x = to_square % 8
    to_square_y = to_square // 8

    # Moves are stored like this
    # 0-6 north, 7-13 south, 14-20 east, 21-27 west, 28-34 northeast, 35-41 northwest, 42-48 southeast, 49-55 southwest
    dx = to_square_x - from_square_x
    dy = to_square_y - from_square_y

    # Queen moves
    if dy > 0 and dx == 0:
        idx = dy
    elif dy < 0 and dx == 0:
        idx = 7 + abs(dy)
    elif dy == 0 and dx > 0:
        idx = 14 + dx
    elif dy == 0 and dx < 0:
        idx = 21 + abs(dx)
    elif dy > 0 and dx > 0 and abs(dy) == abs(dx):
        idx = 28 + dx
    elif dy > 0 and dx < 0 and abs(dy) == abs(dx):
        idx = 35 + abs(dx)
    elif dy < 0 and dx > 0 and abs(dy) == abs(dx):
        idx = 42 + dx
    elif dy < 0 and dx < 0 and abs(dy) == abs(dx):
        idx = 49 + abs(dx)
    # Knight moves
    elif dx == 1 and dy == 2:
        idx = 56
    elif dx == 2 and dy == 1:
        idx = 57
    elif dx == -1 and dy == 2:
        idx = 58
    elif dx == -2 and dy == 1:
        idx = 59
    elif dx == 1 and dy == -2:
        idx = 60
    elif dx == 2 and dy == -1:
        idx = 61
    elif dx == -1 and dy == -2:
        idx = 62
    elif dx == -2 and dy == -1:
        idx = 63
    # Underpromotions
    # Cases with no capture
    elif abs(dx) == 0 and abs(dy) == 1 and move.promotion != chess.QUEEN:
        idx = 64 + move.promotion - 1
    # Capture northwest
    elif dx == -1 and abs(dy) == 1 and move.promotion != chess.QUEEN:
        idx = 67 + move.promotion - 1
    # Capture northeast
    elif dx == 1 and abs(dy) == 1 and move.promotion != chess.QUEEN:
        idx = 70 + move.promotion - 1
    
    return (from_square_x * 8 * 8) + (from_square_y * 8) + idx


# Monte Carlo Tree Search
@torch.no_grad()
def mcts_step(tree, net, board_in, c=1.0, temp=1.0, iterations=1000, epsilon=0.25):
    if 'children' in tree.keys():
        # Add dirichlet noise to root node to add variety
        priors_list = []
        for _, child_dict in tree['children'].items():
            priors_list.append(child_dict['data'][3])
        num_actions = len(priors_list)
        dir = np.random.default_rng().dirichlet([0.3] * num_actions)

        for i, (_, child_dict) in enumerate(tree['children'].items()):
            child_dict['data'][3] = (1 - epsilon) * child_dict['data'][3] + epsilon * dir[i]

    for _ in range(0, iterations):
        board = board_in.copy()

        # Selection step
        traversed_nodes = []
        cur_subtree = tree
        while 'children' in cur_subtree:

            # Get sum(N) for all children
            n_sqrt_sum = math.sqrt(sum(move_data['data'][0] for _, move_data in cur_subtree['children'].items()))
            move_max = float('-inf')
            for move, move_data in cur_subtree['children'].items():
                # Calculate U value
                p = move_data['data'][3]
                u = c * p * (n_sqrt_sum / (1 + move_data['data'][0]))

                # Calculate Q value
                q = move_data['data'][2]
                
                if (q + u) > move_max:
                    move_max = q + u
                    best_move = move

            if best_move not in cur_subtree['children'].keys():
                print(cur_subtree)

            # Append node data to list for backpropagation
            traversed_nodes.append(cur_subtree['children'][best_move])

            # Move to next node
            cur_subtree = cur_subtree['children'][best_move]
            board.push_uci(best_move)
        
        # Expansion step
        # Add leaf node to queue for neural network evaluation
        
        # Handle reaching a terminal state
        if board.legal_moves.count() == 0:
            outcome = board.outcome()
            # either a draw or the current player lost
            if outcome.winner is None:
                v = 0
            else:
                v = -1
            # print(f"reached terminal state {v}")
        else:
            board_encoded = encode_board(board)

            # Send to neural network
            p, v = net(torch.from_numpy(board_encoded).to('cuda'))
            p = p.cpu()
            v = v.cpu()

            # Collect all legal moves and their predicted values
            legal_move_list = []
            for move in board.legal_moves:
                a = encode_move(move)
                legal_move_list.append((p[0, a].item(), move))

            # Normalize the p values. add constant to prevent divide by zero error
            p_sum = sum(x[0] for x in legal_move_list)
            if p_sum == 0:
                legal_move_list = [(1/len(legal_move_list), x[1]) for x in legal_move_list]
            else:
                legal_move_list = [(x[0]/p_sum, x[1]) for x in legal_move_list]
          
            # Populate edges with predicted values if not terminal state
            cur_subtree['children'] = {}
            for p_val, move in legal_move_list:
                if not (0 <= p_val <= 1):
                    raise Exception(f"PVAL ERROR: {p_val}")
                    print(p)
                cur_subtree['children'][move.uci()] = {'data': np.array((0,0,0,p_val)), 'color': board.turn}

            if not legal_move_list:
                print(f"SPOTTED PROBLEM: {board.legal_moves}")

        # Backup step
        for node in traversed_nodes:
            # N is the amount of times this edge has been traversed
            node['data'][0] += 1
            # W is the sum of all the values of this trajectory
            if node['color'] == board.turn:
                node['data'][1] += v
            else:
                node['data'][1] -= v
            # Q is the average value of all the nodes beneath this edge
            node['data'][2] = node['data'][1] / node['data'][0]

    # Play move and discard rest of tree
    encoded_policy = np.zeros((8 * 8 * 73))
    if temp == 0:
        # Pick max move
        move_max = 0
        selected_move = None
        for move, move_data in tree['children'].items():
            # Check that these probabilities are okay after move 30
            policy = move_data['data'][0]

            # Update max move
            if policy > move_max:
                move_max = policy
                selected_move = move
          
        # Add as training datapoint
        a = encode_move(chess.Move.from_uci(selected_move))
        encoded_policy[a] = 1
    else:
        n_sqrt_sum_temp = sum(move_data['data'][0]**(1/temp) for _, move_data in tree['children'].items())
        random_sample = random.random()
        p = 0
        selected_move = None
        for move, move_data in tree['children'].items():
            # Check that these probabilities are okay after move 30
            policy = move_data['data'][0]**(1/temp) / n_sqrt_sum_temp
            # Add as training datapoint
            a = encode_move(chess.Move.from_uci(move))
            encoded_policy[a] = policy

            # See if this will be the selected move
            p += policy
            if selected_move is None and random_sample < p:
                selected_move = move


    # Create an encoding of the policy to return as a training datapoint
    # Maybe just add together the values from above?

    # Return a, new tree, and encoding of the policy
    return selected_move, tree['children'][selected_move], encoded_policy

In [None]:
import graphviz 

def draw_mcts(tree):
    dot = graphviz.Digraph(format='png', graph_attr={'rankdir': 'LR'})
    # dot.attr(size='6,6')
    
    dot.node(name = str(0), label = "{ root | n: %d, w: %.4f, q: %.4f, p: %.4f }" % (tree['data'][0], tree['data'][1], tree['data'][2], tree['data'][3]), shape='record')
    id_counter = 0
    queue = [(tree, 0)]
    while queue:
        subtree, parent_id = queue.pop(0)
        if 'children' not in subtree.keys():
            continue
        for move_uci, values in subtree['children'].items():
            id_counter += 1
            dot.node(name = str(id_counter), label = "{ %d | n: %d, w: %.4f, q: %.4f, p: %.4f }" % (id_counter, values['data'][0], values['data'][1], values['data'][2], values['data'][3]), shape='record')
            dot.edge(str(parent_id), str(id_counter))
            queue.append((values, id_counter))
    
    # flat = dot.unflatten()

    return dot

board = chess.Board()
tree = {'data': np.array((0,0,0,1))}
net.eval()
mcts_step(tree, net, board)
draw_mcts(tree)

In [None]:
class TestNet(nn.Module):
    def __init__(self):
        super(TestNet, self).__init__()
        self.layer1 = nn.Linear(8*8*21, 1024)
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 1024)

        # p
        self.fc1 = nn.Linear(1024, 8*8*73)

        # v
        self.fc2 = nn.Linear(1024, 1)
    
    def forward(self, x):
        x = x.view(-1, 8*8*21)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        
        p = F.softmax(self.fc1(x), dim=1)
        v = torch.tanh(self.fc2(x))
        return p, v

In [None]:
# Simple model of the data needed to make a training point
class TrainingPoint:
    def __init__(self, s: chess.Board, pi: np.ndarray, z=None) -> None:
        self.s = s.copy()
        self.pi = pi
        self.z = z

    def get_tensor(self): # Returns X, (pi, z)
      # X: (8, 8, 21)
      # pi: (8, 8, 73); z: 1
      return torch.from_numpy(encode_board(self.s).astype(np.float32)), (torch.from_numpy(self.pi.astype(np.float32)), torch.tensor((self.z)))

    def set_z_value(self, outcome: chess.Outcome):
      if outcome.winner is None:
          self.z = 0
      elif outcome.winner == self.s.turn:
          self.z = 1
      else:
          self.z = -1

In [None]:
def self_play(net, tree=None, iters=32):
    # Self play and collect datapoints
    datapoints = []
    board = chess.Board()
    if tree:
        subtree = tree
    else:
        subtree = {'data': np.array((0,0,0,1))}
    while board.outcome() is None:
        # Change exploration depending on the move count
        if board.fullmove_number < 30:
            temp = 1
        else:
            temp = 0
        move, subtree, pi = mcts_step(subtree, net, board, temp=temp, iterations=iters)
        train_data = TrainingPoint(board.copy(), pi)
        datapoints.append(train_data)
        board.push_uci(move)

    # Set the game outcome on all the datapoints
    for datapoint in datapoints:
        datapoint.set_z_value(board.outcome())

    return datapoints

In [None]:
def generate_train_data(net, episodes=100, iters=32):
    # Feed in the same tree over and over again to introduce randomness
    train_data = []
    tree = {'data': np.array((0,0,0,1))}
    for i in range(0, episodes):
        train_data += self_play(net, tree=tree, iters=iters)
        print(f"Episode {i} complete")
    return train_data

In [None]:
# Loss function defined in the paper
def loss(p, v, pi, z):
    return (torch.sum((z - v.view(-1)).square()) - torch.sum(pi * torch.log(p + 1e-8))) / p.size()[0]

In [None]:
class ChessNet(nn.Module):
    def __init__(self):
        super(ChessNet, self).__init__()
        self.num_channels = 512
        self.dropout = 0.3
        num_channels = self.num_channels
        self.conv1 = nn.Conv2d(21, num_channels, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(num_channels, num_channels, 3, stride=1)
        self.conv4 = nn.Conv2d(num_channels, num_channels, 3, stride=1)

        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.bn3 = nn.BatchNorm2d(num_channels)
        self.bn4 = nn.BatchNorm2d(num_channels)

        self.fc1 = nn.Linear(num_channels*(8-4)*(8-4), 1024)
        self.fc_bn1 = nn.BatchNorm1d(1024)

        self.fc2 = nn.Linear(1024, 512)
        self.fc_bn2 = nn.BatchNorm1d(512)

        self.fc3 = nn.Linear(512, 8*8*73)

        self.fc4 = nn.Linear(512, 1)

    def forward(self, s):
        #                                                           s: batch_size x board_x x board_y
        s = s.view(-1, 21, 8, 8)                # batch_size x 1 x board_x x board_y
        s = F.relu(self.bn1(self.conv1(s)))                          # batch_size x num_channels x board_x x board_y
        s = F.relu(self.bn2(self.conv2(s)))                          # batch_size x num_channels x board_x x board_y
        s = F.relu(self.bn3(self.conv3(s)))                          # batch_size x num_channels x (board_x-2) x (board_y-2)
        s = F.relu(self.bn4(self.conv4(s)))                          # batch_size x num_channels x (board_x-4) x (board_y-4)
        s = s.view(-1, self.num_channels*(8-4)*(8-4))

        s = F.dropout(F.relu(self.fc_bn1(self.fc1(s))), p=self.dropout, training=self.training)  # batch_size x 1024
        s = F.dropout(F.relu(self.fc_bn2(self.fc2(s))), p=self.dropout, training=self.training)  # batch_size x 512

        pi = self.fc3(s)                                                                         # batch_size x action_size
        v = self.fc4(s)                                                                          # batch_size x 1

        return F.softmax(pi, dim=1), torch.tanh(v)

# New Section


Where is the best place to gather the data? we could either store the positions in the play loop, or have the mcts return a list of every position that led to the end. What does each datapoint look like? Should be encoded board as input, then encoded action and p value as output
But the p value should actually be the policy generated by the mcts, no? actually it looks like the p value isnt needed at all... why? 
Turns out I misunderstood. The outputs of the network are pi and v, where pi is the policy and v is the estimated value of the state. So we use the games outcome as the actual state value, and the q values as the policy output

Tomorrow, should get to building the neural network and figuring out all the loss function stuff.
Today I figured out how to collect the training points from the network
TODO:
- loss function
- train loop
- build real nn
- collect data in self play loop


# Basic PyTorch Setup
### Dataset
We will need to place our data into a DataLoader object so it can be iterated over in the training loop. This object will take our data set and deliver it to us in whatever batch size we ask for. We can also take our training data and place it into a Dataset object, which will be put into the dataloader.  

### Model
To make the model, we have it inherit from nn.Module, define the layers in the init function, and then show how the data passes through the layers in the forward function. We also have to make sure to move the model to the GPU if its available.

### Optimize
Create a loss function and an optimization function (usually SGD or Adam).

### Train
Train loop calls model.train() function (what is this?) takes the data from the dataloader and moves it to the GPU. It does a forward pass on the train data to get the predictions, then will measure the loss. Zero out the gradients on the optimizer and then compute the gradients from the loss. (Why do you have to zero out?). Then take a step with the optimizer to change the weights with the new gradients. You should also print out your training metrics like loss and epoch number periodically.

### Test
Load up some test data and then perform forward passes on all of it to get all the models predictions. You can turn off gradients to increase performance. Your test loss is the accumulation of all the loss functions from the test set, and your accuracy is the # of predictions that the model got correct.

### Train/test loop
One epoch in training would consist of training the model on the entire dataset and then testing the model. Its also common for people to save checkpoints of their model throughout training, just in case they had a model that performed pretty well in the past before it started to overfit.

# Tensors
Tensors work pretty much the name as numpy ndarrays, and can use the same underlying memory. The difference though is that tensors can run on a GPU and ndarrays only run on the CPU. I think most of the time id do this would be if I whipped up a numpy array to build some data (like in mcts) and then need to move it to the GPU for training.
A tensor has a shape, dtype, and a device it exists on. Tensors are created on the CPU by default and have to be moved to the GPU (could you put them on the GPU directly so you dont have to move them?)  
Dot Product: tensor.matmul(tensor.T) == tensor @ tensor.T  
Element-wise product: tensor * tensor == tensor.mul(tensor)

You can convert one element tensors, like sums, into a python numerical value using tensor.item(). Might be useful for loss functions.  

# Autograd
In order to get the gradients, you will need to set requires_grad=True on your tensors and then call for a backwards pass on your loss function (or any output you get that can give a gradient of your weights). Sometimes though you might want to disable gradient tracking in a network, like if you want to freeze some params when finetuning a network or to speed up computations when doing a forward pass. 

In [None]:
def load_batch(dataset, batch_size, random_sample=True):
    if random_sample:
        data_batch = random.sample(dataset, batch_size)
    else:
        data_batch = dataset[:batch_size]
    x_list = []
    pi_list = []
    z_list = []

    for datapoint in data_batch:
        # Convert all datapoints into a training dataset if they pass the test
        x, (pi, z) = datapoint.get_tensor()
        x_list.append(x)
        pi_list.append(pi)
        z_list.append(z)
      
    return torch.stack(x_list), (torch.stack(pi_list), torch.stack(z_list).view(-1, 1))

In [None]:
def train_loop(train_data, model, loss_fn, optimizer, device, batch_size=64):
    size = len(train_data)
    num_batches = size // batch_size
    for i in range(0, num_batches):
        X, (pi, z) = load_batch(train_data, batch_size)
        X = X.to(device)
        pi = pi.to(device)
        z = z.to(device)

        # Compute prediction and loss
        p, v = model(X)
        assert p.shape == pi.shape
        assert v.shape == z.shape
        loss = loss_fn(p, v, pi, z)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            loss, current = loss.item(), i * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(test_data, model, loss_fn, device, batch_size=64):
    size = len(test_data)
    num_batches = size // batch_size
    if not num_batches:
        num_batches = 1
        batch_size = size
    test_loss, correct = 0, 0

    with torch.no_grad():
        for i in range(0, num_batches):
            X, (pi, z) = load_batch(test_data, batch_size)
            X = X.to(device)
            pi = pi.to(device)
            z = z.to(device)

            p, v = model(X)
            assert p.shape == pi.shape
            assert v.shape == z.shape
            test_loss += loss_fn(p, v, pi, z).item()
            correct += (p.argmax(1) == pi.argmax(1)).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
def generate_games(net, iters=32, episodes=100):
    # Generate train/test data
    net.eval()
    generated_data = generate_train_data(net, episodes=episodes, iters=iters)
    random.shuffle(generated_data)
    split_idx = 4*(len(generated_data) // 5)
    train_data, test_data = generated_data[:split_idx], generated_data[split_idx:]
    return train_data, test_data

In [None]:
def train_test_loop(net, train_data, test_data, epochs=50):
    net.train()
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_data, net, loss_fn, optimizer, device, batch_size=batch_size)
        test_loop(test_data, net, loss_fn, device, batch_size=batch_size)
    print("Done!")

In [None]:
import os
def save_checkpoint(nnet, folder='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(folder, filename)
    if not os.path.exists(folder):
        print("Checkpoint Directory does not exist! Making directory {}".format(folder))
        os.mkdir(folder)
    else:
        print("Checkpoint Directory exists! ")
    torch.save({
        'state_dict': nnet.state_dict(),
    }, filepath)

def load_checkpoint(nnet, folder='checkpoint', filename='checkpoint.pth.tar'):
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
    filepath = os.path.join(folder, filename)
    if not os.path.exists(filepath):
        raise ("No model in path {}".format(filepath))
    map_location = None
    checkpoint = torch.load(filepath, map_location=map_location)
    nnet.load_state_dict(checkpoint['state_dict'])

In [None]:
learning_rate = 5e-3
batch_size = 64
epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# Create neural net
net = ChessNet().to(device)
loss_fn = loss
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
# load_checkpoint(net, folder="drive/MyDrive/weights")

In [None]:
while True:
    train_data, test_data = generate_games(net, iters=256, episodes=50)
    train_test_loop(net, train_data, test_data, epochs=100)
    save_checkpoint(net, folder="drive/MyDrive/weights")

In [None]:
from IPython.display import display

def test_net(net):
    board = chess.Board()
    tree = {'data': np.array((0,0,0,1)), 'color': board.turn}
    while board.outcome() is None:
        # Change exploration depending on the move count
        if board.fullmove_number < 30:
            temp = 1
        else:
            temp = 0
        move, tree, pi = mcts_step(tree, net, board, temp=temp, iterations=256, epsilon=0.0)
        board.push_uci(move)
        display(board)

In [None]:
# Profiling
import cProfile, pstats, io

def profile(fnc):
    
    """A decorator that uses cProfile to profile a function"""
    
    def inner(*args, **kwargs):
        
        pr = cProfile.Profile()
        pr.enable()
        retval = fnc(*args, **kwargs)
        pr.disable()
        s = io.StringIO()
        sortby = 'cumulative'
        ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
        ps.print_stats()
        print(s.getvalue())
        return retval

    return inner

In [None]:
# Train and play loop
net.eval()
test_net(net)