In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import random

from QuantumChessGame import * 
from ChessPuzzles import *
from GameToTensor import *
from ChessPuzzles import chess_puzzles

from MCTS import MCTS_Node

import numpy as np
import pandas as pd 

import QChessNN
import MCTS_NN

In [2]:
torch.manual_seed(42)

#Declare a new model
#NNmodel = QChessNN.QChessNN()



# Load the model
NNmodel = QChessNN.QChessNN()
#NNmodel = torch.jit.load('testExport.pt')
#NNmodel.eval()

In [3]:
import pandas as pd
#import mathplotlib.pyplot as plt
#%mathplotlib inline

In [4]:
game = QuantumChessGame()
game.new_game()

gameData = game.get_game_data()
game_tensor = torch.zeros(1,12,8,8)

game_tensor[0] = gameToTensor(gameData, 0)
#print(game_tensor)
y = torch.zeros(12)

In [5]:
class MCTS_AI:
    def __init__(self):
        return

    def find_best_move(self, game, simVar):
        root = MCTS_Node(game)
        gamedata = game.get_game_data()
        bestmove = root.best_action(gamedata.ply, simVar)
        return bestmove

In [6]:
class NetworkMCTS():
    def __init__(self):
            return

    
    def find_best_move(self, game, model, simVar):
        root = MCTS_NN.MCTS_Node(game, model)
        gamedata = game.get_game_data()
        bestmove, value = root.best_action(gamedata.ply, model, simVar)
        return bestmove, value

In [7]:
mcts_nn = NetworkMCTS()

MCTSAI =  MCTS_AI()

In [8]:

def self_play_game(model, moveMax, player1bot = True, player2bot = True):
    board_data_B = []
    board_data_W = []
    values_B = []
    values_W = []
    moves_B = []
    moves_W = []
    game = QuantumChessGame()
    game.new_game({'initial_state_fen':get_puzzle_fen(random.randint(33,35)),  'max_split_moves':[1,1]});
    movecode = 0;
    while not game.is_game_over():
        gamedata = game.get_game_data()

        board_data_W.append(gamedata)
        values_W.append(0)
        board_data_B.append(gamedata)
        values_B.append(0)

        #best_move, value = mcts_nn.find_best_move(game, model, 30)

        #best_move = MCTSAI.find_best_move(game, 35)
        
        #print(f"Value {value}")
        #print("found best move")

        
        # Record the state, policy, and value
        if (gamedata.ply % 2 == 0):
            if(player1bot):
                #best_move, value = MCTSAI.find_best_move(game, 50)
                best_move, value = mcts_nn.find_best_move(game, model, 3)
            else:
                best_move = input("Enter your move: ")
                value = 0
            board_data_W.append(gamedata)
            values_W.append(value)

        if (gamedata.ply % 2 == 1):
            if(player2bot):
                #best_move, value = MCTSAI.find_best_move(game, 50)
                best_move, value = mcts_nn.find_best_move(game, model, 3)
            else:
                best_move = input("Enter your move: ")
                value = 0
            board_data_B.append(gamedata)
            values_B.append(value)
        
        

        if (gamedata.ply == moveMax):
            return board_data_W, board_data_B, moves_W, moves_B, values_W, values_B, 0
        
        print(f"player # {gamedata.ply}")
        print(f"move taken {best_move}")
        
        # Apply the move to the board
        board_state, movecode = game.do_move(best_move)
        move = game.get_unformatted_last_move()
        print(f"move {move}")
        if (gamedata.ply % 2 == 1):
            moves_W.append(move)
        if (gamedata.ply % 2 == 0):
            moves_B.append(move)
            
        game.print_board_and_probabilities()
        
    if(movecode == 2):
        return board_data_W, board_data_B, moves_W, moves_B, values_W, values_B, 1  # Return +1 for  white win, 0 for draw, -1 for black win

    if(movecode == 3):
        return board_data_W, board_data_B, moves_W, moves_B, values_W, values_B, 0  # Return +1 for  white win, 0 for draw, -1 for black win

    if(movecode == 5):
        return board_data_W, board_data_B, moves_W, moves_B, values_W, values_B, -1 



In [9]:

# Training loop
def train_model(NNmodel, player1bot, player2bot, games_per_epoch, epochs):
    optimizer = torch.optim.Adam(NNmodel.parameters(), lr=0.001, weight_decay=1e-4)
    game_tensor = torch.zeros(1,12,8,8)
    for epoch in range(epochs):
        print(f"starting epoch {epoch + 1}")
        for game in range(games_per_epoch):
            
            
            board_data_w, board_data_b, moves_W, moves_B, values_W, values_B, result = self_play_game(NNmodel, 9, player1bot, player2bot)  # Play a game
            print(f"game {game + 1} finished")
            # Train the model on the collected game data
            
            for i in range(len(moves_B)):
                optimizer.zero_grad()

                game_tensor[0] = gameToTensor(board_data_b[i], 1)

                predicted_value = NNmodel(game_tensor)
                true_value = torch.tensor([[values_B[i]]], dtype=torch.float32)

                print(f"predicted value: ", {type(predicted_value[0])}, "true value: ", {type(true_value)})
                value_loss = F.mse_loss(predicted_value[0], true_value)
                

                print(predicted_value[0])

                piece = predicted_value[1][0]
                pos1 = predicted_value[1][1]
                pos2 = predicted_value[1][2]
                pos3 = predicted_value[1][3]
                move_type = predicted_value[1][4]
                variation = predicted_value[1][5]
                print(piece)
                print(moves_B[i]['piece'])
                # Assuming you have the true labels for the policy head
                true_piece = torch.tensor([[round(moves_B[i]['piece'])]], dtype=torch.float)  # Replace with actual label
                true_pos1 = torch.tensor([[round(moves_B[i]['square1'])]], dtype=torch.float)   # Replace with actual label
                true_pos2 = torch.tensor([[round(moves_B[i]['square2'])]], dtype=torch.float)   # Replace with actual label
                true_pos3 = torch.tensor([[round(moves_B[i]['square3'])]], dtype=torch.float)   # Replace with actual label
                true_move_type = torch.tensor([[round(moves_B[i]['type'])]], dtype=torch.float)  # Replace with actual label
                true_variation = torch.tensor([[round(moves_B[i]['variant'])]], dtype=torch.float)  # Replace with actual label

                # Compute policy loss
                policy_loss = (
                    F.cross_entropy(piece, true_piece) +
                    F.cross_entropy(pos1, true_pos1) +
                    F.cross_entropy(pos2, true_pos2) +
                    F.cross_entropy(pos3, true_pos3) +
                    F.cross_entropy(move_type, true_move_type) +
                    F.cross_entropy(variation, true_variation)
                )
                
                # Combine value loss and policy loss
                total_loss = value_loss + policy_loss
                total_loss.backward()
                optimizer.step()
                
            
            for i in range(len(moves_B)):
                optimizer.zero_grad()

                game_tensor[0] = gameToTensor(board_data_w[i], 1)

                predicted_value = NNmodel(game_tensor)
                true_value = torch.tensor([[values_W[i]]], dtype=torch.float32)

                print(f"predicted value: ", {type(predicted_value[0])}, "true value: ", {type(true_value)})
                value_loss = F.mse_loss(predicted_value[0], true_value)
                

                print(predicted_value[0])

                piece = predicted_value[1][0]
                pos1 = predicted_value[1][1]
                pos2 = predicted_value[1][2]
                pos3 = predicted_value[1][3]
                move_type = predicted_value[1][4]
                variation = predicted_value[1][5]
                print(piece)
                print(moves_W[i]['piece'])
                # Assuming you have the true labels for the policy head
                true_piece = torch.tensor([[round(moves_W[i]['piece'])]], dtype=torch.float)  # Replace with actual label
                true_pos1 = torch.tensor([[round(moves_W[i]['square1'])]], dtype=torch.float)   # Replace with actual label
                true_pos2 = torch.tensor([[round(moves_W[i]['square2'])]], dtype=torch.float)   # Replace with actual label
                true_pos3 = torch.tensor([[round(moves_W[i]['square3'])]], dtype=torch.float)   # Replace with actual label
                true_move_type = torch.tensor([[round(moves_W[i]['type'])]], dtype=torch.float)  # Replace with actual label
                true_variation = torch.tensor([[round(moves_W[i]['variant'])]], dtype=torch.float)  # Replace with actual label

                # Compute policy loss
                policy_loss = (
                    F.cross_entropy(piece, true_piece) +
                    F.cross_entropy(pos1, true_pos1) +
                    F.cross_entropy(pos2, true_pos2) +
                    F.cross_entropy(pos3, true_pos3) +
                    F.cross_entropy(move_type, true_move_type) +
                    F.cross_entropy(variation, true_variation)
                )
                
                # Combine value loss and policy loss
                total_loss = value_loss + policy_loss
                total_loss.backward()
                optimizer.step()

        print(f"epoch {epoch + 1} finished")
        model_scripted = torch.jit.script(NNmodel) # Export to TorchScript
        model_scripted.save('testExport.pt') # Save
    print("Training finished")
    
train_model(NNmodel, True, True, 10, 10)

starting epoch 1
1 player # in model
player # 1
move taken a8^a7b8
move {'piece': 14, 'square1': 56, 'square2': 48, 'square3': 57, 'type': 4, 'variant': 1, 'does_measurement': False, 'measurement_outcome': 1, 'promotion_piece': 0}
 +-------------------------------------------------+
8|   .    50:k   .     .     .     .     .     .   |
7|  50:k 100:P   .     .     .     .     .     .   |
6| 100:P   .   100:N   .     .     .     .     .   |
5|   .     .     .     .     .     .     .     .   |
4|   .     .     .     .     .     .     .     .   |
3|   .     .     .     .     .     .     .     .   |
2|   .     .     .     .     .     .     .   100:K |
1|   .     .     .     .     .     .     .     .   |
 +-------------------------------------------------+
     a     b     c     d     e     f     g     h
2 player # in model
player # 2
move taken c6b8.m1
move {'piece': 2, 'square1': 42, 'square2': 57, 'square3': 64, 'type': 2, 'variant': 3, 'does_measurement': True, 'measurement_outcome': 1, 

KeyboardInterrupt: 

In [None]:
game = QuantumChessGame()
game.new_game({'initial_state_fen':get_puzzle_fen(33),  'max_split_moves':[0,1]});
game_tensor[0] = gameToTensor(game.get_game_data(), 0)

output = NNmodel(game_tensor)
print(output)