In [None]:
!pip install pytorch-lightning

In [None]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, Subset, DataLoader, TensorDataset, IterableDataset
import numpy as np
import torch._dynamo

torch._dynamo.config.suppress_errors = True
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
from sklearn.metrics import accuracy_score
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

torch.set_float32_matmul_precision('high')

In [None]:
!apt-get install p7zip-full

In [None]:
from google.colab import drive
import subprocess

# Mount Google Drive
drive.mount('/content/drive')

In [None]:
# Define the path to your ZIP file and the extraction directory
zip_file_path = "/content/drive/MyDrive/ChessTrainingData/gameDataSP8.zip"
extract_to_path = "/content/dataset"

# Use 7z command-line tool to extract the file
subprocess.run(['7z', 'x', zip_file_path, f'-o{extract_to_path}'])

# Check if the directory contains the extracted files
!ls /content/dataset


In [None]:
test_zip_file_path = "/content/drive/MyDrive/ChessTrainingData/testGameData.zip"
test_extract_to_path = "/content/test/dataset"

# Use 7z command-line tool to extract the file
subprocess.run(['7z', 'x', test_zip_file_path, f'-o{test_extract_to_path}'])

In [None]:
class IterableGameDataset(IterableDataset):
    def __init__(self, encoded_states_path, policy_labels_path, value_labels_path,
                 num_encoded_planes, num_output_planes, game_row_count, game_column_count):
        super(IterableGameDataset, self).__init__()
        self.encoded_states_path = encoded_states_path
        self.policy_labels_path = policy_labels_path
        self.value_labels_path = value_labels_path
        self.num_encoded_planes = num_encoded_planes
        self.num_output_planes = num_output_planes
        self.game_row_count = game_row_count
        self.game_column_count = game_column_count

        # Estimate the total number of samples
        self.total_samples = np.fromfile(self.value_labels_path, dtype=np.int32).size

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading
            iter_start = 0
            iter_end = self.total_samples
        else:  # in a worker process, divide the dataset
            per_worker = int(np.ceil(self.total_samples / worker_info.num_workers))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.total_samples)

        for index in range(iter_start, iter_end):
            yield self.load_sample(index)

    def load_sample(self, index):
        encoded_state = self.load_encoded_state(index)
        policy = self.load_policy(index)
        value = self.load_value(index)

        return np.r_[encoded_state[:14], encoded_state[56:]], policy, value

    def load_value(self, index):
        dtype = np.int32
        itemsize = np.dtype(dtype).itemsize
        offset = index * itemsize
        with open(self.value_labels_path, 'rb') as file:
            file.seek(offset)
            value = np.fromfile(file, dtype=dtype, count=1)[0]
        return value

    def load_policy(self, index):
        itemsize = 1  # np.int8
        num_items = self.game_column_count * self.num_output_planes
        offset = index * num_items * itemsize
        with open(self.policy_labels_path, 'rb') as file:
            file.seek(offset)
            data_raw = np.fromfile(file, dtype=np.int8, count=num_items)
        data_bits = ((data_raw[:, None] & (1 << np.arange(8))) > 0).astype(np.int8)
        policy = data_bits.reshape((self.game_row_count * self.game_column_count * self.num_output_planes))
        return policy.flatten()

    def load_encoded_state(self, index):
        itemsize = 1  # np.int8
        num_items = self.game_column_count * self.num_encoded_planes
        offset = index * num_items * itemsize
        with open(self.encoded_states_path, 'rb') as file:
            file.seek(offset)
            data_raw = np.fromfile(file, dtype=np.int8, count=num_items)
        data_bits = ((data_raw[:, None] & (1 << np.arange(8))) > 0).astype(np.int8)
        encoded_state = data_bits.reshape((self.num_encoded_planes, self.game_row_count, self.game_column_count))
        return encoded_state


In [None]:
class IterableSelfPlayDataset(IterableDataset):
    def __init__(self, encoded_states_path, policy_labels_path, value_labels_path,
                 num_encoded_planes, num_output_planes, game_row_count, game_column_count):
        super(IterableSelfPlayDataset, self).__init__()
        self.encoded_states_path = encoded_states_path
        self.policy_labels_path = policy_labels_path
        self.value_labels_path = value_labels_path
        self.num_encoded_planes = num_encoded_planes
        self.num_output_planes = num_output_planes
        self.game_row_count = game_row_count
        self.game_column_count = game_column_count

        # Estimate the total number of samples
        self.total_samples = np.fromfile(self.value_labels_path, dtype=np.int32).size

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading
            iter_start = 0
            iter_end = self.total_samples
        else:  # in a worker process, divide the dataset
            per_worker = int(np.ceil(self.total_samples / worker_info.num_workers))
            worker_id = worker_info.id
            iter_start = worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.total_samples)

        for index in range(iter_start, iter_end):
            yield self.load_sample(index)

    def load_sample(self, index):
        encoded_state = self.load_encoded_state(index)
        policy = self.load_policy(index)
        value = self.load_value(index)

        return np.r_[encoded_state[:14], encoded_state[56:]], policy, value

    def load_value(self, index):
        dtype = np.float32
        itemsize = np.dtype(dtype).itemsize
        offset = index * itemsize
        with open(self.value_labels_path, 'rb') as file:
            file.seek(offset)
            value = np.fromfile(file, dtype=dtype, count=1)[0]
        return value

    def load_policy(self, index):
        dtype = np.float32  # Change the data type to float32
        itemsize = np.dtype(dtype).itemsize  # Update itemsize to 4 bytes for float32
        num_items = self.game_row_count * self.game_column_count * self.num_output_planes
        offset = index * num_items * itemsize
        with open(self.policy_labels_path, 'rb') as file:
            file.seek(offset)
            policy = np.fromfile(file, dtype=dtype, count=num_items)
        policy = policy.reshape((self.game_row_count * self.game_column_count * self.num_output_planes))
        return policy

    def load_encoded_state(self, index):
        itemsize = 1  # np.int8
        num_items = self.game_column_count * self.num_encoded_planes
        offset = index * num_items * itemsize
        with open(self.encoded_states_path, 'rb') as file:
            file.seek(offset)
            data_raw = np.fromfile(file, dtype=np.int8, count=num_items)
        data_bits = ((data_raw[:, None] & (1 << np.arange(8))) > 0).astype(np.int8)
        encoded_state = data_bits.reshape((self.num_encoded_planes, self.game_row_count, self.game_column_count))
        return encoded_state


# Paths to your data files
input_path = '/content/dataset/'
encoded_states_path = input_path + 'encoded_states.bin'
policy_labels_path = input_path + 'policy_labels.bin'
value_labels_path = input_path + 'value_labels.bin'

game_row_count = 8
game_column_count = 8
num_encoded_planes = 64
num_output_planes = 73

In [None]:
# Create the full dataset
train_dataset = IterableSelfPlayDataset(encoded_states_path, policy_labels_path, value_labels_path, num_encoded_planes,
                           num_output_planes, game_row_count, game_column_count)

In [None]:
# Paths to your data files
test_input_path = '/content/test/dataset/'
test_encoded_states_path = test_input_path + 'encoded_states.bin'
test_policy_labels_path = test_input_path + 'policy_labels.bin'
test_value_labels_path = test_input_path + 'value_labels.bin'

# Create the full dataset
val_dataset = IterableGameDataset(test_encoded_states_path, test_policy_labels_path, test_value_labels_path, num_encoded_planes,
                           num_output_planes, game_row_count, game_column_count)

In [None]:
encoded_state, policy, value = train_dataset.load_sample(0)
value

In [None]:
# Turn on quantization for CPU support
class ResLayer(nn.Module):
    def __init__(self, input_layers):
        super().__init__()
        self.input_layers = input_layers

        self.conv1 = nn.Conv2d(self.input_layers, self.input_layers, (3, 3), padding=1)
        self.batch_norm1 = nn.BatchNorm2d(self.input_layers)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(self.input_layers, self.input_layers, (3, 3), padding=1)
        self.batch_norm2 = nn.BatchNorm2d(self.input_layers)

        # self.quant = torch.ao.quantization.QuantStub()
        # self.dequant = torch.ao.quantization.DeQuantStub()

        self.relu2 = nn.ReLU()

    def forward(self, x):
        res = x
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        # x = self.dequant(x)
        # res = self.dequant(res)
        x += res
        # x = self.quant(x)
        return self.relu2(x)

class ResNet(nn.Module):
    def __init__(self, input_layers, intermediate_layers, output_layers, num_hidden_layers, game_row_count,
                 game_column_count):
        super().__init__()

        # self.quant = torch.ao.quantization.QuantStub()
        self.intermediate_layers = intermediate_layers
        self.input_layers = input_layers
        self.policy_layers = output_layers
        self.num_hidden_layers = num_hidden_layers
        self.game_row_count = game_row_count
        self.game_column_count = game_column_count
        # self.dequant = torch.ao.quantization.DeQuantStub()

        self.startBlock = nn.Sequential(
            nn.Conv2d(self.input_layers, self.intermediate_layers, (3, 3), padding=1),
            nn.BatchNorm2d(self.intermediate_layers),
            nn.ReLU(),
        )

        self.resBlocks = nn.ModuleList(
            [ResLayer(self.intermediate_layers) for _ in range(self.num_hidden_layers)]
        )

        self.policyHead = nn.Sequential(
            nn.Conv2d(self.intermediate_layers, self.intermediate_layers, (3, 3), padding=1),
            nn.BatchNorm2d(self.intermediate_layers),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self.intermediate_layers * self.game_row_count * self.game_column_count,
                      self.policy_layers * self.game_row_count * self.game_column_count)
        )

        self.valueHead = nn.Sequential(
            nn.Conv2d(self.intermediate_layers, self.intermediate_layers, (3, 3), padding=1),
            nn.BatchNorm2d(self.intermediate_layers),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(self.intermediate_layers * self.game_row_count * self.game_column_count,
                      self.intermediate_layers * self.game_row_count * self.game_column_count),
            # torch.ao.quantization.DeQuantStub(),
            nn.BatchNorm1d(self.intermediate_layers * self.game_row_count * self.game_column_count),
            # torch.ao.quantization.QuantStub(),
            nn.ReLU(),
            nn.Linear(self.intermediate_layers * self.game_row_count * self.game_column_count, 1),
            nn.Tanh()
        )

    def forward(self, x):
        # x = self.quant(x)
        x = self.startBlock(x)
        for resBlock in self.resBlocks:
            x = resBlock(x)
        policy = self.policyHead(x)
        # policy = self.dequant(policy)
        value = self.valueHead(x)
        # value = self.dequant(value)
        return policy.squeeze(-1), value.squeeze(-1)

    def fuse_modules(self):
        # Fuse modules in startBlock, policyHead, and valueHead
        torch.ao.quantization.fuse_modules(self.startBlock, ['0', '1', '2'], inplace=True)
        torch.ao.quantization.fuse_modules(self.policyHead, ['0', '1', '2'], inplace=True)
        torch.ao.quantization.fuse_modules(self.valueHead, ['0', '1', '2'], inplace=True)
        # Fuse modules in resBlocks
        for res_layer in self.resBlocks:
            torch.ao.quantization.fuse_modules(res_layer, ['conv1', 'batch_norm1', 'relu1'], inplace=True)
            torch.ao.quantization.fuse_modules(res_layer, ['conv2', 'batch_norm2'], inplace=True)

In [None]:
class Network(pl.LightningModule):
    def __init__(self, input_layers, intermediate_layers, output_layers, num_hidden_layers, game_row_count,
                 game_column_count):
        super(Network, self).__init__()

        self.intermediate_layers = intermediate_layers
        self.input_layers = input_layers
        self.policy_layers = output_layers
        self.num_hidden_layers = num_hidden_layers
        self.game_row_count = game_row_count
        self.game_column_count = game_column_count

        self.network = ResNet(self.input_layers, self.intermediate_layers, self.policy_layers,
                              self.num_hidden_layers, self.game_row_count, self.game_column_count)

    def forward(self, x):
        return self.network(x)

    def cross_entropy(self, outputs, targets):
        log_softmax_outputs = F.log_softmax(outputs, dim=1)
        loss = -torch.sum(targets * log_softmax_outputs, dim=1)
        return torch.mean(loss)

    def loss_function(self, policy_output, policy_target, value_output, value_target, model, lambda_reg=0.0001):


        policy_loss = self.cross_entropy(policy_output, policy_target)
        value_loss = nn.functional.mse_loss(value_output, value_target)

        # l2_reg = torch.tensor(0.).to(value_output.device)

        # for param in model.parameters():
        #     l2_reg += param.norm(2) ** 2

        loss = value_loss + policy_loss
        # loss += lambda_reg * l2_reg

        return loss, value_loss, policy_loss

    def training_step(self, batch, batch_idx):
        encoded_state, policy_target, value_target = batch

        encoded_state = encoded_state.float()
        policy_target = policy_target.float()
        value_target = value_target.float()

        policy_output, value_output = self(encoded_state)
        loss, value_loss, policy_loss = self.loss_function(policy_output, policy_target, value_output, value_target,
                                  model=self)

        # GET OUTPUTS AND TARGETS PER BATCH
        policy_pred = policy_output.argmax(dim=1).cpu().numpy()
        policy_true = policy_target.argmax(dim=1).cpu().numpy()

        # Train accuracy per batch
        train_policy_accuracy = accuracy_score(policy_pred, policy_true)

        # Train loss and accuracy per batch
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_policy_accuracy", train_policy_accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        encoded_state, policy_target, value_target = batch

        # Convert to appropriate data type
        encoded_state = encoded_state.float()
        policy_target = policy_target.float()
        value_target = value_target.float()

        # Forward pass
        policy_output, value_output = self(encoded_state)
        val_loss, val_value_loss, val_policy_loss = self.loss_function(policy_output, policy_target, value_output, value_target,
                                  model=self)

        # GET OUTPUTS AND TARGETS PER BATCH
        policy_pred = policy_output.argmax(dim=1).cpu().numpy()
        policy_true = policy_target.argmax(dim=1).cpu().numpy()

        # Train accuracy per batch
        val_policy_accuracy = accuracy_score(policy_pred, policy_true)

        # Log validation loss and accuracy
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_accuracy', val_policy_accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return val_loss

    def on_validation_epoch_end(self):
      for i in [1320, 1321]:
        test_encoded_state, test_policy, test_value = val_dataset.load_sample(i)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        test_encoded_state = torch.from_numpy(test_encoded_state).float().to(device)
        self.to(device)
        self.eval()  # Set the model to evaluation mode
        with torch.no_grad():  # Disable gradient computation
          policy_output, value_output = self(test_encoded_state.unsqueeze(0))
          policy_index = test_policy.argmax()
          policy = policy_output.squeeze(0)
          print(f'Computer Eval: {value_output.item()}, Actual Result: {test_value}')

    def train_dataloader(self):
        # Setup DataLoader for the training dataset
        train_loader = DataLoader(train_dataset, batch_size=1024, persistent_workers=True, num_workers=8, pin_memory=True)
        return train_loader

    def val_dataloader(self):
        # Setup DataLoader for the validation dataset
        val_loader = DataLoader(val_dataset, batch_size=4096, persistent_workers=True, num_workers=8, pin_memory=True)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=4e-6)
        # optimizer = torch.optim.RMSprop(self.parameters(), lr=2e-7)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100000, 130000], gamma=0.1)
        return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]

In [None]:
num_epochs = 1
game_row_count = 8
game_column_count = 8
num_encoded_planes = 22
num_output_planes = 73
num_intermediate_layers = 192
num_hidden_layers = 13

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint


checkpoint_callback_recent = ModelCheckpoint(
    dirpath="/content/drive/My Drive/13x192Checkpoints-test/recent",
    filename="best_checkpoint",
    every_n_epochs=1,  # Save the latest model every epoch
    save_top_k=1,
    monitor='val_loss',
    verbose=True
)

In [None]:
model = Network(num_encoded_planes, num_intermediate_layers, num_output_planes, num_hidden_layers, game_row_count,
                        game_column_count)
print(model)

In [None]:
# Turn on for CPU training
# model.eval()
# model.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# model.network.fuse_modules()
# model_prepared = torch.ao.quantization.prepare_qat(model.train())
# print(model_prepared)

In [None]:
state_dict = torch.load('/content/drive/My Drive/13x192Checkpoints-test/current_model_weights.pth')
model.load_state_dict(state_dict)

In [None]:
# Load the checkpoint
checkpoint = torch.load("/content/drive/My Drive/13x192Checkpoints-test/recent/best_checkpoint-v9.ckpt", map_location=lambda storage, loc: storage)

# Load only the model weights
model.load_state_dict(checkpoint['state_dict'])

In [None]:
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer(
    max_epochs=num_epochs,
    callbacks=[checkpoint_callback_recent, lr_monitor],
    precision=16
)
trainer.fit(model)

In [None]:
# model_prepared = model_prepared.to('cpu')
# model_prepared.eval()
# model_int8 = torch.ao.quantization.convert(model_prepared)

In [None]:
torch.save(model.state_dict(), '/content/drive/My Drive/13x192Checkpoints-test/current_model_weights.pth')

In [None]:
# save with jit scripting for c++ use
model.eval()
model.half()
script = model.to_torchscript()
torch.jit.save(script, "/content/drive/My Drive/13x192Checkpoints-test/ResNet13x192_nohistorySPtest.pt")

In [None]:
model = torch.jit.load("/content/drive/My Drive/13x192Checkpoints-test/ResNet13x192_nohistorySP.pt")

In [None]:
model.half()

In [None]:
print(model)

In [None]:
val_dataset.total_samples

In [None]:
test_encoded_state, test_policy, test_value = val_dataset.load_sample(0)

In [None]:
test_value

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
test_encoded_state = torch.from_numpy(test_encoded_state).float().to(device).half()

In [None]:
model.to(device)
model.eval()  # Set the model to evaluation mode
with torch.no_grad():  # Disable gradient computation
    policy_output, value_output = model(test_encoded_state.unsqueeze(0))
    policy_index = test_policy.argmax()
    policy = policy_output.squeeze(0)
    print(f'Computer Eval: {value_output.item()}, Actual Result: {test_value}, Actual Move Estimated Policy: {policy[policy_index].cpu().numpy()}, Policy Index: {policy_index}')