In [None]:
# @title run pretraining
import torch
import torch.backends

import torch.optim as optim

import os

from network import TakoNet, TakoNetConfig
from pretrain import train, download_training_set
from settings import Configuration

config = Configuration().get_config()

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu') 

model = TakoNetConfig().create_model() # pass device to create_model for GPU
print(f"{model.count_params():,} params")
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
checkpoint_path = "checkpoints/best-pretrained-model.pt" # TODO: configure with command line args
epoch = 0
if os.path.isfile(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=model.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("Checkpoint loaded successfully.")
else:
    print(f"No checkpoint found at {checkpoint_path}, starting from scratch.")

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=config.pretrain.num_epochs, 
    eta_min=1e-6, 
    verbose=True
)
print(config)
download_training_set()
train(model, optimizer, scheduler, starting_epoch=0)
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': 0
}, f'checkpoints/best-model.pt') 

8,963,251 params
Loading checkpoint from checkpoints/best-pretrained-model.pt
Checkpoint loaded successfully.
Adjusting learning rate of group 0 to 1.0000e-03.
Config(model=Config(learning_rate=0.05, policy_output_size=4672, value_output_size=3), mcts=Config(max_nodes=3600, thinking_time=10), train=Config(num_epochs=200, num_self_play_games=100, batch_size=32, num_simulations=100, replay_buffer_size=30000, evaluation_interval=5, save_model=True, model_checkpoint_dir='checkpoints/', training_steps=250), evaluation=Config(num_simulations=800, max_depth=10, num_games=5), pretrain=Config(batch_size=8, num_epochs=10, validation_batch_size=100, validation_interval=1, alpha=0.9), visualize=True, verbose=False)
Puzzle dataset already exists. Skipping download and extraction.
Loading puzzles to memory...
Already pre-processed.
Starting pre-training...
Epoch: 1


loss=3.444, elo=1271, ACC=0.375:   0%|          | 136/100000 [00:21<4:02:27,  6.86it/s]