# Chess CNN implementation

In [3]:
import chess
import chess.pgn
import torch
import os
import shutil
import numpy as np

In [5]:
# ==========================================
# CONFIGURATION
# ==========================================
PGN_FILE = "magnus_blitz.pgn"
OUTPUT_FOLDER = "data_chunks"
CHUNK_SIZE = 5000  # Save a file every 5,000 positions (keeps RAM usage tiny)
MAGNUS_NAME = "DrNykterstein" # Lichess username for Magnus Carlsen

# ==========================================
# TENSOR HELPERS
# ==========================================
PIECE_TO_LAYER = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((14, 8, 8), dtype=np.float32)
    
    # 1. Pieces
    for square, piece in board.piece_map().items():
        row, col = divmod(square, 8)
        layer = PIECE_TO_LAYER[piece.symbol()]
        tensor[layer, row, col] = 1.0

    # 2. Turn (Channel 12)
    if board.turn == chess.BLACK:
        tensor[12, :, :] = 1.0
        
    # 3. Valid Mask (Channel 13)
    tensor[13, :, :] = 1.0
    return tensor

# ==========================================
# MAIN PROCESSING LOOP
# ==========================================
def process_pgn():
    # 1. Setup
    if os.path.exists(OUTPUT_FOLDER):
        shutil.rmtree(OUTPUT_FOLDER) # Clear old data
    os.makedirs(OUTPUT_FOLDER)
    
    print(f"Reading {PGN_FILE}...")
    pgn = open(PGN_FILE)
    
    # Buffers to hold data in RAM before dumping to disk
    current_X = []
    current_from = []
    current_to = []
    
    games_processed = 0
    total_positions = 0
    chunk_index = 0

    while True:
        # Read one game from the file
        game = chess.pgn.read_game(pgn)
        if game is None:
            break # End of file
            
        # 2. Identify Magnus
        white_player = game.headers.get("White", "?")
        black_player = game.headers.get("Black", "?")
        
        magnus_color = None
        if MAGNUS_NAME in white_player:
            magnus_color = chess.WHITE
        elif MAGNUS_NAME in black_player:
            magnus_color = chess.BLACK
        
        # Skip games where Magnus isn't playing (just in case)
        if magnus_color is None:
            continue

        # 3. Play through the game
        board = game.board()
        for move in game.mainline_moves():
            
            # ONLY save if it's Magnus's turn
            if board.turn == magnus_color:
                current_X.append(board_to_tensor(board))
                current_from.append(move.from_square)
                current_to.append(move.to_square)
                total_positions += 1

            # Make the move on the board
            board.push(move)

        games_processed += 1
        if games_processed % 100 == 0:
            print(f"Games: {games_processed} | Positions Saved: {total_positions}")

        # 4. Save Chunk if Buffer is Full
        if len(current_X) >= CHUNK_SIZE:
            save_chunk(current_X, current_from, current_to, chunk_index)
            chunk_index += 1
            # CLEAR RAM
            current_X = []
            current_from = []
            current_to = []

    # Save any remaining data
    if len(current_X) > 0:
        save_chunk(current_X, current_from, current_to, chunk_index)

    print(f"\nDONE! Processed {games_processed} games.")
    print(f"Total positions: {total_positions}")
    print(f"Saved to folder: {OUTPUT_FOLDER}/")

def save_chunk(X, y_from, y_to, index):
    filename = os.path.join(OUTPUT_FOLDER, f"chunk_{index}.pt")
    print(f"  -> Saving chunk {index} ({len(X)} positions)...")
    
    # Convert to Tensors
    X_t = torch.tensor(np.array(X, dtype=np.float32))
    y_f_t = torch.tensor(np.array(y_from), dtype=torch.long)
    y_t_t = torch.tensor(np.array(y_to), dtype=torch.long)
    
    # Save to disk
    torch.save({
        "X": X_t,
        "y_from": y_f_t,
        "y_to": y_t_t
    }, filename)

if __name__ == "__main__":
    process_pgn()

Reading magnus_blitz.pgn...
Games: 100 | Positions Saved: 4502
  -> Saving chunk 0 (5017 positions)...
Games: 200 | Positions Saved: 8752
  -> Saving chunk 1 (5038 positions)...
Games: 300 | Positions Saved: 13353
  -> Saving chunk 2 (5002 positions)...
Games: 400 | Positions Saved: 17541
  -> Saving chunk 3 (5017 positions)...
Games: 500 | Positions Saved: 22199
  -> Saving chunk 4 (5027 positions)...
Games: 600 | Positions Saved: 26445
  -> Saving chunk 5 (1546 positions)...

DONE! Processed 607 games.
Total positions: 26647
Saved to folder: data_chunks/


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import os

# ==========================================
# CONFIGURATION
# ==========================================
# Apple Silicon Acceleration
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 10 
LEARNING_RATE = 0.001
CHUNKS_DIR = "data_chunks"

print(f"Using Device: {DEVICE}")

# ==========================================
# 1. LOAD THE CHUNKS
# ==========================================
def load_all_chunks():
    print(f"Loading data from {CHUNKS_DIR}...")
    all_X = []
    all_from = []
    all_to = []
    
    # Sort files so we load chunk_0, chunk_1, etc. in order
    files = sorted([f for f in os.listdir(CHUNKS_DIR) if f.endswith(".pt")])
    
    if not files:
        print("ERROR: No data found! Run step 1 first.")
        exit()

    for f in files:
        path = os.path.join(CHUNKS_DIR, f)
        data = torch.load(path)
        all_X.append(data["X"])
        all_from.append(data["y_from"])
        all_to.append(data["y_to"])
        print(f"  Loaded {f}")
        
    # Combine into one big Tensor
    X = torch.cat(all_X)
    y_from = torch.cat(all_from)
    y_to = torch.cat(all_to)
    
    print(f"Total Training Data: {len(X)} positions")
    return X, y_from, y_to

# Load Data
X_train, y_from_train, y_to_train = load_all_chunks()
dataset = TensorDataset(X_train, y_from_train, y_to_train)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# ==========================================
# 2. THE MODEL (Simple CNN)
# ==========================================
class MagnusNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: 14 channels (Pieces + Turn + Mask)
        self.features = nn.Sequential(
            # Layer 1: Look for simple patterns (3x3)
            nn.Conv2d(14, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            # Layer 2: Look for complex patterns
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            # Layer 3: Abstract concepts
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        # 128 channels * 8x8 board = 8192 features
        self.head_from = nn.Linear(128 * 8 * 8, 64) # Which piece to move?
        self.head_to = nn.Linear(128 * 8 * 8, 64)   # Where to go?

    def forward(self, x):
        features = self.features(x)
        return self.head_from(features), self.head_to(features)

# ==========================================
# 3. THE TRAINING LOOP
# ==========================================
model = MagnusNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

print("\nStarting Training...")
print("Goal: Loss should drop from ~8.5 to ~3.0")

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    correct_moves = 0
    total_moves = 0
    
    for X, y_f, y_t in loader:
        # Move to GPU
        X, y_f, y_t = X.to(DEVICE), y_f.to(DEVICE), y_t.to(DEVICE)
        
        # 1. Reset Gradients
        optimizer.zero_grad()
        
        # 2. Forward Pass
        pred_f, pred_t = model(X)
        
        # 3. Calculate Loss (From + To)
        loss = criterion(pred_f, y_f) + criterion(pred_t, y_t)
        
        # 4. Backward Pass (Learn)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # (Optional) Calculate Accuracy roughly
        # This is strictly "Did it pick the EXACT same square?"
        # It's hard to get high accuracy on real data, so don't worry if it's low.
        
    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f}")

print("\nTraining Complete!")
print("Saving model to 'magnus_cnn.pth'...")
torch.save(model.state_dict(), "magnus_cnn.pth")

Using Device: mps
Loading data from data_chunks...
  Loaded chunk_0.pt
  Loaded chunk_1.pt
  Loaded chunk_2.pt
  Loaded chunk_3.pt
  Loaded chunk_4.pt
  Loaded chunk_5.pt
Total Training Data: 26647 positions

Starting Training...
Goal: Loss should drop from ~8.5 to ~3.0
Epoch 1/10 | Loss: 7.2514
Epoch 2/10 | Loss: 5.6015
Epoch 3/10 | Loss: 4.6849
Epoch 4/10 | Loss: 4.1539
Epoch 5/10 | Loss: 3.6392
Epoch 6/10 | Loss: 3.1365
Epoch 7/10 | Loss: 2.6277
Epoch 8/10 | Loss: 2.1304
Epoch 9/10 | Loss: 1.6878
Epoch 10/10 | Loss: 1.3140

Training Complete!
Saving model to 'magnus_cnn.pth'...


In [None]:
import pygame
import chess
import torch
import torch.nn as nn
import numpy as np
import time

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_FILE = "magnus_cnn.pth"
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# GUI Settings
WIDTH, HEIGHT = 600, 600
SQUARE_SIZE = WIDTH // 8
LIGHT_COLOR = (240, 217, 181)
DARK_COLOR = (181, 136, 99)
HIGHLIGHT_COLOR = (186, 202, 68)

# ==========================================
# 2. MODEL & HELPERS (Must match EXACTLY)
# ==========================================
PIECE_TO_LAYER = {
    'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
    'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
}

def board_to_tensor(board):
    tensor = np.zeros((14, 8, 8), dtype=np.float32)
    for square, piece in board.piece_map().items():
        row, col = divmod(square, 8)
        layer = PIECE_TO_LAYER[piece.symbol()]
        tensor[layer, row, col] = 1.0
    if board.turn == chess.BLACK:
        tensor[12, :, :] = 1.0
    tensor[13, :, :] = 1.0
    return torch.tensor(tensor).unsqueeze(0)

class MagnusNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(14, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.head_from = nn.Linear(128 * 8 * 8, 64)
        self.head_to = nn.Linear(128 * 8 * 8, 64)

    def forward(self, x):
        features = self.features(x)
        return self.head_from(features), self.head_to(features)

# Load Model
print("Loading MagnusNet...")
model = MagnusNet().to(DEVICE)
try:
    model.load_state_dict(torch.load(MODEL_FILE, map_location=DEVICE))
    model.eval()
except FileNotFoundError:
    print("Error: MODEL_FILE not found!")
    exit()

def get_ai_move(board):
    tensor = board_to_tensor(board).to(DEVICE)
    with torch.no_grad():
        logits_from, logits_to = model(tensor)
    
    best_move = None
    best_score = -9999
    
    legal_moves = list(board.legal_moves)
    if not legal_moves: return None
    
    for move in legal_moves:
        score = logits_from[0, move.from_square].item() + logits_to[0, move.to_square].item()
        if score > best_score:
            best_score = score
            best_move = move
    return best_move

# ==========================================
# 3. GUI FUNCTIONS
# ==========================================
def draw_board(screen, board, selected_square=None):
    # Draw Squares
    for r in range(8):
        for c in range(8):
            color = LIGHT_COLOR if (r + c) % 2 == 0 else DARK_COLOR
            # Highlight selected square
            if selected_square is not None:
                sel_r, sel_c = divmod(selected_square, 8)
                # Flip row for drawing (chess rank 0 is bottom, pygame y=0 is top)
                if c == sel_c and (7-r) == sel_r: # chess.A1 is (0,0) -> Pygame (0, 7)
                    color = HIGHLIGHT_COLOR
            
            pygame.draw.rect(screen, color, (c*SQUARE_SIZE, r*SQUARE_SIZE, SQUARE_SIZE, SQUARE_SIZE))

def draw_pieces(screen, board, font, dragging_piece=None, mouse_pos=None):
    # Mapping Unicode pieces
    unicode_pieces = {
        'R': '♖', 'N': '♘', 'B': '♗', 'Q': '♕', 'K': '♔', 'P': '♙',
        'r': '♜', 'n': '♞', 'b': '♝', 'q': '♛', 'k': '♚', 'p': '♟'
    }
    
    for square in range(64):
        piece = board.piece_at(square)
        if piece:
            # If we are dragging this piece, skip drawing it on the square
            if dragging_piece and dragging_piece['square'] == square:
                continue
                
            symbol = unicode_pieces[piece.symbol()]
            color = (0, 0, 0) if piece.color == chess.BLACK else (255, 255, 255)
            
            # Text Render
            text = font.render(symbol, True, color)
            
            # Coordinates
            row, col = divmod(square, 8)  # <--- The Fix (Row first, Col second)
            # Flip row (Chess rank 0 is bottom)
            draw_y = (7 - row) * SQUARE_SIZE 
            draw_x = col * SQUARE_SIZE
            
            # Center the text
            text_rect = text.get_rect(center=(draw_x + SQUARE_SIZE//2, draw_y + SQUARE_SIZE//2))
            screen.blit(text, text_rect)
            
    # Draw the Dragging Piece LAST (so it floats on top)
    if dragging_piece:
        piece = dragging_piece['piece']
        symbol = unicode_pieces[piece.symbol()]
        color = (0, 0, 0) if piece.color == chess.BLACK else (255, 255, 255)
        text = font.render(symbol, True, color)
        text_rect = text.get_rect(center=mouse_pos)
        screen.blit(text, text_rect)

def get_square_under_mouse(pos):
    x, y = pos
    col = x // SQUARE_SIZE
    row = 7 - (y // SQUARE_SIZE) # Flip Y coordinate
    if 0 <= col < 8 and 0 <= row < 8:
        return chess.square(col, row)
    return None

# ==========================================
# 4. MAIN LOOP
# ==========================================
def main():
    pygame.init()
    screen = pygame.display.set_mode((WIDTH, HEIGHT))
    pygame.display.set_caption("MagnusNet (Week 1)")
    
    # Try to find a font that supports Chess Symbols
    # 'Segoe UI Symbol' works on Windows, 'Apple Symbols' on Mac
    font_name = 'Apple Symbols' 
    font = pygame.font.SysFont(font_name, SQUARE_SIZE - 10)

    board = chess.Board()
    clock = pygame.time.Clock()
    
    running = True
    dragging_piece = None # { 'piece': chess.Piece, 'square': int }
    
    while running:
        clock.tick(60)
        screen.fill((0, 0, 0))
        
        # 1. Event Handling
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
                
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if board.turn == chess.WHITE: # Only allow human to drag White
                    square = get_square_under_mouse(event.pos)
                    if square is not None:
                        piece = board.piece_at(square)
                        if piece and piece.color == chess.WHITE:
                            dragging_piece = {'piece': piece, 'square': square}
            
            elif event.type == pygame.MOUSEBUTTONUP:
                if dragging_piece:
                    # Try to drop
                    target_square = get_square_under_mouse(event.pos)
                    if target_square is not None:
                        move = chess.Move(dragging_piece['square'], target_square)
                        
                        # Check for promotion (simple auto-queen)
                        if dragging_piece['piece'].piece_type == chess.PAWN and chess.square_rank(target_square) == 7:
                            move = chess.Move(dragging_piece['square'], target_square, promotion=chess.QUEEN)

                        if move in board.legal_moves:
                            board.push(move)
                        else:
                            print("Illegal move!")
                            
                    dragging_piece = None # Stop dragging
        
        # 2. Draw Everything
        draw_board(screen, board, selected_square=dragging_piece['square'] if dragging_piece else None)
        draw_pieces(screen, board, font, dragging_piece, pygame.mouse.get_pos())
        
        pygame.display.flip()

        # 3. AI Turn (Logic triggers AFTER draw update)
        if board.turn == chess.BLACK and not board.is_game_over():
            # Small delay so it feels like thinking
            pygame.time.delay(100) 
            
            # Get AI Move
            move = get_ai_move(board)
            if move:
                print(f"MagnusNet plays: {move}")
                board.push(move)
            else:
                print("Game Over")

    pygame.quit()

if __name__ == "__main__":
    main()

Loading MagnusNet...
MagnusNet plays: g8f6
MagnusNet plays: g7g6
MagnusNet plays: f8g7
MagnusNet plays: e8g8
MagnusNet plays: d7d6
Illegal move!
MagnusNet plays: b8a6
MagnusNet plays: a6c5
Illegal move!
Illegal move!
MagnusNet plays: d6c5
Illegal move!
MagnusNet plays: c8g4
MagnusNet plays: d8d7
MagnusNet plays: f8c8
MagnusNet plays: e7f6
MagnusNet plays: d7c6
Illegal move!
Illegal move!
Illegal move!
Illegal move!
Illegal move!
MagnusNet plays: g4f3
MagnusNet plays: c8d8
Illegal move!
MagnusNet plays: g7f8
Illegal move!
Illegal move!
MagnusNet plays: c5b4
Illegal move!
MagnusNet plays: f8d6
MagnusNet plays: d6c5
Illegal move!
Illegal move!
MagnusNet plays: a7a5
MagnusNet plays: g8g7
MagnusNet plays: c5b6
MagnusNet plays: b7c6
Illegal move!
Illegal move!
MagnusNet plays: d8d6
MagnusNet plays: a8e8
Illegal move!
MagnusNet plays: c7d6
MagnusNet plays: e8e5
MagnusNet plays: e5e6
MagnusNet plays: e6e5
MagnusNet plays: e5f5
MagnusNet plays: f5e5
MagnusNet plays: e5e6
Illegal move!
MagnusNet

: 