In [1]:
import os
import datetime
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import deque, defaultdict
from tqdm import tqdm
from train import AlphaLoss, train_pipeline
from environments import FrozenLakeManipulationEnv, GripperDiscretisedEnv
from data_loading import to_one_hot_encoding, ReplayBuffer, ReplayDataset
from mcts_models import MCTSNode, LearnedMCTSNode, AlphaZeroNet

In [2]:
# MCTS / AlphaZero params
NUM_SIMS     = 10000       # MCTS simulations/iterations per self-play step
NUM_SELF_PLAY = 1       # number of self-play games to generate per epoch/episode
NUM_EPOCHS   = 10       # number of epochs to train the model
CPUCT        = 1.41       # PUCT exploration constant
TAU          = 1.0       # temperature for π = N^(1/τ)
# Training params
BATCH_SIZE   = 128
LR           = 1e-3
EVAL_INTERVAL= 1       # eval every self-play games
TARGET_SR    = 0.90      # stop when success rate ≥ 95%
REGULARIZATION = 1e-4    # L2 regularization weight decay constant
MAX_EPISODES = 10 

NUM_EVAL     = 50
BUFFER_SIZE   = 20000
SAMPLE_SIZE   = 2048

In [3]:
# --- Main Execution ---
def make_env():
    # return gym.make("FrozenLake-v1", is_slippery=False, render_mode="ansi")
    return FrozenLakeManipulationEnv()
    # return GripperDiscretisedEnv()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

env    = make_env()
env.reset()
state, info = env.reset()
nA = env.action_space.n

if isinstance(env.observation_space, Discrete):
    nS = env.observation_space.n
else:
    # Assuming the observation space is a tuple of (states, ..., states, holding/not_holding) 
    nS = (len(env.observation_space.sample()) - 1) * env.n_states + 1
net    = AlphaZeroNet(nS, nA).to(device)
optimizer   = optim.Adam(net.parameters(), lr=LR, weight_decay=REGULARIZATION)
scheduler   = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)


Using device: cpu


In [4]:
root_node = LearnedMCTSNode(state=state,
                            make_env=make_env,
                            net=net,
                            cpuct=CPUCT,
                            device=device,
                            verbose=False)

# Training Loop

In [6]:
train_pipeline(net=net,
                make_env=make_env,
                optimizer=optimizer,
                scheduler=scheduler,
                buffer_size=BUFFER_SIZE,
                sample_size=SAMPLE_SIZE,
                batch_size=BATCH_SIZE,
                num_sims=NUM_SIMS,
                num_epochs=NUM_EPOCHS,
                tau=TAU,
                cpuct=CPUCT,
                num_episodes=MAX_EPISODES,
                num_self_play=NUM_SELF_PLAY,
                eval_interval=EVAL_INTERVAL,
                num_eval=NUM_EVAL,
                target_sr=TARGET_SR,
                device=device)

Episode 1/10:   0%|          | 0/10 [00:00<?, ?it/s]

Episode 1/10:  20%|██        | 2/10 [00:26<01:47, 13.47s/it, buffer=4, reward=-1]


KeyboardInterrupt: 