In [39]:
import os
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, defaultdict
from tqdm import tqdm
from environments import FrozenLakeManipulationEnv

In [40]:
# MCTS / AlphaZero params
NUM_SIMS     = 1000       # simulations per move
CPUCT        = 1.41       # PUCT exploration constant
TAU          = 1.0       # temperature for π = N^(1/τ)
# Training params
BATCH_SIZE   = 64
LR           = 1e-3
EVAL_INTERVAL= 1       # eval every self-play games
TARGET_SR    = 0.95      # stop when success rate ≥ 95%
REGULARIZATION = 1e-4    # L2 regularization weight decay constant
MAX_EPISODES = 20 # max number of self-play episodes

NUM_EVAL     = 50
BUFFER_SIZE   = 3000

models_dir = "models"

In [41]:
class AlphaZeroNet(nn.Module):
    def __init__(self, n_states, n_actions, hidden_dim=128):
        super().__init__()
        self.fc1 = nn.Linear(n_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # policy head
        self.policy_head = nn.Linear(hidden_dim, n_actions)
        # value head
        self.value_head  = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x: one-hot or feature vector of shape (batch, n_states)
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        p = F.log_softmax(self.policy_head(h), dim=1)  # log-probs
        v = torch.tanh(self.value_head(h))             # in [-1,1]
        return p, v.squeeze(-1)

In [42]:
# --- Main Execution ---
def make_env():
    return gym.make("FrozenLake-v1", is_slippery=False, render_mode="ansi")
    # return FrozenLakeManipulationEnv()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)


env    = make_env()
nS, nA = env.observation_space.n, env.action_space.n
net    = AlphaZeroNet(nS, nA).to(device)
opt    = optim.Adam(net.parameters(), lr=LR, weight_decay=REGULARIZATION)

Using device: cpu


In [61]:
class LearnedMCTSNode:
    def __init__(self, 
                 state,
                 make_env,
                 net,
                 parent=None, 
                 action=None, 
                 prior=0.0,
                 cpuct=1.41,
                 device='cpu',
                 verbose=False):
        
        self.state = state
        self.parent = parent
        self.action = action            # action taken to reach this node
        self.prior = prior              # prior probability of this action
        self.children = {}
        
        self.N = defaultdict(int)       # visit counts per action
        self.W = defaultdict(float)     # total reward per action
        self.Q = defaultdict(float)     # average reward per action (Q = W/N)
        self.reward = 0.0
        self.terminal = False
        
        self.puct_constant = cpuct
        
        self.make_env = make_env
        self.env = make_env()
        self.net = net
        self.device = device
        self.verbose = verbose

    def is_leaf(self):
        return len(self.children) == 0

    def puct(self):
        """Calculate the PUCT value for this node.

        Returns:
            puct_value (float): The PUCT value for this node.
        """
        
        if self.parent.N[self.action] == 0:
            return float("inf")
        
        total_N = sum(self.parent.N.values())
        
        exploitation = self.parent.Q[self.action]
        exploration = self.puct_constant * self.prior * math.sqrt(total_N) / (1 + self.parent.N[self.action])
        
        return exploitation + exploration
        

    def best_puct_child(self):
        return max(self.children.values(), key=lambda child: child.puct())

    def best_child(self):
        return max(self.children.values(), key=lambda child: self.Q[child.action])

    def selection(self):
        """Traverse the tree to select a promising node to expand.
        
        Returns:
            - node  (MCTSNode): The selected node to expand.
            - is_goal (bool): True if the node is a goal state.
        """
        node = self
        while not node.is_leaf():
            node = node.best_puct_child()
            if self.verbose:
                print(f"Selected node {node.state} with visits {node.parent.N[node.action]} and value {node.parent.Q[node.action]}")
        
        return node
    
    
    def expand(self):
        """Expand the current node by simulating the environment and adding child nodes.

        Returns:
            child (LearnedMCTSNode): The child node that is a goal state, if any.
        """
        
        s_tensor = F.one_hot(torch.tensor(self.state), self.env.observation_space.n).float().to(self.device)
        
            
        with torch.no_grad():
            logp, _ = self.net(s_tensor.unsqueeze(0))
            p = torch.exp(logp).cpu().numpy()[0]
        
        for action in range(self.env.action_space.n):
            env_copy = self.make_env()
            env_copy.reset()
            env_copy.unwrapped.s = self.state
            
            obs, reward, terminated, truncated, _ = env_copy.step(action)

            # Only add child if it is not a (non successful) terminal state or if it is not the same state
            # if reward == 0 and terminated or self.state == obs:
            #     continue

            child = LearnedMCTSNode(obs,  
                                    make_env=self.make_env,
                                    parent=self,
                                    action=action, 
                                    prior=p[action],
                                    cpuct=self.puct_constant,
                                    device=self.device,
                                    net=self.net,
                                    verbose=self.verbose)
            
            self.children[action] = child

            if terminated or obs == self.state:
                child.terminal = True
                child.reward = reward if reward == 1 else -1
                
            if reward == 1:
                return child
            
        if self.verbose:
            print(f"Expanded node {self.state} with children: {[c.state for c in self.children.values()]}")
        
        return None
    

    def evaluation(self):
        """Evaluate the current node using the neural network.
        Returns:
            value (float): The value of the current node.
        """
        s_tensor = F.one_hot(torch.tensor(self.state), self.env.observation_space.n).float().to(self.device)
        with torch.no_grad():
            _, value = self.net(s_tensor.unsqueeze(0))
            
        return value.item() 

    def backpropagation(self, value):
        """Propagate the simulation result back up the tree."""
        node = self
        while node:
            parent = node.parent
            if parent:
                parent.N[node.action] += 1
                parent.W[node.action] += value
                parent.Q[node.action] = parent.W[node.action] / parent.N[node.action]
                node = parent
            else:
                break            

In [83]:


def run_mcts(root_node,
             tau=1.0, 
             num_sims=1000):
    """Run MCTS simulations from the given node."""
    
    root_node.expand()
    
    for _ in range(num_sims):
        node = root_node.selection()

        print(f"\nSELECTED NODE: {node.state}, with visits {node.parent.N[node.action]} and value {node.parent.Q[node.action]}\n")
        
        if node.terminal:
            print(f"Terminal node reached: {node.parent.state} -> {node.state}, with reward {node.reward}")
            node.backpropagation(node.reward)
            continue
        
        # Check if the node had been visited before
        if node.parent.N[node.action] > 0:
            # If the node has been visited before, expand it
            goal_node = node.expand()
            
            print(f"\nEXPANDED NODE: {node.state}, with children {[c.state for c in node.children.values()]}")
        
            # If the node is a goal state, select it, otherwise select a random child
            node = goal_node if goal_node is not None else node.best_puct_child()
            
            print(f"selected node {node.state} from children {[c.state for c in node.parent.children.values()]}.")
            print(f"is goal node: {goal_node is not None}\n")
            
            
        # If the node is a terminal state, use its reward as the value
        if node.terminal:
            value = node.reward
            print(f"BACKPROPAGATING REWARD: {value} from terminal node {node.state}")
        else:
            value = node.evaluation() # get value from NN
            print(f"BACKPROPAGATING VALUE: {value} from non-terminal node {node.state}")
            
        node.backpropagation(value)
        
        
        
    counts = np.array([root_node.N[a] for a in range(root_node.env.action_space.n)])
    
    counts = counts**(1 / tau)
    
    pi = counts / counts.sum()
    
    return pi
        


In [102]:
root_node = LearnedMCTSNode(state=0,
                            make_env=make_env,
                            net=net,
                            cpuct=CPUCT,
                            device=device,
                            verbose=False)

pi = run_mcts(root_node, tau=1., num_sims=500)
print(pi)



SELECTED NODE: 0, with visits 0 and value 0.0

Terminal node reached: 0 -> 0, with reward -1

SELECTED NODE: 4, with visits 0 and value 0.0

BACKPROPAGATING VALUE: 0.12379749119281769 from non-terminal node 4

SELECTED NODE: 1, with visits 0 and value 0.0

BACKPROPAGATING VALUE: 0.15063880383968353 from non-terminal node 1

SELECTED NODE: 0, with visits 0 and value 0.0

Terminal node reached: 0 -> 0, with reward -1

SELECTED NODE: 4, with visits 1 and value 0.12379749119281769


EXPANDED NODE: 4, with children [4, 8, 5, 0]
selected node 4 from children [4, 8, 5, 0].
is goal node: False

BACKPROPAGATING REWARD: -1 from terminal node 4

SELECTED NODE: 1, with visits 1 and value 0.15063880383968353


EXPANDED NODE: 1, with children [0, 5, 2, 1]
selected node 0 from children [0, 5, 2, 1].
is goal node: False

BACKPROPAGATING VALUE: 0.11375313997268677 from non-terminal node 0

SELECTED NODE: 5, with visits 0 and value 0.0

Terminal node reached: 1 -> 5, with reward -1

SELECTED NODE: 2, w

In [85]:
env = make_env()
env.reset()
env.render()
node = root_node
trajectory = []
trajectory.append((node.state, None, 0))
while not node.is_leaf():
    prev_node = node
    node = prev_node.best_child()
    trajectory.append((node.state, node.action, node.parent.Q[node.action]))
    print(f"Best action from state {prev_node.state} to state {node.state} with value {node.parent.Q[node.action]}")
    env.step(node.action)
    print(env.render())

Best action from state 0 to state 1 with value 0.8275279130600591
  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG

Best action from state 1 to state 2 with value 0.8736741070891654
  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG

Best action from state 2 to state 6 with value 0.9102544542590598
  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG

Best action from state 6 to state 10 with value 0.9428766608580775
  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG

Best action from state 10 to state 14 with value 0.974944531338375
  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG

Best action from state 14 to state 15 with value 1.0
  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m



In [103]:
print(f"Action probabilities: {pi}")

Action probabilities: [0.012 0.092 0.882 0.014]
