In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch as tch
# use proper seaborn styling
import seaborn as sns
sns.set_theme()

In [None]:
# test-x.txt and test-y.txt are comma-seperated files of float/int values, corresponding to the board state and the policy respectively
x = np.loadtxt("datasets/dataset1-x.txt", delimiter=",")
y = np.loadtxt("datasets/dataset1-y.txt", delimiter=",")

print(f"{len(x)} datapoints loaded!")
assert len(x) == len(y)

In [None]:
# enforce that each y value sums to 1
y = y / np.sum(y, axis=1)[:, None]

# split the dataset into training and validation sets
split = int(len(x) * 0.9)
x_train = x[:split]
y_train = y[:split]
x_val = x[split:]
y_val = y[split:]

# shuffle the datasets
perm_train = np.random.permutation(len(x_train))
x_train = x_train[perm_train]
y_train = y_train[perm_train]
perm_val = np.random.permutation(len(x_val))
x_val = x_val[perm_val]
y_val = y_val[perm_val]

# convert to tensors
x_train = tch.tensor(x_train, dtype=tch.float)
y_train = tch.tensor(y_train, dtype=tch.float)
print(f"x_train has dims {x_train.shape}")
print(f"y_train has dims {y_train.shape}")
x_val = tch.tensor(x_val, dtype=tch.float)
y_val = tch.tensor(y_val, dtype=tch.float)
print(f"x_val has dims {x_val.shape}")
print(f"y_val has dims {y_val.shape}")

# create a dataset class
class GomokuDataset(tch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

# create dataloaders
train_dataset = GomokuDataset(x_train, y_train)
val_dataset = GomokuDataset(x_val, y_val)
train_loader = tch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = tch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
x_shaped = x_train.reshape(-1, 2, 9, 9)
y_shaped = y_train.reshape(-1, 9, 9)

# show some sample data
fig, axs = plt.subplots(5, 4, figsize=(12, 10))
for i in range(5):
    # left plot is the board state
    board = x_shaped[i, 0] + x_shaped[i, 1] / 2
    axs[i, 0].imshow(board, cmap="gray")
    # right plot is the policy
    policy = y_shaped[i]
    axs[i, 1].imshow(policy, cmap="inferno")

    # plot the nonzero board values
    nonzero_board = board != 0
    axs[i, 2].imshow(nonzero_board, vmin=0, vmax=1)

    # plot the zero policy values
    zero_policy = policy == 0
    axs[i, 3].imshow(zero_policy, vmin=0, vmax=1)

# remove the gridlines
for ax in axs.flatten():
    ax.grid(False)

In [None]:
# define an extremely simple model
from matplotlib.pylab import f


class PolicyModel(tch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = tch.nn.Linear(9*9*2, 9*9)
        self.softmax = tch.nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc(x)
        x = self.softmax(x)
        return x

# define a convolutional model
# this is ever so slightly more complicated than the previous model
# as we need to reshape the input to be 4-dimensional
class ConvPolicyModel(tch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu    = tch.nn.ReLU()
        self.conv1   = tch.nn.Conv2d(2, 8, 3, padding=1) # 2x9x9 -> 8x9x9
        self.conv2   = tch.nn.Conv2d(8, 2, 3)            # 8x9x9 -> 2x7x7
        self.fc      = tch.nn.Linear(2 * 7 * 7, 9 * 9)
        self.softmax = tch.nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 2, 9, 9)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        # flatten
        x = x.view(-1, 2 * 7 * 7)
        x = self.fc(x)
        # x = self.softmax(x)
        return x

In [None]:
# create the model and optimizer
model = ConvPolicyModel()
optimizer = tch.optim.AdamW(model.parameters(), lr=1e-3)

# create a loss function to match the probability distribution
def loss_fn(prediction_logits, target_distribution):
    return tch.nn.functional.binary_cross_entropy(tch.nn.functional.softmax(prediction_logits, dim=1), target_distribution)

# create a function for masking off illegal moves
def mask_illegal_moves(model_prediction, board):
    # return model_prediction
    # model_prediction is an 81-element vector of probabilities
    # board is a 81 * 2 = 162-element vector of occupancies
    # we need to set all illegal moves to 0,
    # and then renormalize the probabilities
    # so that they sum to 1 again
    # first, get a mask of all illegal moves
    illegal_moves = board[:, :81] + board[:, 81:]
    # now set all illegal moves to 0
    model_prediction = tch.where(illegal_moves != 0, tch.zeros_like(model_prediction), model_prediction)

    return model_prediction

# create a training loop
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    losses = []
    val_losses = []
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}")
        model.train()
        for batch_idx, (board_state, search_policy) in enumerate(train_loader):
            board_state = board_state.to(device)
            search_policy = search_policy.to(device)
            optimizer.zero_grad()
            raw_policy = model(board_state)
            masked_policy = mask_illegal_moves(raw_policy, board_state)
            loss = loss_fn(masked_policy, search_policy)
            loss.backward()
            optimizer.step()
            loss_value = loss.item()
            total_batch_idx = epoch * len(train_loader) + batch_idx
            losses.append((total_batch_idx, loss_value))
            if batch_idx % 64 == 0:
                print(f"Training batch {batch_idx}/{len(train_loader)}: loss {loss_value}")
        val_loss = 0.0
        model.eval()
        with tch.no_grad():
            for batch_idx, (board_state, search_policy) in enumerate(val_loader):
                board_state = board_state.to(device)
                search_policy = search_policy.to(device)
                raw_policy = model(board_state)
                masked_policy = mask_illegal_moves(raw_policy, board_state)
                val_loss += loss_fn(masked_policy, search_policy).item()
            val_loss /= len(val_loader)
            total_batch_idx = (epoch + 1) * len(train_loader)
            val_losses.append((total_batch_idx, val_loss))
            print(f"Validation loss {val_loss}")
    return losses, val_losses

In [None]:
# train the model
# print(f"Training on device {tch.cuda.get_device_name(0)}")
loss_trace, val_loss_trace = train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20)

In [None]:
# plot the loss trace and validation loss trace
loss_trace = np.array(loss_trace)
val_loss_trace = np.array(val_loss_trace)
plt.plot(loss_trace[:, 0], loss_trace[:, 1])
plt.plot(val_loss_trace[:, 0], val_loss_trace[:, 1])
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.legend(["Training loss", "Validation loss"])
plt.show()

In [None]:
# save the model
tch.save(model.state_dict(), "model.pt")

In [None]:
# visualize the model's predictions
model.eval()

# get five random items from the validation set
rand_idx = np.random.randint(len(val_dataset), size=10)
x_sample = x_val[rand_idx]
y_sample = y_val[rand_idx]
y_pred_raw = model(x_sample)
y_pred = mask_illegal_moves(y_pred_raw, x_sample)
# apply softmax to get a probability distribution
y_pred = tch.nn.functional.softmax(y_pred, dim=1)

x_sample = x_sample.reshape(-1, 2, 9, 9)
y_sample = y_sample.reshape(-1, 9, 9)
y_pred = y_pred.detach().numpy().reshape(-1, 9, 9)

In [None]:
# plot the results
# the first column is the board state
# the second column is the search policy
# the third column is the neural network policy
fig, axs = plt.subplots(5, 3, figsize=(7, 13))
# label the columns
axs[0, 0].set_title("Board state")
axs[0, 1].set_title("Search policy")
axs[0, 2].set_title("Neural network policy")
for i in range(5):
    board_one = x_sample[i][0]
    board_two = x_sample[i][1]
    search = y_sample[i]
    nn = y_pred[i]
    pieces_on_first_board = board_one.flatten().sum()
    pieces_on_second_board = board_two.flatten().sum()
    x_to_move = pieces_on_first_board != pieces_on_second_board
    if not x_to_move:
        board_one, board_two = board_two, board_one
    board = np.stack([board_one, np.zeros((9, 9)), board_two], axis=2) / 1.5
    search_policy_board = np.stack([search, search, search], axis=2)
    nn_policy_board = np.stack([nn, nn, nn], axis=2)

    # renormalise the policies so that the max prediction is 1.0
    search_policy_board *= 1.0 / search_policy_board.max()
    nn_policy_board *= 1.0 / nn_policy_board.max()

    search_policy_board += board
    nn_policy_board += board
    axs[i, 0].imshow(board, vmin=0, vmax=255, cmap="inferno")
    axs[i, 1].imshow(search_policy_board, vmin=0, vmax=255)
    axs[i, 2].imshow(nn_policy_board, vmin=0, vmax=255)

# set gridlines to 1x1
for ax in axs.flatten():
    ax.set_xticks(np.arange(-.5, 9, 1))
    ax.set_yticks(np.arange(-.5, 9, 1))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(True, color="grey", linewidth=0.5)

In [None]:
# export model to ONNX
onnx_model_path = "model.onnx"

# create a dummy input
dummy_input = tch.randn(1, 162)

# export the model
tch.onnx.export(model, dummy_input, onnx_model_path, verbose=True)

In [None]:
# verify the model with onnxruntime
common_input_data = x_sample
pytorch_net_output = model(common_input_data).detach().numpy()
import onnxruntime as ort
ort_session = ort.InferenceSession(onnx_model_path)
onnx_net_output = []
for i in range(len(common_input_data)):
    input_thing = {'onnx::Reshape_0': common_input_data[i].numpy().reshape(1, 162)}
    onnx_net_output.append(ort_session.run(None, input_thing)[0])
onnx_net_output = np.squeeze(np.array(onnx_net_output), axis=1)

# compare the outputs
assert pytorch_net_output.shape == onnx_net_output.shape
assert np.allclose(pytorch_net_output, onnx_net_output, rtol=1e-03, atol=1e-05)