In [None]:
!pip install chess

In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
import chess
import numpy as np
import os
import glob


class PieceWiseSelfAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(PieceWiseSelfAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads
        self.query = nn.Linear(embedding_dim, embedding_dim)
        self.key = nn.Linear(embedding_dim, embedding_dim)
        self.value = nn.Linear(embedding_dim, embedding_dim)
        self.attention_weights = None

    def forward(self, x, piece_type):
        batch_size, seq_length, _ = x.size()
        q = self.query(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        piece_mask = piece_type.unsqueeze(1).unsqueeze(2).expand_as(scores)
        scores = scores.masked_fill(piece_mask == 0, float('-inf'))
        attention_weights = torch.softmax(scores, dim=-1)
        self.attention_weights = attention_weights
        
        weighted_values = torch.matmul(attention_weights, v)
        output = weighted_values.transpose(1, 2).contiguous().view(batch_size, seq_length, self.embedding_dim)
        return output



class TransformerModel(nn.Module):
    def __init__(self, num_tokens, embedding_dim, num_heads, num_layers, hidden_dim, dropout):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(num_tokens, embedding_dim)
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        self.piece_wise_attention = PieceWiseSelfAttention(embedding_dim, num_heads)
        self.value_head = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh()
        )
        self.policy_head = nn.Sequential(
            nn.Linear(embedding_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 64),
            nn.Softmax(dim=-1)
        )

    def forward(self, x, piece_type):
        x = self.embedding(x)
        for layer in self.transformer_layers:
            x = layer(x)
        x = self.piece_wise_attention(x, piece_type)
        value = self.value_head(x[:, -1])
        policy_input = torch.cat((x[:, -1], value.view(-1, 1)), dim=-1)
        policy = self.policy_head(policy_input)
        return value, policy

    def to(self, device):
        super().to(device)
        self.embedding = self.embedding.to(device)
        self.transformer_layers = self.transformer_layers.to(device)
        self.piece_wise_attention = self.piece_wise_attention.to(device)
        self.value_head = self.value_head.to(device)
        self.policy_head = self.policy_head.to(device)
        return self

def generate_piece_moves(board, piece_type, piece_index):
    square = chess.SQUARES[piece_index]
    piece = board.piece_at(square)
    
    if piece is None or piece.piece_type != piece_type:
        return []
    
    legal_moves = []
    
    if piece_type == chess.PAWN:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    elif piece_type == chess.KNIGHT:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    elif piece_type == chess.BISHOP:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    elif piece_type == chess.ROOK:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    elif piece_type == chess.QUEEN:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    elif piece_type == chess.KING:
        for move in board.legal_moves:
            if move.from_square == square:
                legal_moves.append(move)
    
    return legal_moves

def map_policy_to_move(board, policy_output, piece_type_tensor, attention_weights, top_k=5):
    top_k_pieces = torch.topk(attention_weights.squeeze(), top_k)
    top_k_indices = top_k_pieces.indices.flatten().tolist()
    
    candidate_moves = []
    for piece_index in top_k_indices:
        if piece_index < len(piece_type_tensor):
            piece_type = piece_type_tensor[piece_index].item()
            piece_moves = generate_piece_moves(board, piece_type, piece_index)
            candidate_moves.extend(piece_moves)
    
    move_mapping = {i: move for i, move in enumerate(candidate_moves)}
    
    move_probabilities = policy_output.squeeze().tolist()
    
    best_move_index = move_probabilities.index(max(move_probabilities))
    
    best_move = move_mapping.get(best_move_index)
    
    return best_move

def tokenize_board(board):
    tokenized_board = []
    piece_type_tensor = []

    for square in chess.SQUARES:
        piece = board.piece_at(square)

        if piece:
            piece_type = piece.piece_type
            piece_color = piece.color

            if piece_color == chess.WHITE:
                token = piece_type
            else:
                token = piece_type + 6

            piece_type_tensor.append(piece_type)
        else:
            token = 0
            piece_type_tensor.append(0)

        tokenized_board.append(token)

    return tokenized_board, torch.tensor(piece_type_tensor)

def fen_to_tensor(fen):
    board = chess.Board(fen)
    tokenized_board, piece_type_tensor = tokenize_board(board)
    return torch.tensor(tokenized_board), piece_type_tensor

In [3]:


def generate_self_play_data(model, num_games, max_moves_per_game):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self_play_data = []

    for game_number in range(1, num_games + 1):
        board = chess.Board()
        game_moves = []
        game_values = []

        for move_number in range(1, max_moves_per_game + 1):
            if board.is_game_over():
                break

            fen = board.fen()
            tokenized_board_tensor, piece_type_tensor = fen_to_tensor(fen)
            tokenized_board_tensor = tokenized_board_tensor.to(device)
            piece_type_tensor = piece_type_tensor.to(device)
            
            value, policy = model(tokenized_board_tensor.unsqueeze(0), piece_type_tensor.unsqueeze(0))
            attention_weights = model.piece_wise_attention.attention_weights
            move = map_policy_to_move(board, policy, piece_type_tensor, attention_weights)

            if move is None:
                break

            board.push(move)
            game_moves.append(move)
            game_values.append(value.item())

        result = board.result()
        if result == '1-0':
            game_result = 1.0
            result_str = "White wins"
        elif result == '0-1':
            game_result = -1.0
            result_str = "Black wins"
        else:
            game_result = 0.0
            result_str = "Draw"

        self_play_data.append((game_moves, game_values, game_result))
        print(f"Epoch: 1, Game: {game_number}, Result: {result_str}")

    return self_play_data


def save_self_play_data(self_play_data, data_dir, file_prefix):
    for idx, game_data in enumerate(self_play_data, 1):
        game_moves, game_values, game_result = game_data
        data_path = os.path.join(data_dir, f"{file_prefix}_game_{idx}.txt")
        with open(data_path, "w") as file:
            for move, value in zip(game_moves, game_values):
                file.write(f"{move.uci()} {value}\n")
            file.write(f"Result: {game_result}\n")
    print(f"Self-play data saved: {len(self_play_data)} games")


def load_self_play_data(data_dir, file_prefix):
    self_play_data = []

    # Get the list of files with the specified prefix
    file_pattern = os.path.join(data_dir, f"{file_prefix}_game_*.txt")
    game_files = glob.glob(file_pattern)

    for game_file in game_files:
        with open(game_file, "r") as file:
            game_moves = []
            game_values = []
            for line in file:
                if line.startswith("Result:"):
                    game_result = float(line.split(":")[1].strip())
                else:
                    move, value = line.strip().split()
                    game_moves.append(chess.Move.from_uci(move))
                    game_values.append(float(value))
            self_play_data.append((game_moves, game_values, game_result))

    return self_play_data

def check_self_play_data_exists(data_dir, file_prefix):
    file_pattern = os.path.join(data_dir, f"{file_prefix}_game_*.txt")
    game_files = glob.glob(file_pattern)
    return len(game_files) > 0

In [6]:


def train(model, optimizer, num_epochs, batch_size, self_play_data, checkpoint_dir, data_dir, save_interval, window_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion_value = nn.MSELoss()

    def window_loss_fn(values):
        first_value = values[:, 0]
        last_value = values[:, -1]
        return torch.mean(torch.relu(first_value - last_value))

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for batch_idx in range(0, len(self_play_data), batch_size):
            batch = self_play_data[batch_idx:batch_idx+batch_size]
            boards = []
            piece_types = []
            value_targets = []

            for game_moves, game_values, game_result in batch:
                for move_idx, move in enumerate(game_moves):
                    board = chess.Board()
                    for prev_move in game_moves[:move_idx]:
                        board.push(prev_move)
                    fen = board.fen()
                    tokenized_board_tensor, piece_type_tensor = fen_to_tensor(fen)
                    boards.append(tokenized_board_tensor)
                    piece_types.append(piece_type_tensor)

                    # Calculate value targets based on the game result
                    value_target = game_result if move_idx % 2 == 0 else -game_result
                    value_targets.append(value_target)

            boards = torch.stack(boards).to(device)
            piece_types = torch.stack(piece_types).to(device)
            value_targets = torch.tensor(value_targets).unsqueeze(1).to(device)

            optimizer.zero_grad()
            values, _ = model(boards, piece_types)

            value_loss = criterion_value(values, value_targets)

            # Calculate the window loss
            window_boards = []
            window_piece_types = []
            window_values = []

            for game_moves, game_values, game_result in batch:
                for player in [0, 1]:  # Iterate over both players
                    player_moves = game_moves[player::2]  # Select moves for the current player
                    player_values = game_values[player::2]

                    for move_idx in range(len(player_moves) - window_size + 1):
                        window_moves = player_moves[move_idx:move_idx+window_size]
                        window_values.extend(player_values[move_idx:move_idx+window_size])

                        # Create board states for the window moves
                        for move in window_moves:
                            board = chess.Board()
                            for prev_move in game_moves[:move_idx*2+player]:
                                board.push(prev_move)
                            fen = board.fen()
                            tokenized_board_tensor, piece_type_tensor = fen_to_tensor(fen)
                            window_boards.append(tokenized_board_tensor)
                            window_piece_types.append(piece_type_tensor)

            window_boards = torch.stack(window_boards).to(device)
            window_piece_types = torch.stack(window_piece_types).to(device)
            window_values_tensor = torch.tensor(window_values).unsqueeze(1).to(device)

            _, window_values_pred = model(window_boards, window_piece_types)

            num_windows = window_values_pred.size(0) // window_size
            window_values_pred = window_values_pred[:num_windows * window_size].view(-1, window_size)
            window_values_tensor = window_values_tensor[:num_windows * window_size].view(-1, window_size)
            window_loss = window_loss_fn(window_values_pred)

            loss = value_loss + window_loss
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        epoch_loss /= num_batches
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.8f}, Value Loss: {value_loss:.8f}, Window Loss: {window_loss:.8f}")

        if (epoch + 1) % save_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": epoch_loss
            }, checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}")

    return model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#################################### HYPER PARAMETERS --- do not edit num_tokens #############################
model = TransformerModel(num_tokens=13, embedding_dim=128, num_heads=8, num_layers=8, hidden_dim=128, dropout=0.1)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


num_self_play_games = 1000
max_moves_per_game = 50
num_epochs = 5
batch_size = 25
checkpoint_dir = "checkpoints"
data_dir = "self_play_data"
save_interval = 5

#constant
window_size = 10

os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)



In [None]:
num_iterations = 5  # Number of times to iterate between self-play and training

for iteration in range(num_iterations):
    print(f"Iteration: {iteration + 1}")

    # Check if self-play data exists for the current iteration
    file_prefix = f"self_play_iter_{iteration + 1}"
    self_play_data_exists = check_self_play_data_exists(data_dir, file_prefix)

    if self_play_data_exists:
        # Load self-play data for the current iteration
        self_play_data = load_self_play_data(data_dir, file_prefix)
    else:
        # Generate self-play data for the current iteration
        self_play_data = generate_self_play_data(model, num_self_play_games, max_moves_per_game)
        save_self_play_data(self_play_data, data_dir, file_prefix)

    # Train the model
    trained_model = train(model, optimizer, num_epochs, batch_size, self_play_data, checkpoint_dir, data_dir, save_interval, window_size)

    # Update the model for the next iteration
    model = trained_model

Iteration: 1
Epoch [1/5], Loss: 0.01861065, Value Loss: 0.00013456, Window Loss: 0.00000000
Epoch [2/5], Loss: 0.00924258, Value Loss: 0.00001151, Window Loss: 0.00000000
Epoch [3/5], Loss: 0.00923066, Value Loss: 0.00000730, Window Loss: 0.00000000
Epoch [4/5], Loss: 0.00923347, Value Loss: 0.00000558, Window Loss: 0.00000000
Epoch [5/5], Loss: 0.00923091, Value Loss: 0.00000664, Window Loss: 0.00000000
Checkpoint saved at epoch 5
Iteration: 2
Epoch [1/5], Loss: 0.00000687, Value Loss: 0.00000705, Window Loss: 0.00000000
Epoch [2/5], Loss: 0.00000577, Value Loss: 0.00000584, Window Loss: 0.00000000
Epoch [3/5], Loss: 0.00000560, Value Loss: 0.00000550, Window Loss: 0.00000000
Epoch [4/5], Loss: 0.00000542, Value Loss: 0.00000536, Window Loss: 0.00000000
Epoch [5/5], Loss: 0.00000515, Value Loss: 0.00000509, Window Loss: 0.00000000
Checkpoint saved at epoch 5
Iteration: 3
Epoch [1/5], Loss: 0.00000983, Value Loss: 0.00000983, Window Loss: 0.00000000
Epoch [2/5], Loss: 0.00000618, Value

In [None]:
5