In [1]:
import os
import datetime
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import deque, defaultdict
from tqdm import tqdm
from train import AlphaLoss
from environments.frozen_lake_manipulation_environment import FrozenLakeManipulationEnv
from environments.gripper_environment import GripperDiscretisedEnv
from data_loading import to_one_hot_encoding, ReplayBuffer, ReplayDataset

In [2]:
# MCTS / AlphaZero params
NUM_SIMS     = 5000       # MCTS simulations/iterations per self-play step
NUM_SELF_PLAY = 10       # number of self-play games to generate per epoch/episode
CPUCT        = 1.41       # PUCT exploration constant
TAU          = 1.0       # temperature for π = N^(1/τ)
# Training params
BATCH_SIZE   = 128
LR           = 1e-3
EVAL_INTERVAL= 1       # eval every self-play games
TARGET_SR    = 0.95      # stop when success rate ≥ 95%
REGULARIZATION = 1e-4    # L2 regularization weight decay constant
MAX_EPISODES = 20 # max number of self-play episodes

NUM_EVAL     = 50
BUFFER_SIZE   = 20000
SAMPLE_SIZE   = 2048

In [3]:
class AlphaZeroNet(nn.Module):
    def __init__(self, n_states, n_actions, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(n_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # policy head
        self.policy_head = nn.Linear(hidden_dim, n_actions)
        # value head
        self.value_head  = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x: one-hot or feature vector of shape (batch, n_states)
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        p = F.log_softmax(self.policy_head(h), dim=1)  # log-probs
        v = torch.tanh(self.value_head(h))             # in [-1,1]
        return p, v.squeeze(-1)

In [4]:
# --- Main Execution ---
def make_env():
    # return gym.make("FrozenLake-v1", is_slippery=False, render_mode="ansi")
    return FrozenLakeManipulationEnv()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)


env    = make_env()
nA = env.action_space.n

if isinstance(env.observation_space, Discrete):
    nS = env.observation_space.n
else:
    # Assuming the observation space is a tuple of (states, ..., states, holding/not_holding) 
    nS = (len(env.observation_space.sample()) - 1) * env.n_states + 1
net    = AlphaZeroNet(nS, nA).to(device)
optimizer   = optim.Adam(net.parameters(), lr=LR, weight_decay=REGULARIZATION)
scheduler   = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)


Using device: cpu


In [15]:
class LearnedMCTSNode:
    def __init__(self, 
                 state,
                 make_env,
                 net,
                 parent=None, 
                 action=None, 
                 prior=0.0,
                 cpuct=1.41,
                 device='cpu',
                 verbose=False):
        
        self.state = state
        self.parent = parent
        self.action = action            # action taken to reach this node
        self.prior = prior              # prior probability of this action
        self.children = {}
        
        self.N = defaultdict(int)       # visit counts per action
        self.W = defaultdict(float)     # total reward per action
        self.Q = defaultdict(float)     # average reward per action (Q = W/N)
        self.reward = 0.0
        self.terminal = False
        
        self.puct_constant = cpuct
        
        self.make_env = make_env
        self.env = make_env()
        self.net = net
        self.device = device
        self.verbose = verbose

    def is_leaf(self):
        return len(self.children) == 0

    def puct(self):
        """Calculate the PUCT value for this node.

        Returns:
            puct_value (float): The PUCT value for this node.
        """
        
        if self.parent.N[self.action] == 0:
            return float("inf")
        
        total_N = sum(self.parent.N.values())
        
        exploitation = self.parent.Q[self.action]
        exploration = self.puct_constant * self.prior * math.sqrt(total_N) / (1 + self.parent.N[self.action])
        
        return exploitation + exploration
        

    def best_puct_child(self):
        return max(self.children.values(), key=lambda child: child.puct())

    def best_child(self):
        return max(self.children.values(), key=lambda child: self.Q[child.action])

    def selection(self):
        """Traverse the tree to select a promising node to expand.
        
        Returns:
            - node  (MCTSNode): The selected node to expand.
            - is_goal (bool): True if the node is a goal state.
        """
        node = self
        while not node.is_leaf():
            node = node.best_puct_child()
            if self.verbose:
                print(f"Selected node {node.state} with visits {node.parent.N[node.action]} and value {node.parent.Q[node.action]}")
        
        return node
    
    
    def expand(self):
        """Expand the current node by simulating the environment and adding child nodes.

        Returns:
            child (LearnedMCTSNode): The child node that is a goal state, if any.
        """
        
        # s_tensor = F.one_hot(torch.tensor(self.state), self.env.observation_space.n).float().to(self.device)
        s_tensor = to_one_hot_encoding(self.state, self.env.observation_space).float().to(self.device)
            
        with torch.no_grad():
            logp, _ = self.net(s_tensor.unsqueeze(0))
            p = torch.exp(logp).cpu().numpy()[0]
        
        for action in range(self.env.action_space.n):
            env_copy = self.make_env()
            env_copy.reset()
            env_copy.unwrapped.s = self.state
            
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            # Only add child if it is not a (non successful) terminal state or if it is not the same state
            # if reward == 0 and terminated or self.state == obs:
            #     continue

            child = LearnedMCTSNode(obs,  
                                    make_env=self.make_env,
                                    parent=self,
                                    action=action, 
                                    prior=p[action],
                                    cpuct=self.puct_constant,
                                    device=self.device,
                                    net=self.net,
                                    verbose=self.verbose)
            
            self.children[action] = child

            if terminated or obs == self.state:
                child.terminal = True
                child.reward = reward if reward == 1 else -1
                
            if reward == 1:
                return child
            
        if self.verbose:
            print(f"Expanded node {self.state} with children: {[c.state for c in self.children.values()]}")
        
        return None
    

    def evaluation(self):
        """Evaluate the current node using the neural network.
        Returns:
            value (float): The value of the current node.
        """
        # s_tensor = F.one_hot(torch.tensor(self.state), self.env.observation_space.n).float().to(self.device)
        s_tensor = to_one_hot_encoding(self.state, self.env.observation_space).float().to(self.device)
        with torch.no_grad():
            _, value = self.net(s_tensor.unsqueeze(0))
            
        return value.item() 

    def backpropagation(self, value):
        """Propagate the simulation result back up the tree."""
        node = self
        while node:
            parent = node.parent
            if parent:
                parent.N[node.action] += 1
                parent.W[node.action] += value
                parent.Q[node.action] = parent.W[node.action] / parent.N[node.action]
                node = parent
            else:
                break            

In [9]:


def run_mcts(root_node,
             tau=1.0, 
             num_sims=1000,
             pipeline_verbose=False):
    """Run MCTS simulations from the given node."""
    
    root_node.expand()
    
    for _ in range(num_sims):
        node = root_node.selection()

        if pipeline_verbose:
            print(f"\nSELECTED NODE: {node.state}, with visits {node.parent.N[node.action]} and value {node.parent.Q[node.action]}\n")
        
        if node.terminal:
            if pipeline_verbose: 
                print(f"Terminal node reached: {node.parent.state} -> {node.state}, with reward {node.reward}")
            node.backpropagation(node.reward)
            continue
        
        # Check if the node had been visited before
        if node.parent.N[node.action] > 0:
            # If the node has been visited before, expand it
            goal_node = node.expand()
            
            if pipeline_verbose:
                print(f"\nEXPANDED NODE: {node.state}, with children {[c.state for c in node.children.values()]}")
        
            # If the node is a goal state, select it, otherwise select a random child
            node = goal_node if goal_node is not None else node.best_puct_child()
            if pipeline_verbose:
                print(f"selected node {node.state} from children {[c.state for c in node.parent.children.values()]}.")
                print(f"is goal node: {goal_node is not None}\n")
            
            
        # If the node is a terminal state, use its reward as the value
        if node.terminal:
            value = node.reward
            if pipeline_verbose:
                print(f"BACKPROPAGATING REWARD: {value} from terminal node {node.state}")
        else:
            value = node.evaluation() # get value from NN
            if pipeline_verbose:
                print(f"BACKPROPAGATING VALUE: {value} from non-terminal node {node.state}")
            
        node.backpropagation(value)
        
        
        
    counts = np.array([root_node.N[a] for a in range(root_node.env.action_space.n)])
    
    counts = counts**(1 / tau)
    
    pi = counts / counts.sum()
    
    return pi
        


In [7]:
root_node = LearnedMCTSNode(state=(0,1,0),
                            make_env=make_env,
                            net=net,
                            cpuct=CPUCT,
                            device=device,
                            verbose=False)

pi = run_mcts(root_node, tau=2., num_sims=5000)
print(pi)


[0.07192478 0.36635814 0.06785789 0.36285    0.06658941 0.06441979]


In [36]:
print(f"Action probabilities: {pi}")

Action probabilities: [0.06464873 0.37371588 0.06464873 0.36547027 0.06642026 0.06509613]


# Training Loop

In [10]:
def self_play_episode(
    make_env,
    net,
    num_sims=100,
    tau=1.,
    cpuct=1.41,
    device='cpu',
    verbose=False
):
    data = []
    env = make_env()
    state, _ = env.reset()
    done = False
    
    while not done:
        root_node = LearnedMCTSNode(state=state,
                                    make_env=make_env,
                                    net=net,
                                    cpuct=cpuct,
                                    device=device,
                                    verbose=verbose)
        
        pi = run_mcts(root_node, tau=tau, num_sims=num_sims)
        
        action = np.random.choice(np.arange(len(pi)), p=pi)
        
        next_state, reward, terminated, truncated, _ = env.step(action)
        
        data.append((state, pi))
        
        if terminated or truncated:
            done = True
            
        state = next_state
        
    return data, reward

In [11]:
def train(net, dataloader, device,
          optimizer,
          scheduler,
          epoch_start=0, epoch_stop=20, cpu=0):
    """
    Train the AlphaZero network using MCTS-generated dataset.

    Args:
        net: Neural network model.
        dataset: Training dataset (raw data to be wrapped with board_data).
        device: torch.device (e.g., 'cuda' or 'cpu').
        optimizer: torch.optim optimizer (e.g., Adam).
        scheduler: torch.optim.lr_scheduler instance.
        epoch_start: Starting epoch index.
        epoch_stop: Stopping epoch index.
        cpu: Random seed / CPU identifier.
    """
    # Set random seed for reproducibility
    torch.manual_seed(cpu)
    net.train()

    # Use custom loss function
    criterion = AlphaLoss()

    losses_per_epoch = []

    # Outer progress bar for epochs
    epoch_bar = tqdm(range(epoch_start, epoch_stop), desc="Epochs", position=0)
    for epoch in epoch_bar:
        scheduler.step()  # Step the learning rate scheduler

        total_loss = 0.0
        losses_per_batch = []

        # Inner progress bar for batches
        batch_bar = tqdm(enumerate(dataloader, 0),
                         total=len(dataloader),
                         desc=f"Epoch {epoch + 1}",
                         leave=False,
                         position=1)

        for i, data in batch_bar:
            state, policy, value = data

            # Move tensors to GPU or CPU
            state = state.to(device).float()
            policy = policy.to(device).float()
            value = value.to(device).float()

            # Forward + backward + optimization step
            optimizer.zero_grad()
            policy_pred, value_pred = net(state)
            loss = criterion(value_pred[:, 0], value, policy_pred, policy)
            loss.backward()
            optimizer.step()

            # Track total loss for this batch
            total_loss += loss.item()
            batch_bar.set_postfix(loss=loss.item())

            # Periodic logging every 10 batches
            if i % 10 == 9:
                avg_loss = total_loss / 10
                tqdm.write(f'[Epoch {epoch + 1}, Batch {i + 1}] Avg Loss: {avg_loss:.4f}')
                tqdm.write(f'Policy: GT {policy[0].argmax().item()}, Pred {policy_pred[0].argmax().item()}')
                tqdm.write(f'Value:  GT {value[0].item()}, Pred {value_pred[0, 0].item()}')
                losses_per_batch.append(avg_loss)
                total_loss = 0.0

        # Epoch-level loss tracking
        if losses_per_batch:
            epoch_avg_loss = sum(losses_per_batch) / len(losses_per_batch)
            losses_per_epoch.append(epoch_avg_loss)
            epoch_bar.set_postfix(avg_epoch_loss=epoch_avg_loss)

        # Early stopping criterion (very conservative)
        if len(losses_per_epoch) > 100:
            recent = sum(losses_per_epoch[-4:-1]) / 3
            earlier = sum(losses_per_epoch[-16:-13]) / 3
            if abs(recent - earlier) <= 0.01:
                tqdm.write("Early stopping criterion met.")
                break

    # Final loss vs. epoch plot
    fig = plt.figure()
    ax = fig.add_subplot(222)
    ax.scatter([e for e in range(1, len(losses_per_epoch) + 1)], losses_per_epoch)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss per batch")
    ax.set_title("Loss vs Epoch")
    os.makedirs("model_data/", exist_ok=True)
    plt.savefig(os.path.join("model_data/", f"Loss_vs_Epoch_{datetime.datetime.today().strftime('%Y-%m-%d')}.png"))
    tqdm.write("Finished Training")

In [12]:
      
def evaluate(net, make_env, num_episodes=50): 
    """Evaluate the trained model on the environment.

    Args:
        net: Neural network model.
        make_env: Function to create a new environment instance.
        num_episodes: Number of episodes to evaluate.

    Returns:
        success_rate: Success rate of the model in the environment.
    """
    
    success_count = 0
    net.eval()
    
    with tqdm(total=num_episodes, desc="Evaluating", position=0) as pbar:
        for episode in range(num_episodes):
            env = make_env()
            state, _ = env.reset()
            
            node = LearnedMCTSNode(state=state,
                                    make_env=make_env,
                                    net=net,
                                    device=device)
              
            done = False
            
            while not node.is_leaf():
                logp, _ = net(to_one_hot_encoding(state, env.observation_space).float().to(device).unsqueeze(0))
                p = torch.exp(logp).cpu().numpy()[0]
                action = np.argmax(p)
                node = node.children[action]
                
            if node.terminal and node.reward == 1:
                success_count += 1
            
            pbar.set_postfix(success_rate=success_count / (episode + 1))
            pbar.update(1)
            
            
    success_rate = success_count / num_episodes
    
    return success_rate

In [13]:
def train_pipeline( net,
                    make_env,
                    optimizer,
                    scheduler,
                    buffer_size=20000,
                    sample_size=2048,
                    batch_size=128,
                    num_episodes=10,
                    num_self_play=10,
                    eval_interval=1,
                    num_eval=200,
                    target_sr=0.90,
                    device = 'cpu',
                    verbose = False):
    
    env = make_env()
    best_net = copy.deepcopy(net)
    best_sr = 0.0
    
    replay_buffer = ReplayBuffer(buffer_size=buffer_size,
                                 sample_size=sample_size)
    
    for episode in range(num_episodes):
        
        # ------------- Self-play -------------
        self_play_bar = tqdm(range(num_self_play),
                            desc=f"Episode {episode+1}/{num_episodes}",
                            position=0)

        for g in self_play_bar:
            data, reward = self_play_episode(make_env=make_env,
                                            net=net,
                                            num_sims=NUM_SIMS,
                                            tau=TAU,
                                            cpuct=CPUCT,
                                            device=device,
                                            verbose=verbose)

            for state, pi in data:
                # Convert state to one-hot encoding
                replay_buffer.add(state=state,
                                mcts_policy=pi,
                                value=reward)

            if verbose:
                tqdm.write(f"Episode {episode+1}. Self-play episode {g+1} finished with reward {reward}. Buffer size: {len(replay_buffer)}")
            
            self_play_bar.set_postfix(reward=reward, buffer=len(replay_buffer))
            
            
        # ------------- Training -------------
        if len(replay_buffer) < batch_size:
            continue
        
        net.train()
        
        experiences = replay_buffer.sample()
        dataset = ReplayDataset(experiences, obs_space=env.observation_space)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Train the network
        train(net=net,
              dataloader=dataloader,
              device=device,
              optimizer=optimizer,
              scheduler=scheduler,
              epoch_start=0, 
              epoch_stop=20,
              cpu=0)
        
        # ------------- Evaluation -------------
        if (episode + 1) % eval_interval == 0:
            success_rate = evaluate(net, make_env, num_episodes=num_eval)
            print(f"Episode {episode+1}. Success rate: {success_rate:.2f}")
            
            # Save the best model
            if success_rate >= target_sr and success_rate >= best_sr:
                best_net = copy.deepcopy(net)
                best_sr = success_rate
                torch.save(best_net.state_dict(), os.path(f"models/best_model_{episode+1}_sr.pth"))
                print(f"Best model saved at episode {episode+1}.")
                
            elif success_rate > best_sr:
                best_net = copy.deepcopy(net)
                best_sr = success_rate
                torch.save(best_net.state_dict(), os.path(f"models/best_model_{episode+1}.pth"))
                print(f"Best model updated at episode {episode+1}.")
                

In [14]:

train_pipeline(net=net,
                make_env=make_env,
                optimizer=optimizer,
                scheduler=scheduler,
                buffer_size=BUFFER_SIZE,
                sample_size=2048,
                batch_size=BATCH_SIZE,
                num_episodes=MAX_EPISODES,
                num_self_play=NUM_SELF_PLAY,
                eval_interval=EVAL_INTERVAL,
                num_eval=NUM_EVAL,
                target_sr=TARGET_SR,
                device=device)

Episode 1/20 — Self-play:   0%|          | 0/10 [00:00<?, ?it/s]

Episode 1/20 — Self-play:  70%|███████   | 7/10 [01:25<00:36, 12.27s/it, buffer=31, reward=-1]


KeyboardInterrupt: 