In [27]:
import os
import datetime
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import deque, defaultdict
from tqdm import tqdm
import argparse
import sys

In [28]:
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, parent_dir)

from mcts_models import MCTSNode, LearnedMCTSNode, AlphaZeroNet
from data_loading import to_one_hot_encoding

In [29]:
# --- Main Execution ---
def make_env():
    return gym.make("FrozenLake-v1", is_slippery=False, render_mode="ansi")
    # return FrozenLakeManipulationEnv()
    # return GripperDiscretisedEnv()
    


env = make_env()
state, info = env.reset()
env.render()


# load the model from the checkpoint at ../models/best_model_frozen_lake.pth
checkpoint_path = os.path.join(os.getcwd(), "../models/best_model_frozen_lake.pth")
checkpoint = torch.load(checkpoint_path)

net = AlphaZeroNet(env.observation_space.n, env.action_space.n)
net.load_state_dict(checkpoint)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)


net.eval()

root_node = LearnedMCTSNode(state=state,
                            net=net,
                            make_env=make_env)



# Run MCTS to get the policy

node = root_node
trajectory = []
trajectory.append(node.state)
done = False

while not done:
    # Run MCTS to get the policy
    logp, _ = net(to_one_hot_encoding(node.state, env.observation_space).float().to(device).unsqueeze(0))
    p = torch.exp(logp).cpu().detach().numpy()[0]
    action = np.argmax(p)
    next_state, reward, terminated, truncated, _ = env.step(action)
    print(env.render())
    node = LearnedMCTSNode(state=next_state,
                            make_env=make_env,
                            net=net,
                            parent=node,
                            action=action,
                            prior = p[action],
                            device=device)
    
    done = terminated or truncated
    state = next_state
    trajectory.append(node.state)

  (Right)
S[41mF[0mFF
FHFH
FFFH
HFFG

  (Right)
SF[41mF[0mF
FHFH
FFFH
HFFG

  (Down)
SFFF
FH[41mF[0mH
FFFH
HFFG

  (Down)
SFFF
FHFH
FF[41mF[0mH
HFFG

  (Down)
SFFF
FHFH
FFFH
HF[41mF[0mG

  (Right)
SFFF
FHFH
FFFH
HFF[41mG[0m

