In [46]:
import torch

In [47]:
import chess
import torch
from chess.polyglot import zobrist_hash
from chess import Board, Move
import random
from IPython.display import SVG, clear_output, Image
import time
import numpy as np
from numpy import float64, int64
from typing import Dict, Tuple, Sequence, Self, Optional, Any
from dataclasses import dataclass
from pathlib import Path
import sys

sys.path.append(str(Path.cwd() / ".." / "src"))
from mcts import Node, Edge, MCTS
from model_wrapper import ModelWrapper
from leela_cnn import LeelaCNN, board_to_tensor
import matplotlib.pyplot as plt
from matplotlib import animation
import matplotlib.patches as patches

In [48]:
og_model = LeelaCNN(10, 128)


In [49]:
import os
def save_split_state_dict(model, out_dir, max_file_size=10 * 1024**2):
    os.makedirs(out_dir, exist_ok=True)

    state_dict = model.state_dict()
    current_chunk = {}
    current_size = 0
    chunk_idx = 0

    for key, tensor in state_dict.items():
        tensor_bytes = tensor.numel() * tensor.element_size()

        # If adding this tensor exceeds chunk limit, save current chunk
        if current_size + tensor_bytes > max_file_size:
            torch.save(current_chunk, f"{out_dir}/chunk_{chunk_idx}.pt")
            chunk_idx += 1
            current_chunk = {}
            current_size = 0

        current_chunk[key] = tensor
        current_size += tensor_bytes

    # Save the final chunk
    if current_chunk:
        torch.save(current_chunk, f"{out_dir}/chunk_{chunk_idx}.pt")


def load_split_state_dict(model, split_dir):
    chunks = sorted(glob.glob(f"{split_dir}/chunk_*.pt"))

    merged = {}
    for path in chunks:
        part = torch.load(path, map_location="cpu")
        merged.update(part)

    model.load_state_dict(merged)
    return model


In [50]:
save_split_state_dict(og_model, out_dir="models/leela_split_checkpoint", max_file_size=10 * 1024**2)
model = load_split_state_dict(LeelaCNN(10, 128), split_dir="models/leela_split_checkpoint")

In [51]:
# check if the two models are equal
for p1, p2 in zip(og_model.parameters(), model.parameters()):
    assert torch.equal(p1, p2)