In [13]:
import os
import shutil
from copy import deepcopy
from copy import copy
import random
import time
import argparse

from math import *
import numpy as np
import torch

import gymnasium as gym

from c_mcts_args import parse_args
from utils.get_pixel_state import get_pixel_state
from utils.memory import Memory
from c_mcts import Node, MCTS
from rnd import RND
import matplotlib.pyplot as plt

# Fix Argparse in ipython
import sys
sys.argv=['']

In [15]:
# CUDA setup
run_name = "mcts_test"
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

# Env Setup
env = gym.make("FrozenLake-v1",  render_mode = "rgb_array", map_name="8x8", is_slippery=False)
env = gym.wrappers.TimeLimit(env, max_episode_steps=32)
# Video recording setup
#render_env = deepcopy(env)
#render_env = gym.wrappers.RecordVideo(render_env,f"videos/{run_name}")

# Init models
memory = Memory(env, args)
mcts = MCTS(args)
mcts.temperature = 0.5

Using device:  cuda


In [16]:
# Load models
checkpoint_path = "checkpoints/FrozenLake_mcts_1103_06-05-2023"
version = 5
mcts.load(f"{checkpoint_path}/mcts/{version}.pth")
mcts.vae.load("models/FL_vae_8x8.pth")
mcts.rnd.load(f"{checkpoint_path}/rnd/{version}.pth")

In [17]:
# Test loop
steps_array = []
goal_reached = 0

for i in range(1):
    # Init
    steps = 0
    reward_ext = 0
    reward_int = 0

    env.reset()
    #render_env.reset()
    sim_env = deepcopy(env)
    state = get_pixel_state(env)
    root = Node(sim_env, state)

    memory.clear()
    for i in range (args.rnn_sequence_length):
        memory.states.append(state)
        memory.append_action(0)

    done = False
    while not done:

        # Train RND
        z_state = mcts.vae.encode(state.unsqueeze(0).to(device))
        mcts.rnd.train(z_state)

        # Get action from mcts    
        history = memory.get_history()
        root, action = mcts.search(root, history)

        # Step in the environment
        _, reward_ext, terminated, truncated, _ = env.step(action)
        #render_env.step(action)
        next_state = get_pixel_state(env)
        done = np.any([terminated, truncated], axis=0)

        # Get RND intrinsic reward
        with torch.no_grad():
            next_z_state = mcts.vae.encode(next_state.unsqueeze(0).to(device))
            reward_int = mcts.rnd.reward(next_z_state)
        # Append current step to memory
        action_id = ['left ', 'down ', 'right', 'up   ']
        #print(f"action: {action_id[action]} | reward_ext: {reward_ext} | reward_int: {reward_int}")
        steps += 1

        # Append statistics
        if reward_ext == 1.0:
            goal_reached += 1
            steps_array.append(steps)
    
    
print(f"Goals reached: {goal_reached} / 8")
print(f"Avg steps to goal: {np.mean(steps_array)}")

env.close()
#render_env.close()

Goals reached: 1 / 8
Avg steps to goal: 16.0
