In [214]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import chess
import chess.pgn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader


In [215]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
def board_to_tensor(board):
    piece_map = board.piece_map()
    num_piece_planes = 12
    num_attack_planes = 12
    total_planes = num_piece_planes + 1 + num_attack_planes  # 25
    
    tensor = np.zeros((total_planes, 8, 8), dtype=np.float32)
    
    material_score = 0
    piece_values = {
        chess.PAWN: 1, chess.KNIGHT: 3, chess.BISHOP: 3,
        chess.ROOK: 5, chess.QUEEN: 9, chess.KING: 0
    }
    

    for square, piece in piece_map.items():
        piece_type = piece.piece_type
        color = int(piece.color)
        idx = (piece_type - 1) + (0 if color else 6)  
        row = 7 - (square // 8)
        col = square % 8
        tensor[idx][row][col] = 1
        
        sign = 1 if piece.color == chess.WHITE else -1
        material_score += sign * piece_values[piece_type]

    tensor[12] = np.full((8, 8), board.turn, dtype=np.float32)
    
    for piece_type in range(1, 7):  
        for color in [chess.WHITE, chess.BLACK]:
            attack_plane_idx = 13 + (piece_type -1) + (0 if color == chess.WHITE else 6)
            
            attacked_squares_set = set()
            
            for square, piece in piece_map.items():
                if piece.piece_type == piece_type and piece.color == color:
                    attacked_squares = board.attacks(square)
                    attacked_squares_set.update(attacked_squares)

            for att_sq in attacked_squares_set:
                row = 7 - (att_sq // 8)
                col = att_sq % 8
                tensor[attack_plane_idx][row][col] = 1

    material_score = material_score / 25.0 
    
    return tensor, material_score


In [None]:
import chess

CASTLING_BASE_INDEX = 4096
CASTLING_MOVE_TO_INDEX = {
    'e1g1': CASTLING_BASE_INDEX,    
    'e1c1': CASTLING_BASE_INDEX + 1,  
    'e8g8': CASTLING_BASE_INDEX + 2,  
    'e8c8': CASTLING_BASE_INDEX + 3   
}
INDEX_TO_CASTLING_MOVE = {v: k for k, v in CASTLING_MOVE_TO_INDEX.items()}

def move_to_index(move):
    uci = move.uci()
    if uci in CASTLING_MOVE_TO_INDEX:
        return CASTLING_MOVE_TO_INDEX[uci]
    else:
        return move.from_square * 64 + move.to_square

def index_to_move(index):
    if index in INDEX_TO_CASTLING_MOVE:
        uci = INDEX_TO_CASTLING_MOVE[index]
        return chess.Move.from_uci(uci)
    else:
        from_square = index // 64
        to_square = index % 64
        return chess.Move(from_square, to_square)


In [218]:
def legal_moves_mask(board):
    mask = np.zeros(4100, dtype=np.float32)
    for move in board.legal_moves:
        idx = move_to_index(move)
        mask[idx] = 1
    return mask
def result_to_value(result_str):
    if result_str == '1-0':
        return 1.0
    elif result_str == '0-1':
        return -1.0
    elif result_str == '1/2-1/2':
        return 0.0
    else:
        return 0.0  


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
    def forward(self, x):
        b, c, h, w = x.size()
        y = x.view(b, c, -1).mean(dim=2)
        y = torch.sigmoid(self.fc2(F.relu(self.fc1(y)))).view(b, c, 1, 1)
        return x * y

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.LayerNorm([channels,8,8])
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.LayerNorm([channels,8,8])
        self.se = SEBlock(channels)
        self.dropout = nn.Dropout(0.1)
    def forward(self, x):
        residual = x
        out = F.gelu(self.bn1(self.conv1(x)))  
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out = self.dropout(out)
        return F.gelu(out + residual)        

class ChessNet(nn.Module):
    def __init__(self, num_res_blocks=30):
        super().__init__()
        self.conv_spatial = nn.Conv2d(25, 64, kernel_size=3, padding=1)
        self.bn_in = nn.LayerNorm([64,8,8])
        self.res_blocks = nn.ModuleList([ResidualBlock(64) for _ in range(num_res_blocks)])

        self.embed_dim = 64
        self.pos_embed = nn.Parameter(torch.zeros(1, 64, self.embed_dim))
        nn.init.normal_(self.pos_embed, std=0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim, nhead=4, batch_first=True, dim_feedforward=128
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.policy_head = nn.Sequential(
            nn.Linear(4098, 1024), nn.GELU(),
            nn.Linear(1024, 4100) 
        )
        self.value_head = nn.Sequential(
            nn.Linear(4098, 1024), nn.GELU(),
            nn.Linear(1024, 1)
        )

    def forward(self, x, material_score=None, turn=None):
        x = F.gelu(self.bn_in(self.conv_spatial(x)))
        for block in self.res_blocks:
            x = block(x)
        x_seq = x.flatten(2).transpose(1,2)      
        x_seq = x_seq + self.pos_embed
        x_seq = self.transformer(x_seq)         
        x_flat = x_seq.flatten(1)             
        if material_score is not None and turn is not None:
            extra_features = torch.cat([
                material_score.view(-1, 1), turn.view(-1, 1)
            ], dim=1)
            x_flat = torch.cat([x_flat, extra_features], dim=1)
        policy = F.log_softmax(self.policy_head(x_flat), dim=1)
        value = (self.value_head(x_flat) ) 
        return policy, value


In [220]:
model=ChessNet().to(device)
optimizer=optim.Adam(model.parameters(),lr=0.0001)

In [221]:
print(torch.cuda.is_available())

True


In [None]:
import math
import copy
import random
class MCTSNode:
    def __init__(self,board,parent=None):
        self.board=board
        self.parent=parent
        self.children={}
        self.N={}
        self.W={}
        self.P={}
        self.is_expanded=False

def expand_node(node,model,device,dirichlet_noise=False):
    if node.board.is_game_over():
       
        result = node.board.result()
        if result == '1-0':
            return 5
        elif result == '0-1':
            return -5
        else:
            return 0
    state_tensor, material_score = board_to_tensor(node.board)
    state_tensor = torch.tensor(state_tensor, dtype=torch.float32).unsqueeze(0).to(device)
    material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
    is_white = torch.tensor([[float(node.board.turn == chess.WHITE)]], dtype=torch.float32).to(device)

    with torch.no_grad():
        log_policy,value=model(state_tensor,material_score,is_white)
        policy=torch.exp(log_policy.squeeze(0)).cpu()
    legal_moves=list(node.board.legal_moves)
    legal_uci=[move.uci() for move in legal_moves]
    total_p=0
    for move in legal_moves:
        if move.promotion is not None and move.promotion != chess.QUEEN:
            continue
        dummy=node.board.copy()
        dummy.push(move)
        if dummy.is_stalemate():
            continue
        idx=move_to_index(move)
        prob=policy[idx].item()
        node.P[move.uci()]=prob
        node.W[move.uci()]=0
        node.N[move.uci()]=0
        total_p+=prob
    for uci in node.P:
        node.P[uci]/=total_p+1e-8
    if dirichlet_noise:
        alpha=0.1
        epsilon=0.25
        noise=np.random.dirichlet([alpha]*len(legal_uci))
        for i,uci in enumerate(node.P):
            node.P[uci]=(1-epsilon)*node.P[uci]+epsilon*noise[i]
    node.is_expanded=True
    return value.item()

def ucb_score (node,move_uci,c_puct=1):
    N_total=sum(node.N.values())
    N_sa=node.N[move_uci]
    W_sa=node.W[move_uci]
    Q_sa = W_sa / (N_sa + 1)
    U = c_puct * node.P[move_uci]*math.sqrt(N_total)  / (1 + N_sa)
    return Q_sa + U

def run_mcts(root_board, model, device, num_simulations=20):
    root = MCTSNode(root_board)

    for _ in range(num_simulations):
        node = root
        path = []

        
        while node.is_expanded:

            max_ucb = -float('inf')
            best_move = None
            scored_moves=[]
            for move_uci in node.P:
                dummy=node.board.copy()
                dummy.push_uci(move_uci)
                if(dummy.is_stalemate()):
                    continue
                if(dummy.is_game_over()):
                    best_move=move_uci
                    break
                ucb = ucb_score(node, move_uci)
                scored_moves.append((ucb,move_uci))
                if ucb >max_ucb:
                    max_ucb = ucb
                    best_move = move_uci
                
            if best_move not in node.children:
                next_board = node.board.copy()
                next_board.push(chess.Move.from_uci(best_move))
                node.children[best_move] = MCTSNode(next_board, parent=node)
            
            path.append((node, best_move))
            node = node.children[best_move]
        is_root=(node is root)
        
        value = expand_node(node, model, device,dirichlet_noise=is_root)
        
        discount=1
        for parent_node, move_uci in reversed(path):
            if parent_node.board.turn==chess.WHITE:
                value=value
            else:
                value=-value
            parent_node.N[move_uci] += 1
            parent_node.W[move_uci] += value
            value = discount*value  

    
    move_visits = root.N
    total = sum(move_visits.values())
    policy = {uci: count / total for uci, count in move_visits.items()}
    return policy
def select_move_from_policy(policy, temperature=1.0):
    moves = list(policy.keys())
    visits = torch.tensor([policy[m] for m in moves], dtype=torch.float32)
    
    if temperature == 0:
        return max(policy, key=policy.get)  
    else:
        probs = (visits ** (1 / temperature)).numpy()
        probs /= probs.sum()
        return random.choices(moves, weights=probs, k=1)[0]


In [223]:
import chess
import random

import chess
import random

def generate_mate_in_one_position():
    board = chess.Board.empty()

    edge = random.choice(['a', 'h', '1', '8'])
    files = 'abcdefgh'
    ranks = '12345678'

    # Place black king
    if edge in ['a', 'h']:
        rank = random.randint(0, 7)
        file = 0 if edge == 'a' else 7
    else:
        file = random.randint(0, 7)
        rank = 0 if edge == '1' else 7

    black_king_square = chess.square(file, rank)
    board.set_piece_at(black_king_square, chess.Piece(chess.KING, chess.BLACK))

    # Place white king two squares away
    if edge in ['a', 'h']:
        white_file = file + (2 if edge == 'a' else -2)
        if not (0 <= white_file <= 7):  # Avoid going off-board
            return generate_mate_in_one_position()
        white_king_square = chess.square(white_file, rank+1) if rank<5 else chess.square(white_file,rank-1)
    else:
        white_rank = rank + (2 if edge == '1' else -2)
        if not (0 <= white_rank <= 7):
            return generate_mate_in_one_position()
        white_king_square = chess.square(file+1, white_rank) if file<5 else chess.square(file-1,white_rank)

    board.set_piece_at(white_king_square, chess.Piece(chess.KING, chess.WHITE))


    
    all_squares = list(chess.SQUARES)
    random.shuffle(all_squares)

    for square in all_squares:
        if square in [black_king_square, white_king_square]:
            continue
        if random.random()<0.5:
            board.set_piece_at(square, chess.Piece(chess.QUEEN, chess.WHITE))
        else:
            board.set_piece_at(square, chess.Piece(chess.ROOK, chess.WHITE))
        if not (board.is_attacked_by(chess.WHITE, black_king_square) or board.is_stalemate()):
           
            break
        board.remove_piece_at(square)
    random.shuffle(all_squares)
    for square in all_squares:
        if square in [black_king_square, white_king_square]:
            continue
        if random.random()<0.5:
            board.set_piece_at(square, chess.Piece(chess.PAWN, chess.BLACK))
        else:
            board.set_piece_at(square, chess.Piece(chess.KNIGHT, chess.BLACK))
        if not (board.is_attacked_by(chess.BLACK, white_king_square) or board.is_stalemate()):
            return board, square  # Found a good position
        board.remove_piece_at(square)
    raise Exception("Failed to find non-checking queen position")




def break_mate_position():
    while True:
        board, queen_square = generate_mate_in_one_position()

        # queen_moves = [move for move in board.legal_moves]

        # # Try random queen moves that don't lead to mate
        # random.shuffle(queen_moves)
        # for move in queen_moves:
        #     board_copy = board.copy()
        #     board_copy.push(move)
        #     if not board_copy.is_checkmate():
        #         board.push(move)
        return board


In [None]:
import chess
import random

def generate_ladder_mate():
    board = chess.Board.empty()
    squares = list(chess.SQUARES)
    random.shuffle(squares)
    for sq_bk in squares:
        board.set_piece_at(sq_bk, chess.Piece(chess.KING, chess.BLACK))
        break

    for sq_wk in squares:
        if sq_wk == sq_bk:
            continue
        board.set_piece_at(sq_wk, chess.Piece(chess.KING, chess.WHITE))
        if not board.is_valid():
            board.remove_piece_at(sq_wk)
            continue

        if chess.square_distance(sq_bk, sq_wk) == 1:
            board.remove_piece_at(sq_wk)
            continue

        occupied = {sq_bk, sq_wk}


        for color in (chess.BLACK, chess.WHITE):
            for _ in range(2):  # 2 pawns
                for sq in squares:
                    if sq in occupied:
                        continue
                    board.set_piece_at(sq, chess.Piece(chess.PAWN, color))
                    if board.is_valid():
                        occupied.add(sq)
                        break
                    board.remove_piece_at(sq)
            

            minor = random.choice((chess.KNIGHT, chess.BISHOP))
            for sq in squares:
                if sq in occupied:
                    continue
                board.set_piece_at(sq, chess.Piece(minor, color))
                if board.is_valid():
                    occupied.add(sq)
                    break
                board.remove_piece_at(sq)

   
        return board

        board.remove_piece_at(sq_wk)
    return board  


In [473]:
def self_play_game(model, device, num_simulations=100):
    board = chess.Board.empty()
    board=break_mate_position()
    board.turn=chess.WHITE
    n=10
    if random.random()<0.5:
        board.turn=chess.BLACK
    if random.random()<0.8:
        board=generate_ladder_mate()
        n=50
    # board=random_mate_position()
    game_data = []
    policy = run_mcts(board, model, device, num_simulations=num_simulations)
    move_probs = torch.zeros(4100)
    for uci, prob in policy.items():
        move_index = move_to_index(chess.Move.from_uci(uci))
        move_probs[move_index] = prob
    state_tensor, material_score = board_to_tensor(board)

    state_tensor = torch.tensor(state_tensor, dtype=torch.float32).to(device)
    material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
    is_white = torch.tensor([[float(board.turn == chess.WHITE)]], dtype=torch.float32).to(device)

    game_data.append((state_tensor, move_probs,is_white,material_score)) 
    while not board.is_game_over():

        if board.fullmove_number==n:
            break
        policy = run_mcts(board, model, device, num_simulations=num_simulations)
        move_probs = torch.zeros(4100)
        for uci, prob in policy.items():
            move_index = move_to_index(chess.Move.from_uci(uci))
            move_probs[move_index] = prob
        move_uci = select_move_from_policy(policy, temperature=2.5)
        move = chess.Move.from_uci(move_uci)
        # print(board)
        # print("\n")
        board.push(move)
        state_tensor, material_score = board_to_tensor(board)
        
        state_tensor = torch.tensor(state_tensor, dtype=torch.float32).to(device)
        material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
        is_white = torch.tensor([[float(board.turn == chess.WHITE)]], dtype=torch.float32).to(device)
        game_data.append((state_tensor, move_probs,is_white,material_score))  

    
    result = board.result()  
    if result == '1-0':
        outcome = 1
    elif result == '0-1':
        outcome = -1
    else:
        outcome = 0

    discount = 1
    upscale=1.01
    training_data = []
    for t, (state_tensor, move_probs, is_white,material_score) in enumerate(reversed(game_data)):
        if outcome==0 :
            discount=0.9
            discounted_z=material_score.item()
            futures = []
            x=material_score.item()
            for k in range(1, 6):
                if t + k >= len(game_data):
                    break  
                _, ___, __, future_material = game_data[t + k]
                
                delta = future_material.item() - x
                x=future_material.item()
                futures.append(discount ** k * delta)
            discounted_z += sum(futures)

        else:
            discount=0.99
            discounted_z = (discount ** t) * outcome
        training_data.append(
        (state_tensor, move_probs, torch.tensor(discounted_z, dtype=torch.float32),is_white,material_score)
    )   

    training_data.reverse()  
    print(board)
    print("Game finished with result:", result)
    
    return training_data,result

In [None]:
from torch.utils.data import TensorDataset, DataLoader

def train_model(model, optimizer, training_data, device, batch_size=128):
    model.train()

    states, policies, values, is_white_list, material_scores = zip(*training_data)
    
    states = torch.stack(states).to(device)                   
    policies = torch.stack(policies).to(device)               
    values = torch.stack(values).to(device).float()       
    is_white_tensor = torch.tensor(is_white_list, dtype=torch.float32).unsqueeze(1).to(device)  
    material_tensor = torch.tensor(material_scores, dtype=torch.float32).unsqueeze(1).to(device)  

   
    dataset = TensorDataset(states, policies, values, is_white_tensor, material_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for state_batch, policy_batch, value_batch, turn_batch, material_batch in dataloader:
        log_policy_pred, value_pred = model(state_batch, material_batch, turn_batch)

        policy_loss = -torch.sum(policy_batch * log_policy_pred) / state_batch.size(0)
        

        value_loss = F.mse_loss(value_pred.view(-1), value_batch)

        loss = policy_loss + value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [None]:

for i in range(100):
    all_training_data = []
    wins=0
    for _ in range(20):
        game_data,result = self_play_game(model, device, num_simulations=50)
        all_training_data.extend(game_data)  
        if result=='0-1'or result=='1-0':
            wins+=1
    print(wins)
    train_model(model, optimizer, all_training_data, device)
    print(f'trained!{i}')
    torch.save(model.state_dict(), "model.pth")


In [None]:
def self_play_game_test(model, device, num_simulations=200):
    board = chess.Board.empty()

# Place the pieces
    board=break_mate_position()
    board.turn=chess.WHITE
    if random.random()<0.0:
        board.turn=chess.BLACK
    if random.random()<0.5:
        board=generate_ladder_mate()
    
    game_data = []
    
    while not board.is_game_over():
        # Get MCTS policy
        print(board)
        print("\n")
        state_tensor, material_score = board_to_tensor(board)
        state_tensor = torch.tensor(state_tensor, dtype=torch.float32).unsqueeze(0).to(device)
        material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
        is_white = torch.tensor([[float(board.turn == chess.WHITE)]], dtype=torch.float32).to(device)

        with torch.no_grad():
            log_policy,value=model(state_tensor,material_score,is_white)
            policy=torch.exp(log_policy.squeeze(0)).cpu()
        print(value.item())
        print("\n")
        # if board.fullmove_number==6:
        #     break
        policy = run_mcts(board, model, device, num_simulations=num_simulations)

        # Convert policy to vector aligned with training indices
        move_probs = torch.zeros(4100)  # Replace with your move space size
        for uci, prob in policy.items():
            move_index = move_to_index(chess.Move.from_uci(uci))
            move_probs[move_index] = prob

        # Store the state, π, player
        state_tensor, material_score = board_to_tensor(board)

        state_tensor = torch.tensor(state_tensor, dtype=torch.float32).to(device)
        material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
        is_white = torch.tensor([[float(board.turn == chess.WHITE)]], dtype=torch.float32).to(device)

        game_data.append((state_tensor, move_probs,is_white,material_score))  # board.turn = True if white

        # Pick move from policy
        move_uci = select_move_from_policy(policy, temperature=0)
        move = chess.Move.from_uci(move_uci)
        board.push(move)
        
        # print("\n")
    # Game over → Assign value to each position
    result = board.result()  # '1-0', '0-1', or '1/2-1/2'
    if result == '1-0':
        outcome = 1
    elif result == '0-1':
        outcome = -1
    elif board.fullmove_number==6:
        outcome = 0.4
    else:
        outcome = 0

    discount = 1

    
    print(board)
    
    state_tensor, material_score = board_to_tensor(board)
    state_tensor = torch.tensor(state_tensor, dtype=torch.float32).unsqueeze(0).to(device)
    material_score = torch.tensor([[material_score]], dtype=torch.float32).to(device)
    is_white = torch.tensor([[float(board.turn == chess.WHITE)]], dtype=torch.float32).to(device)

    with torch.no_grad():
        log_policy,value=model(state_tensor,material_score,is_white)
        policy=torch.exp(log_policy.squeeze(0)).cpu()
    print(value.item())
    print("\n")
    print("Game finished with result:", result)
    
    return result
self_play_game_test(model,device)

In [None]:
torch.save(model.state_dict(), "model1.pth")
import os
print(os.getcwd()) 


In [None]:
model = ChessNet()
model.load_state_dict(torch.load(r"C:\Users\Jashn Khemani\Downloads\model.pth",map_location=device,weights_only=True))
model.to(device)
model.eval()  


In [None]:
state = torch.load(r"C:\Users\Jashn Khemani\Downloads\model.pth", map_location=device)
for k, v in state.items():
    print(k, v.shape)
