In [1]:
!pip install -r requirements.txt

Collecting chess==1.11.2 (from -r requirements.txt (line 1))
  Downloading chess-1.11.2.tar.gz (6.1 MB)
     ---------------------------------------- 0.0/6.1 MB ? eta -:--:--
     ------------------------------------- -- 5.8/6.1 MB 32.2 MB/s eta 0:00:01
     ---------------------------------------- 6.1/6.1 MB 28.9 MB/s eta 0:00:00
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting exceptiongroup==1.3.1 (from -r requirements.txt (line 3))
  Downloading exceptiongroup-1.3.1-py3-none-any.whl.metadata (6.7 kB)
Collecting filelock==3.20.3 (from -r requirements.txt (line 4))
  Downloading filelock-3.20.3-py3-none-any.whl.metadata (2.1 kB)
Collecting fsspec==2026.1.0 (from -r requirements.txt (line 5))
  Downloading fsspec-2026.1.0-py3-none-any.whl.metadata (10 kB)
Collecting iniconfig==2.3.0 (from -r requirements.txt (line 6))
  Downloading iniconfig-2.3.0-py3-none-any.whl.metadata (2.5 kB)
Collecting MarkupSafe==3.0.3 (from -r re

  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-intel 2.17.1 requires keras>=3.2.0, but you have keras 2.10.0 which is incompatible.
tensorflow-intel 2.17.1 requires numpy<2.0.0,>=1.23.5; python_version <= "3.11", but you have numpy 2.2.6 which is incompatible.
tensorflow-intel 2.17.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.19.6 which is incompatible.
tensorflow-intel 2.17.1 requires tensorboard<2.18,>=2.17, but you have tensorboard 2.10.1 which is incompatible.

[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


# New Section

In [1]:

import chess
import numpy as np

class ChessEnv:
    def __init__(self):
        self.board = chess.Board()

    def copy(self):
        env = ChessEnv()
        env.board = self.board.copy(stack = True)
        return env
    def legal_moves(self):
        return list(self.board.legal_moves)

    def push(self,move):
        self.board.push(move)

    def is_terminal(self):
        return self.board.is_game_over()

    def result(self):
        if not self.board.is_game_over():
            return None
        outcome = self.board.outcome()
        #print(f'outcome is {outcome} and outcome winner is {outcome.winner}')
        if outcome.winner is None:
            #print(f'returning 0')
            return 0

        # in python chess true refers to white false refers to black
        return 1 if outcome.winner else -1

    # encodes current position from perspective of current player
    def encode(self):
        '''
        AlphaZero style encoding

        Returns:
            planes: np.ndarray of shape (21, 8 ,8)

        :param self: Description
        '''

        planes = np.zeros((21,8,8), dtype = np.float32)

        # Piece planes
        # Planes 0-5: p1 pieces
        # Planes 6-11: p2 pieces
        # Piece type enums directly correspond to plane
        for square, piece in self.board.piece_map().items():
            row = 7-chess.square_rank(square)
            col = chess.square_file(square)

            if piece.color == self.board.turn:
                planes[piece.piece_type-1, row, col]=1.0
            else:
                planes[piece.piece_type+5,row,col]=1.0

        # repetition planes 12-13
        # plane 12 represents repeated once
        # plane 13 represents repeated twice

        if self.board.is_repetition(1):
            planes[12, :, :] = 1.0
        if self.board.is_repetition(2):
            planes[13, :, :] = 1.0


        # plane 14: all 1s if its white
        if self.board.turn:
            planes[14,:,:]=1.0

        # plane 15: move count normalized by /100
        planes[15,:,:] = min(1,self.board.fullmove_number/100)

        # plane 16-17: current player has castling rights
        if self.board.has_kingside_castling_rights(self.board.turn):
            planes[16,:,:]=1.0
        if self.board.has_queenside_castling_rights(self.board.turn):
            planes[17,:,:]=1.0

        # plane 18-19: opponent has castling rights
        opponent = not self.board.turn
        if self.board.has_kingside_castling_rights(opponent):
            planes[18,:,:]=1.0
        if self.board.has_queenside_castling_rights(opponent):
            planes[19,:,:]=1.0

        # plane 20: halfmove clock normalized by /100

        planes[20,:,:]= min(1,self.board.halfmove_clock/100)

        return planes












In [2]:
import chess
import numpy as np
import config

# Directions for sliding pieces

PROMOTION_PIECES = [chess.ROOK, chess.KNIGHT, chess.BISHOP]
DIRS = [
    (1,0), (-1,0), (0,1), (0,-1),
    (1,1), (1,-1), (-1,1), (-1,-1)
]

KNIGHT_DIRS = [
    (2,1), (2,-1), (-2,1), (-2,-1),
    (1,2), (1,-2), (-1,2), (-1,-2)
]

def sign(x):
    if x>0:
        return 1
    elif x<0:
        return -1
    else:
        return 0



def move_to_index(move: chess.Move) -> int:
    # maps a chess move to index in [0,4671]
    from_sq = move.from_square
    to_sq = move.to_square

    fx,fy = chess.square_file(from_sq), chess.square_rank(from_sq)
    tx,ty = chess.square_file(to_sq), chess.square_rank(to_sq)

    dx = tx-fx
    dy = ty-fy

    if move.promotion and move.promotion != chess.QUEEN:
        direction = dx+1
        promo_idx = PROMOTION_PIECES.index(move.promotion)
        plane = 64+direction*3+promo_idx
    elif (dx,dy) in KNIGHT_DIRS:
        plane = 56 + KNIGHT_DIRS.index((dx,dy))
    else:
        step_dx = 0 if dx==0 else dx//abs(dx)
        step_dy = 0 if dy==0 else dy//abs(dy)

        direction = DIRS.index((step_dx,step_dy))
        dist = max(abs(dx),abs(dy))
        plane = direction*7 + (dist-1)

    return from_sq * 73 + plane

def index_to_move(index: int, board: chess.Board) -> chess.Move:
    from_sq = index // 73
    plane = index %73

    #print(f'plane is {plane} square is {from_sq} which is {chess.square_name(from_sq)}')

    piece = board.piece_at(from_sq)
    if piece is None:
        #print(f'No piece at from_sq {chess.square_name(from_sq)}')
        return None
    #print(f'current piece at from_sq is {piece} color is {("W" if piece and piece.color==chess.WHITE else "B" if piece else "None")}')
    fx,fy = chess.square_file(from_sq), chess.square_rank(from_sq)

    if plane>=64:
        direction = (plane-64)//3
        promo_idx = (plane-64)%3
        #print(f'underpromotion direction is {direction} promo idx is {promo_idx} which corresponds to {PROMOTION_PIECES[promo_idx]}')
        dx = direction-1
        dy = 1 if piece.color == chess.WHITE else -1

        to_sq = chess.square(fx+dx,fy + dy)
        #print(f'dx is {dx} dy is {dy} to_sq is {fx+dx},{fy + dy} which is {chess.square_name(chess.square(fx+dx,fy + dy))}')
        promotion = PROMOTION_PIECES[promo_idx]
        if fx+dx<0 or fx+dx>7 or fy+dy<0 or fy+dy>7 or piece is None or piece.piece_type != 1:
            #print(f'out of bounds or invalid')
            return None
        return chess.Move(from_sq, to_sq, promotion=promotion)
    elif plane>=56:
        knight_idx = plane-56
        dx,dy = KNIGHT_DIRS[knight_idx]
        to_sq = chess.square(fx+dx,fy+dy)
        if fx+dx<0 or fx+dx>7 or fy+dy<0 or fy+dy>7:
            return None
        return chess.Move(from_sq,to_sq)
    else:
        direction = plane//7
        dist = (plane%7)+1

        step_dx,step_dy = DIRS[direction]
        dx = step_dx * dist
        dy = step_dy * dist

        #print(f'direction is {direction} which is {step_dx}, {step_dy} dist is {dist} resulting in dx {dx} dy {dy}')

        to_rank = fy+dy
        #print(f'to_rank is {to_rank} piece is {piece} piecetype is {piece.piece_type if piece else "None"} which corresponds to {chess.piece_name(piece.piece_type) if piece else "None"}')
        if piece is not None and piece.piece_type == 1 and (to_rank==0 or to_rank==7):
            promotion = chess.QUEEN
        else:
            promotion = None
        #print(f'promotion is {promotion}')

        # check if last rank and pawn move for queen promotion
        to_sq = chess.square(fx+dx,fy+dy)
        if fx+dx<0 or fx+dx>7 or fy+dy<0 or fy+dy>7:
            return None
        return chess.Move(from_sq,to_sq,promotion=promotion)




In [3]:


import config
import torch

class Node:
    def __init__(self,env:ChessEnv, parent: "Node" = None, parent_move = None, prior = 0):
        self.env = env
        self.parent = parent
        self.parent_move = parent_move

        self.children = {}
        self.untried_moves = env.legal_moves()

        np.random.shuffle(self.untried_moves)

        self.N = 0
        self.W = 0.0
        self.prior = prior




    def is_fully_expanded(self):
        return len(self.untried_moves)==0 and len(self.children)>0
    def is_expanded(self):
        return len(self.children)>0

    '''
    def get_ucb(self, child):
        if child.N == 0:
            return float('inf')
        q_value = 1 - ((child.W/child.N)+1)/2
        return q_value + config.UCB_C * np.sqrt(np.log(self.N) / child.N)
    '''
    def get_ucb(self, child):
        if child.N == 0:
            return float('inf')
        return (child.W / child.N) + config.UCB_C * np.sqrt(np.log(self.N) / child.N)

    def get_puct(self, child):
        q_value = 0 if child.N==0 else 1 - ((child.W/child.N)+1)/2
        u_value = config.PUCT_C * child.prior * (np.sqrt(self.N) / (1+child.N))
        return q_value + u_value

    def expand_random(self):
        action = self.untried_moves[-1]
        self.untried_moves.pop()

        child_state = self.env.copy()
        child_state.push(action)
        child = Node(child_state, parent = self, parent_move = action)
        self.children[action] = child
        return child

    def expand(self, model, device):
        planes = self.env.encode()
        x = torch.tensor(planes, dtype = torch.float32).unsqueeze(0).to(device)

        with torch.no_grad():
            policy, value = model(x)
        policy = policy.squeeze(0).cpu().numpy()
        value = value.item()
        legal_moves = self.env.legal_moves()

        priors = []
        total_prior = 0.0

        for move in legal_moves:
            idx = move_to_index(move)
            priors.append(policy[idx])
        priors = np.array(priors)
        priors = np.maximum(priors,1e-10)
        priors /= np.sum(priors)

        for move,prior in zip(legal_moves,priors):
            next_env = self.env.copy()
            next_env.push(move)
            self.children[move] = Node(next_env,parent=self,parent_move=move,prior=prior)
        return value

    def select_random(self):
        best_child = None
        best_ucb = -np.inf

        for child_action, child in self.children.items():
            ucb = self.get_ucb(child)
            if ucb > best_ucb:
                best_child = child
                best_ucb = ucb
        return best_child
    def select(self):
        best_child = None
        best_puct = -np.inf
        for child_action, child in self.children.items():
            puct = self.get_puct(child)
            if puct > best_puct:
                best_child = child
                best_puct = puct
        return best_child

    def simulate(self):
        winner = self.env.result()

        rollout_state = self.env.copy()
        initial_player = 1 if self.env.board.turn else -1
        #print(f'initial player is initally {initial_player} board is \n {rollout_state.board}')

        val = None

        while True:
            if winner is not None:
                #print(f'simulate board is now\n {rollout_state.board} winner is {winner}')
                if winner==initial_player:
                    val= 1
                elif winner==0:
                    val= 0
                else:
                    val = -1
                #print(f'returning simulate function {val}')
                return val
            valid_moves=  rollout_state.legal_moves()
            action = np.random.choice(valid_moves)
            rollout_state.push(action)
            winner = rollout_state.result()


    def backpropagate(self, value):
        self.W+=value
        self.N+=1
        #print(f'propagating state {self.env.board} with val {value}')
        if self.parent is not None:
            self.parent.backpropagate(value*-1)


class MCTS:
    def __init__(self,model = None, device = 'cpu'):
        self.model = model
        self.device = device
        pass

    def search(self,state):
        root = Node(state)

        for _ in range(config.NUM_SEARCHES):
            node = root

            while node.is_fully_expanded():
                node = node.select_random()

            val = node.env.result()
            if val is None:
                if node.untried_moves:
                    node = node.expand_random()
                    val = node.simulate()
            else:
                pass
            node.backpropagate(val)

        action_probs = {}
        total = 0
        for child_action, child in root.children.items():
            action_probs[child_action]=child.N
            total+=child.N

        for key in action_probs:
            action_probs[key]/=total

        return action_probs


class AlphaMCTS:
    def __init__(self,model, device = 'cpu'):
        self.model = model
        self.device = device

    @torch.no_grad()
    def search(self, state:ChessEnv):
        root = Node(state)

        # expand once to get priors

        value = root.expand(self.model,self.device)

        # dirichlet noise
        legal_moves = list(root.children.keys())
        noise = np.random.dirichlet([config.DIRICHLET_ALPHA]*len(legal_moves))

        for move,n in zip(legal_moves,noise):
            child = root.children[move]
            child.prior = (1-config.DIRICHLET_EPSILON)*child.prior+config.DIRICHLET_EPSILON*n

        root.backpropagate(value)

        for _ in range(config.NUM_SEARCHES-1):
            node = root

            while node.is_expanded() and not node.env.is_terminal():
                node = node.select()
            value = node.env.result()
            if value is None:
                value = node.expand(self.model,self.device)
            node.backpropagate(value)
        action_probs = {}
        total_visits = sum([child.N for child in root.children.values()])

        for child_action, child in root.children.items():
            action_probs[child_action] = 0 if total_visits==0 else child.N/total_visits

        return action_probs
















In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None, kernel_size=3, padding=1):
        super().__init__()
        outchannels = out_channels or config.NUM_CHANNELS
        self.conv = nn.Conv2d(in_channels,out_channels, kernel_size=kernel_size, padding = padding)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x


class ResBlock(nn.Module):
    def __init__(self,channels = None):
        super().__init__()
        channels = channels or config.NUM_CHANNELS
        self.conv1 = nn.Conv2d(channels,channels,3,padding=1,bias = False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels,channels,3,padding=1,bias=False)
        self.bn2=nn.BatchNorm2d(channels)
    def forward(self,x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out+=residual
        return F.relu(out)

class PolicyHead(nn.Module):
    def __init__(self,in_channels = None, action_size = None):
        super().__init__()
        in_channels = in_channels or config.NUM_CHANNELS
        action_size = action_size or config.ACTION_SIZE

        self.conv = nn.Conv2d(in_channels, config.POLICY_HEAD_CHANNELS, kernel_size=1)
        self.bn = nn.BatchNorm2d(config.POLICY_HEAD_CHANNELS)
        self.fc = nn.Linear(config.POLICY_HEAD_CHANNELS*8*8, action_size)

    def forward(self,x,legal_mask = None):
        x = F.relu(self.bn(self.conv(x)))
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        if legal_mask is not None:
            x = x.masked_fill(legal_mask==0, float('-inf'))
        x = F.softmax(x, dim=1)
        return x

class ValueHead(nn.Module):
    def __init__(self,in_channels = None):
        super().__init__()
        in_channels = in_channels or config.NUM_CHANNELS

        self.conv = nn.Conv2d(in_channels, config.VALUE_HEAD_CHANNELS, kernel_size=1)
        self.bn = nn.BatchNorm2d(config.VALUE_HEAD_CHANNELS)
        self.fc1 = nn.Linear(config.VALUE_HEAD_CHANNELS*8*8, 256)
        self.fc2 = nn.Linear(256,1)
    def forward(self,x):
        x = F.relu(self.bn(self.conv(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x

class ChessModel(nn.Module):
    def __init__(self, in_channels = 21, num_res_blocks = None, action_size = None):
        super().__init__()
        num_res_blocks = num_res_blocks or config.NUM_RES_BLOCKS
        action_size = action_size or config.ACTION_SIZE

        self.conv_block = ConvBlock(in_channels, config.NUM_CHANNELS)

        self.res_blocks = nn.ModuleList(
            [ResBlock(config.NUM_CHANNELS) for _ in range(num_res_blocks)]
        )

        self.policy_head = PolicyHead(config.NUM_CHANNELS, action_size)
        self.value_head = ValueHead(config.NUM_CHANNELS)

    def forward(self,x, legal_mask = None):
        x = self.conv_block(x)
        for res_block in self.res_blocks:
            x = res_block(x)

        policy = self.policy_head(x, legal_mask)
        value = self.value_head(x)

        return policy, value



In [5]:
class AlphaZero:
    def __init__(self,model,optimizer):
        self.model = model
        self.optimizer = optimizer
        self.mcts = AlphaMCTS(model)
        self.device = next(model.parameters()).device

    @staticmethod
    def pi_to_vector(pi_dict):
        target = np.zeros(config.ACTION_SIZE,dtype=np.float32)
        for move, prob in pi_dict.items():
            index = move_to_index(move)
            target[index]=prob
        return target

    def selfPlay(self):
        memory = []
        env = ChessEnv()
        move_count = 0

        while not env.is_terminal():
            # mcts
            pi = self.mcts.search(env)
            memory.append([env.encode(),self.pi_to_vector(pi)])

            # temperature
            tau = 1 if move_count<15 else 0

            if tau==0:
                action = max(pi, key = pi.get)
            else:
                moves = list(pi.keys())
                probs = np.array([pi[m] for m in moves])
                probs = probs / probs.sum()
                action = np.random.choice(moves, p = probs)

            env.push(action)
            move_count+=1
        z = env.result()

        for entry in memory:
            entry.append(z)
            z = -z
        return memory
    '''
    def train(self,memory):
        np.random.shuffle(memory)
        for startInd in range(0,len(memory),config.BATCH_SIZE):
            sample = memory[startInd:min(startInd+config.BATCH_SIZE,len(memory))]
            states, policy_targets, value_targets = zip(*sample)

            states, policy_targets, value_targets = np.array(states), np.array(policy_targets), np.array(value_targets)

            states = torch.tensor(states, dtype=torch.float32).to(self.device)
            policy_targets = torch.tensor(policy_targets, dtype = torch.float32).to(self.device)
            value_targets = torch.tensor(value_targets,dtype = torch.float32).unsqueeze(1).to(self.device)

            out_policy,out_value = self.model(states)

            legal_mask = (policy_targets>0).float()
            out_policy = out_policy * legal_mask
            out_policy = out_policy / (out_policy.sum(dim=1,keepdim=True)+1e-10)

            policy_loss = -torch.mean(
                torch.sum(policy_targets * torch.log(out_policy+1e-8),dim=1)
            )
            value_loss = F.mse_loss(out_value, value_targets)

            loss = policy_loss + value_loss
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),5.0)
            self.optimizer.step()
    '''
    def train(self, memory):
      np.random.shuffle(memory)

      total_loss = 0.0
      total_policy = 0.0
      total_value = 0.0
      total_entropy = 0.0
      batches = 0

      for startInd in range(0, len(memory), config.BATCH_SIZE):
          sample = memory[startInd:startInd + config.BATCH_SIZE]
          states, policy_targets, value_targets = zip(*sample)

          states = torch.tensor(np.array(states), dtype=torch.float32).to(self.device)
          policy_targets = torch.tensor(np.array(policy_targets), dtype=torch.float32).to(self.device)
          value_targets = torch.tensor(np.array(value_targets), dtype=torch.float32).unsqueeze(1).to(self.device)

          out_policy, out_value = self.model(states)

          legal_mask = (policy_targets > 0).float()
          out_policy = out_policy * legal_mask
          out_policy = out_policy / (out_policy.sum(dim=1, keepdim=True) + 1e-8)

          policy_loss = -torch.mean(
              torch.sum(policy_targets * torch.log(out_policy + 1e-8), dim=1)
          )
          value_loss = F.mse_loss(out_value, value_targets)

          entropy = -torch.mean(
              torch.sum(out_policy * torch.log(out_policy + 1e-8), dim=1)
          )

          loss = policy_loss + value_loss

          self.optimizer.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
          self.optimizer.step()

          total_loss += loss.item()
          total_policy += policy_loss.item()
          total_value += value_loss.item()
          total_entropy += entropy.item()
          batches += 1

      return {
          "loss": total_loss / batches,
          "policy": total_policy / batches,
          "value": total_value / batches,
          "entropy": total_entropy / batches,
      }



    '''
    def learn(self):
        for iteration in range(config.TRAIN_ITERATIONS):
            memory = []

            self.model.eval()

            # self play
            for selfplay_iter in range(config.SELF_PLAY_ITERATIONS):
                memory+=self.selfPlay()

            # train
            self.model.train()

            for epoch in range(config.NUM_EPOCHES):
                self.train(memory)

            torch.save(self.model.state_dict(), f'model_{iteration}.pt')
            torch.save(self.optimizer.state_dict(), f'optimizer_{iteration}.pt')
    '''
    def learn(self):
      for iteration in range(config.TRAIN_ITERATIONS):
          memory = []

          self.model.eval()
          for _ in range(config.SELF_PLAY_ITERATIONS):
              memory += self.selfPlay()

          self.model.train()

          for epoch in range(config.NUM_EPOCHES):
              stats = self.train(memory)
              print(f'Iter {iteration} | Epoch {epoch}')
              print(f'Loss {stats["loss"]:.4f}')
              print(f'Policy {stats["policy"]:.4f}')
              print(f'Value {stats["value"]:.4f}')
              print(f'Entropy {stats["entropy"]:.4f}')


          torch.save(self.model.state_dict(), f"model_{iteration}.pt")
          torch.save(self.optimizer.state_dict(), f"optimizer_{iteration}.pt")


In [6]:
import logging

logging.basicConfig(
    #filename="train.log",
    #filemode="a",
    level=logging.INFO,
    format="%(asctime)s | %(message)s"
)


In [None]:
model = ChessModel()
optimizer = torch.optim.Adam(model.parameters(), lr = config.LEARNING_RATE)

alpha_jerry = AlphaZero(model, optimizer = optimizer)

alpha_jerry.learn()

Iter 0 | Epoch 0
Loss 2.1161
Policy 1.9194
Value 0.1967
Entropy 1.8889
