# LLM-Reasoners Demo
## Setup
Set cuda device and initialize an ExllamaModel use our unified LLM interface.

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
from reasoners.lm import ExLlamaModel
import torch

model = ExLlamaModel(model_dir='/data/yi/Llama-2-70B-GPTQ',
                     lora_dir=None,
                     device = torch.device("cuda:0"),
                     max_batch_size=1,
                     max_new_tokens=200,
                     mem_map=None,
                     max_seq_length=2048)
# HFModel(llama_path, llama_path, device=device, max_batch_size=1, max_new_tokens=512, quantized=quantized, peft_pth=peft_path, load_awq_pth=load_awq_pth)
# Llama2Model(llama2_ckpts, llama_size, max_batch_size=1)
# OpenAIModel(openai_mode)
# ClaudeModel('claude-3-opus-20240229')

We gather one example from the Blocksworld dataset, and the proper prompt for in-context learning examples.
We will talk more about Evaluators later.

In [None]:
from reasoners.benchmark import BWEvaluator
import json

with open('examples/blocksworld/prompts/pool_prompt_v1.json') as f:
    prompt = json.load(f)
evaluator = BWEvaluator(config_file='examples/blocksworld/data/bw_config.yaml',
                        domain_file='examples/blocksworld/data/generated_domain.pddl',
                        data_path='examples/blocksworld/data/split_v1/split_v1_step_4_data.json',
                        init_prompt=prompt)
prompt = evaluator.sample_prompt(shuffle_prompt=False, num_shot=4)
example = evaluator.full_dataset[1]
cot_inputs = (prompt['icl'].replace('<init_state>', example["init"])
                           .replace('<goals>', example["goal"])
                           .replace('<action>', ''))

Here is the example.

In [None]:
print(example['init'])

In [None]:
print(example['goal'])

## Chain-of-Thought
We first experiment with the Chain-of-Thought method.
Since we are having the simplest generation algorithm, we directly ask the model to generate all the steps.
We look at the 4-shot prompt and the generated answer.

In [None]:
print(cot_inputs)

In [None]:
output = model.generate([cot_inputs],
                        hide_input=True,
                        eos_token_id='\n[').text[0][:-1].strip()

In [None]:
print(output)

In [None]:
from reasoners import WorldModel, LanguageModel, SearchConfig, State, Reasoner
from reasoners.algorithm import BeamSearch, MCTS
import reasoners.benchmark.bw_utils as utils
from typing import NamedTuple
import copy
import numpy as np


# We use NamedTuple for clearer presentation, you may just use normal tuple if you want a quick experiment.
class BWStateToT(NamedTuple):
    step_idx: int
    action_history: list[str]
    end: bool


# We just use the description str as the action, we use a type alias for better presentation.
# You may directly use str of you want a quick experiment.
BWAction = str


class BlocksWorldModelToT(WorldModel):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict,
                 max_steps: int = 4,
                 batch_size: int = 1) -> None:
        super().__init__()
        self.max_steps = max_steps
        self.base_model = base_model
        self.prompt = prompt
        self.batch_size = batch_size

    def init_state(self) -> BWStateToT:
        return BWStateToT(step_idx=0, action_history=[], end=False)
    
    def step(self, state: BWStateToT, action: BWAction) -> tuple[BWStateToT, dict]:
        state = copy.deepcopy(state)
        if action != "[PLAN END]":
            state = BWStateToT(step_idx=state.step_idx + 1, action_history=state.action_history + [action], end=False)
        else:
            state = BWStateToT(step_idx=state.step_idx + 1, action_history=state.action_history, end=True)
        return state, {}  # the dict is auxiliary information for SearchConfig, we don't need it here.
    
    def is_terminal(self, state: State) -> bool:
        return state.end or state.step_idx >= self.max_steps


class BWConfigToT(SearchConfig):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict,
                 temperature: float = 0.8,
                 n_candidate: int = 4) -> None:
        super().__init__()
        self.base_model = base_model
        self.example = None
        self.prompt = prompt
        self.n_candidate = n_candidate
        self.temperature = temperature

    def get_actions(self, state: BWStateToT) -> list[BWAction]:
        prompts = (self.prompt["icl"]
                       .replace("<action>", "\n".join(state.action_history + [""]))
                       .replace("<init_state>", utils.extract_init_state(self.example))
                       .replace("<goals>", utils.extract_goals(self.example, return_raw=True)))
        outputs = self.base_model.generate([prompts],
                                           num_return_sequences=self.n_candidate,
                                           max_length=20,
                                           eos_token_id="\n",
                                           temperature=self.temperature,
                                           do_sample=True,
                                           hide_input=True).text
        outputs = [output.split("\n")[0] for output in outputs]
        outputs = list(dict.fromkeys(outputs))  # deduplicate
        return outputs

    # Some reward functions are fast to calculate.
    # We calculate the reward before executing the action, which can be used to better guide the search.
    def fast_reward(self, state: BWStateToT, action: BWAction) -> tuple[float, dict]:
        # We use two rewards here:
        # 1. Intuition: The loglikelihood of the action given the prompt.
        # 2. Self-eval: Ask the language model whether this step is "Good".
        inputs = self.prompt["icl"].replace("<action>", "\n".join(state.action_history + [""])) \
            .replace("<init_state>", utils.extract_init_state(self.example)) \
            .replace("<goals>", utils.extract_goals(self.example, return_raw=True))[:-1]
        
        intuition = self.base_model.get_loglikelihood(inputs, [inputs + "\n" + action])[0]

        self_eval_prompt = (self.prompt["self-eval"].replace("<init_state>", utils.extract_init_state(self.example))
                                                    .replace("<goals>", utils.extract_goals(self.example, return_raw=True))
                                                    .replace("<action>", action))
        self_eval = self.base_model.get_loglikelihood(self_eval_prompt, [self_eval_prompt + "good"])[0]

        return intuition + self_eval, {'intuition': intuition, "self_eval": self_eval}
    
    # kwargs is the auxiliary information returned by SearchConfig.fast_reward and WorldModel.step,
    # so that we do not need duplicated calculations.
    # In this case, we just use the fast_reward result as the reward.
    # Generally, if a reward function depends on the new state, or is slow to calculate,
    # we will calculate it here.
    def reward(self, state, action, **kwargs) -> tuple[float, dict]:
        return kwargs['intuition'] + kwargs['self_eval'], kwargs

Once we have defined the world model and the search config, we can easily call Reasoners to get the result.

In [None]:
world_model = BlocksWorldModelToT(base_model=model, prompt=prompt)
config = BWConfigToT(base_model=model, prompt=prompt)
algorithm = BeamSearch(beam_size=4, max_depth=7)
reasoner_tot = Reasoner(world_model=world_model, search_config=config, search_algo=algorithm)
result_tot = reasoner_tot(example)
print(result_tot)

In [None]:
print('Action, Reward')
for action, _, reward in result_tot.trace:
    print(action, reward)

## Tree-of-Thought
Then let's turn to a tree search algorithm, as introduced by Tree-of-Thought.
We will need to define a simple world model, and a search algorithm, for the Blocksworld task.

## RAP
In RAP, we are truly using the latest block configuration as the state, instead of a history of actions.
Thus, we define a new world model to transit between states, which is just a little complex than the previous one.

In [None]:
BWAction = str


class BWStateRAP(NamedTuple):
    step_idx: int
    last_blocks_state: str
    blocks_state: str
    buffered_action: BWAction


class BlocksWorldModelRAP(WorldModel):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict,
                 max_steps: int = 4,
                 batch_size: int = 1) -> None:
        super().__init__()
        self.max_steps = max_steps
        self.base_model = base_model
        self.prompt = prompt
        self.batch_size = batch_size

    def init_state(self) -> BWStateRAP:
        return BWStateRAP(step_idx=0, last_blocks_state="", blocks_state=utils.
                       extract_init_state(self.example), buffered_action="")

    def step(self, state: BWStateRAP, action: BWAction) -> tuple[BWStateRAP, dict]:
        state = copy.deepcopy(state)
        blocks_state = state.blocks_state
        step_idx = state.step_idx
        blocks_state = self.update_blocks(blocks_state, action)
        new_buffered_action = action if state.buffered_action == "" else ""

        state = BWStateRAP(step_idx=step_idx + 1,
                        last_blocks_state=state.blocks_state,
                        blocks_state=blocks_state,
                        buffered_action=new_buffered_action)
        return state, {"goal_reached": utils.goal_check(utils.extract_goals(self.example), blocks_state)}

    def update_blocks(self, block_states: str, action: BWAction) -> str:
        if "pick" in action:
            key = "world_update_pickup"
        elif "unstack" in action:
            key = "world_update_unstack"
        elif "put" in action:
            key = "world_update_putdown"
        elif "stack" in action:
            key = "world_update_stack"
        else:
            raise ValueError("Invalid action")
        world_update_prompt = self.prompt[key].format(block_states, action.capitalize() + ".")
        world_output = self.base_model.generate([world_update_prompt],
                                                eos_token_id="\n",
                                                hide_input=True,
                                                temperature=0).text[0].strip()
        new_state = utils.apply_change(world_output, block_states)
        return new_state

    def is_terminal(self, state: BWStateRAP) -> bool:
        if utils.goal_check(utils.extract_goals(self.example), state.blocks_state)[0]:
            return True
        elif state.step_idx == self.max_steps:
            return True
        return False

In [None]:
class BWConfigRAP(SearchConfig):
    def __init__(self,
                 base_model: LanguageModel,
                 prompt: dict,
                 batch_size: int = 1,
                 reward_alpha: float = 0.5,
                 goal_reward_default: float = 0.,
                 goal_reached_reward: float = 100.) -> None:
        super().__init__()
        self.base_model = base_model
        self.example = None
        self.prompt = prompt
        self.batch_size = batch_size
        self.reward_alpha = reward_alpha
        self.goal_reward_default = goal_reward_default
        self.goal_reached_reward = goal_reached_reward

    def get_actions(self, state: BWStateRAP) -> list[BWAction]:
        blocks_state = state.blocks_state
        return utils.generate_all_actions(blocks_state)

    def fast_reward(self, state: BWStateRAP, action: BWAction) -> tuple[float, dict]:
        if state.buffered_action == "":
            current_blocks_state = state.blocks_state
        else:
            current_blocks_state = state.last_blocks_state
        previous_action = state.buffered_action + "\n" if state.buffered_action != "" else ""
        
        # every two steps, we will also reduce the icl examples by 2 steps
        # so that the distribution of step length in examples is more reasonable
        icl_template = self.prompt["icl_list"][state.step_idx // 2]
        
        inputs = (icl_template.replace("<init_state>", current_blocks_state)
                              .replace("<goals>", utils.extract_goals(self.example, return_raw=True))
                              .replace("<action>", previous_action))
        intuition = self.base_model.get_loglikelihood(inputs, [inputs + action])[0]

        self_eval_prompt = (self.prompt["self-eval"]
                                .replace("<init_state>", current_blocks_state)
                                .replace("<goals>", utils.extract_goals(self.example, return_raw=True))
                                .replace("<action>", action))
        self_eval = self.base_model.get_loglikelihood(self_eval_prompt, [self_eval_prompt + "good"])[0]

        return (self.calculate_reward(intuition, self_eval),
                {'intuition': intuition, "self_eval": self_eval})

    def calculate_reward(self, intuition, self_eval, goal_reached=None) -> float:
        # to provide a unified interface for reward and fast_reward
        if goal_reached is None:
            goal_reward = self.goal_reward_default
        elif goal_reached[0]:
            goal_reward = self.goal_reached_reward
        else:
            goal_reward = goal_reached[1]
        return (intuition + self_eval) * self.reward_alpha + goal_reward * (1 - self.reward_alpha)

    def reward(self, state: BWStateRAP, action: BWAction,
               intuition: float = None,
               self_eval: float = None,
               goal_reached: tuple[bool, float] = None) -> tuple[float, dict]:
        return (self.calculate_reward(intuition, self_eval, goal_reached),
                {'intuition': intuition, 'goal_reached': goal_reached})

We just use the MCTS algorithm embedded in Reasoners, and build up the pipeline again.

In [None]:
world_model = BlocksWorldModelRAP(base_model=model, prompt=prompt, max_steps=4)
config = BWConfigRAP(base_model=model, prompt=prompt)
algorithm = MCTS(depth_limit=4, disable_tqdm=False, output_trace_in_each_iter=True, n_iters=10)
reasoner_rap = Reasoner(world_model=world_model, search_config=config, search_algo=algorithm)
result_rap = reasoner_rap(example)
print(result_rap)

In [None]:
result_rap.trace

## Visualization

In [4]:
import pickle
import torch
from typing import Union, Tuple, NamedTuple, Callable
import reasoners.benchmark.gw_utils as utils
from reasoners import WorldModel, LanguageModel
import copy
import json


GWAction = str

class CausalMapper(Callable):
    def __init__(self, causal_mapper, mean, std, target_assignment):
        self.causal_mapper = causal_mapper.eval()
        self.mean = mean.to(causal_mapper.device)
        self.std = std.to(causal_mapper.device)
        self.target_assignment = target_assignment.to(causal_mapper.device)
        self.mapping = {0: 0, 1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6, 8 : 7} # TODO: make this a parameter
        
    def __call__(self, latents):
        latents = (latents - self.mean) / self.std
        latents = self._prepare_input(latents, self.target_assignment)
        res = self.causal_mapper(latents)
        res = {key: value[i, :].argmax(-1).item() if value.shape[1] == 2 else value[i, :].item() for i, (key, value) in zip(self.mapping.values(), res.items())}
        return list(res.values())
    
    def _prepare_input(self, latents, target_assignment, flatten_inp=True):
        ta = target_assignment.detach()[None,:,:].expand(latents.shape[0], -1, -1)
        latents = torch.cat([latents[:,:,None] * ta, ta], dim=-2).permute(0, 2, 1)
        if flatten_inp:
            latents = latents.flatten(0, 1)
        return latents
        
class GWState(NamedTuple):
    """The state of the Blocksworld.
    
    See the docstring of BlocksWorldModel for more details.
    """
    step_idx: int
    image: torch.Tensor
    description: str
    latents: torch.Tensor = None

class CausalWorldModel(WorldModel):
    def __init__(self, crl_model, causal_mapper, cm_mean, cm_std, target_assignment, nl_model, tokenizer, device, max_steps=6, config_file=None):
        super().__init__()
        # self.autoencoder = autoencoder
        self.crl_model = crl_model.eval()
        self.causal_mapper = CausalMapper(causal_mapper.to(self.crl_model.device), cm_mean, cm_std, target_assignment)
        self.nl_model = nl_model
        self.tokenizer = tokenizer
        self.device = device
        self.max_steps = max_steps
        self.keys = json.load(open(config_file, 'r'))['flattened_causals']

    def init_state(self, initial_image: torch.Tensor) -> Tuple[torch.Tensor, str]:
        """
        Initialize the state with an image, encode it to latent, transform it,
        and generate the natural language description of the initial state.
        """
        initial_image = (initial_image * 2.0) - 1.0
        latents = self.crl_model.autoencoder.encoder(initial_image[None].to(self.device))
        disentangled_latents, _ = self.crl_model.flow.forward(latents)
        causal_variables = self.causal_mapper(disentangled_latents)
        description = self.map_to_language(causal_variables)
        return (disentangled_latents, description)

    @torch.no_grad()
    def step(self, state: GWState, action: str) -> Tuple[Tuple[torch.Tensor, str], dict]:
        """
        Update the state based on the action.
        """
        if state.latents is None:
            image = (state.image * 2.0) - 1.0
            current_latents = self.crl_model.autoencoder.encoder(image[None].to(self.device))
            current_latents, _ = self.crl_model.flow.forward(current_latents)
        else:
            current_latents = state.latents
        tokenized_description = self.tokenizer(action, return_token_type_ids=True, padding='max_length', max_length=64)
        input_ids = torch.tensor(tokenized_description['input_ids']).to(self.device)
        token_type_ids = torch.tensor(tokenized_description['token_type_ids']).to(self.device)
        attention_mask = torch.tensor(tokenized_description['attention_mask']).to(self.device)
        tokenized_description = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
        new_latents, _ = self.crl_model.prior_t1.sample(current_latents, tokenized_description=tokenized_description, action=torch.empty(1).to(self.device))
        new_latents = new_latents.squeeze(1)
        causal_variables = self.causal_mapper(new_latents)
        new_description = self.map_to_language(causal_variables)
        new_state = GWState(step_idx=state.step_idx + 1, image=None, description=new_description, latents=new_latents)
        return new_state, {'goal_reached' : utils.goal_check(utils.extract_goals(self.example), new_description, ignore_obstacles=True)}

    def is_terminal(self, state: Tuple[torch.Tensor, str]) -> bool:
        if utils.goal_check(utils.extract_goals(self.example), state.description, ignore_obstacles=True)[0]:
            return True
        elif state.step_idx == self.max_steps:
            return True
        return False
        # if len(state) > 0:
        #     generated_ans = ''.join([x.action for x in state])
        #     return "[invalid]" != extract_answer(generated_ans)
        # return False

    def map_to_language(self, causals: torch.Tensor) -> str:
        """
        Map the causal variables to a natural language description using the language model.
        """
        return utils.describe_latent(causals, self.keys)

    def init_state(self) -> GWState:
        """Initialize the world model.

        :return: the initial state
        """
        return GWState(step_idx=0, image=self.example['images'][0], description=utils.
                       extract_init_state(self.example))

import sys
import types

# Create a new module
module = types.ModuleType('world_model')
sys.modules['world_model'] = module

# Attach classes to the module correctly
setattr(module, 'GWState', GWState)
setattr(module, 'CausalWorldModel', CausalWorldModel)

# Ensure you are using the module where required
import world_model

import torch


result_rap = pickle.load(open('/home/john/PhD/BISCUIT/llm-reasoners/logs/gridworld_MCTS/05162024-020758/algo_output/1.pkl', 'rb'))

In [77]:
from reasoners.visualization import visualize
from reasoners.visualization.tree_snapshot import NodeData, EdgeData
from reasoners.algorithm.mcts import MCTSNode
from PIL import Image
import io
import base64
import torch
import numpy as np
import sys
sys.path.append('../')
from models.biscuit_nf import BISCUITNF
from typing import List, Tuple, Dict, Any, Union, Optional

device = 'cuda' if torch.cuda.is_available() else 'cpu'

autoencoder_path = '/home/john/PhD/BISCUIT/pretrained_models/AE_40l_64hid_3c1b3l.ckpt'
model_path = '/home/john/PhD/BISCUIT/pretrained_models/epoch=39-step=19760.ckpt'
model = BISCUITNF.load_from_checkpoint(model_path, autoencoder_path=autoencoder_path)
model.to(device)
model.freeze()
_ = model.eval()

def revert_to_image(latents, model):
    latents = latents.to(model.device)
    latents = model.flow.reverse(latents)
    new_image = model.autoencoder.decoder(latents)[0]
    new_image = (new_image + 1.0) / 2.0
    return new_image

def tensor_to_base64_image(tensor: torch.Tensor) -> str:
    # Convert the tensor to a NumPy array and transpose it to (H, W, 3)
    array = tensor.mul(255).byte().cpu().numpy().transpose(1, 2, 0)
    # Create a PIL Image from the NumPy array
    image = Image.fromarray(array)
    # Save the image to a bytes buffer
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    buffer.seek(0)
    # Encode the image to a base64 string
    img_str = base64.b64encode(buffer.read()).decode("utf-8")
    return f"data:image/png;base64,{img_str}"

def blocksworld_node_data_factory(n: MCTSNode) -> NodeData:
    if n.state:
        if hasattr(n.state, 'image') and n.state.image is not None:
            image_tensor = n.state.image
        else:
            image_tensor = revert_to_image(n.state.latents, model)
        # image_str = tensor_to_base64_image(image_tensor)
        image_str = n.state.description
    else:
        image_str = "Not expanded"

    return NodeData({
        "block state": image_str,
        "# goals satisfied": n.reward_details["goal_reached"][1] if hasattr(n, "reward_details") else "N/A",
        "# visited": len(n.cum_rewards)
    })

def blocksworld_edge_data_factory(n: MCTSNode) -> EdgeData:
    return EdgeData({
        "Q": n.Q,
        "intuition": n.fast_reward_details["intuition"],
        "self_eval": n.fast_reward_details["self_eval"],
        "action": n.action
    })

visualize(result_rap,
          node_data_factory=blocksworld_node_data_factory,
          edge_data_factory=blocksworld_edge_data_factory)



The loaded checkpoint was produced with Lightning v2.1.0, which is newer than your current Lightning version: v2.0.9.post0



Visualizer URL: https://www.llm-reasoners.net/visualizer/a4cadd4a-c63a-412b-9ee1-4c3373f892ab?accessKey=f2d54cca


In [None]:
with open('prompts/pool_prompt_v1.json') as f:
    prompt = json.load(f)
evaluator = BWEvaluator(config_file='examples/blocksworld/data/bw_config.yaml',
                        domain_file='examples/blocksworld/data/generated_domain.pddl',
                        data_path='examples/blocksworld/data/split_v1/split_v1_step_4_data.json',
                        init_prompt=prompt)
evaluator.evaluate(reasoner_tot, shuffle_prompt=True, num_shot=4, resume=0, log_dir='log/')

In [6]:
import base64
import io
import json
import networkx as nx
import plotly.graph_objects as go
from PIL import Image
import torch
import sys
sys.path.append('../')
from models.biscuit_nf import BISCUITNF
from reasoners.algorithm.mcts import MCTSNode, MCTSResult
from typing import List, Tuple, Dict, Any, Union, Optional

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load models
autoencoder_path = '/home/john/PhD/BISCUIT/pretrained_models/AE_40l_64hid_3c1b3l.ckpt'
model_path = '/home/john/PhD/BISCUIT/pretrained_models/epoch=39-step=19760.ckpt'
model = BISCUITNF.load_from_checkpoint(model_path, autoencoder_path=autoencoder_path)
model.to(device)
model.freeze()
_ = model.eval()

def revert_to_image(latents, model):
    latents = latents.to(model.device)
    latents = model.flow.reverse(latents)
    new_image = model.autoencoder.decoder(latents)[0]
    new_image = (new_image + 1.0) / 2.0
    return new_image

def tensor_to_base64_image(tensor: torch.Tensor) -> str:
    # Convert the tensor to a NumPy array and transpose it to (H, W, 3)
    array = tensor.mul(255).byte().cpu().numpy().transpose(1, 2, 0)
    # Create a PIL Image from the NumPy array
    image = Image.fromarray(array)
    # Save the image to a bytes buffer
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    buffer.seek(0)
    # Encode the image to a base64 string
    img_str = base64.b64encode(buffer.read()).decode("utf-8")
    return f"data:image/png;base64,{img_str}"

import math

def recalculate_uct(node: MCTSNode, w_exp: float) -> float:
    if node.parent is None:
        return 0  # Root node has no UCT value
    # print(len(node.cum_rewards))
    # print(len(node.parent.cum_rewards))
    # N = node.parent.cum_rewards
    M = len(node.parent.cum_rewards)
    N = len(node.cum_rewards)
    return node.Q + w_exp * math.sqrt(math.log(M) / max(1, N))

def extract_nodes_edges(root: MCTSNode, w_exp: float):
    nodes = []
    edges = []
    stack = [(root, 0)]
    
    while stack:
        node, depth = stack.pop()
        nodes.append((node, depth))
        if node.children:
            for child in node.children:
                uct_value = recalculate_uct(child, w_exp)
                edges.append((node.id, child.id, child, uct_value))
                stack.append((child, depth + 1))
    
    return nodes, edges

def generate_tree_layout(nodes: List[Tuple[MCTSNode, int]], edges: List[Tuple[int, int, Any, float]]) -> Dict[int, Tuple[float, float]]:
    G = nx.DiGraph()
    for node, depth in nodes:
        G.add_node(node.id, subset=depth)
    G.add_edges_from((u, v) for u, v, _, _ in edges)  # Adjust unpacking to ignore the extra elements

    pos = nx.multipartite_layout(G, subset_key="subset")
    pos = nx.spring_layout(G, pos=pos, fixed=pos.keys())  # refine positions
    return pos

def simulate_path_selection(root: MCTSNode, w_exp: float) -> List[MCTSNode]:
    path = [root]
    current_node = root
    while current_node.children:
        uct_values = [recalculate_uct(child, w_exp) for child in current_node.children]
        max_uct_index = uct_values.index(max(uct_values))
        current_node = current_node.children[max_uct_index]
        path.append(current_node)
    return path

def generate_visualizations(result: MCTSResult, w_exp_values: List[float]):
    for w_exp in w_exp_values:
        nodes, edges = extract_nodes_edges(result.tree_state, w_exp)
        pos = generate_tree_layout(nodes, edges)

        images = {}
        for node, _ in nodes:
            if node.state:
                if hasattr(node.state, 'image') and node.state.image is not None:
                    image_tensor = node.state.image
                else:
                    image_tensor = revert_to_image(node.state.latents, model)
                images[node.id] = tensor_to_base64_image(image_tensor)

        node_x = []
        node_y = []
        node_images = []
        node_text = []
        
        for node_id, (x, y) in pos.items():
            node_x.append(x)
            node_y.append(y)
            if node_id in images:
                node_images.append(images[node_id])
            else:
                node_images.append(None)
            node = next(n for n, _ in nodes if n.id == node_id)
            uct_value = recalculate_uct(node, w_exp) if node.parent is not None else 0.0
            node_text.append(f"Node {node_id}<br>"
                            f"Goals satisfied: {node.reward_details['goal_reached'][1] if hasattr(node, 'reward_details') else 'N/A'}<br>"
                            f"Visited: {len(node.cum_rewards)}<br>"
                            f"UCT: {uct_value:.2f}")

        edge_x = []
        edge_y = []
        edge_text = []
        for u, v, child, uct_value in edges:
            x0, y0 = pos[u]
            x1, y1 = pos[v]
            edge_x.append(x0)
            edge_x.append(x1)
            edge_x.append(None)
            edge_y.append(y0)
            edge_y.append(y1)
            edge_y.append(None)
            
            action = child.action
            Q = f"{child.Q:.2f}"
            intuition = f"{child.fast_reward_details['intuition']:.2f}" if 'intuition' in child.fast_reward_details else 'N/A'
            self_eval = f"{child.fast_reward_details['self_eval']:.2f}" if 'self_eval' in child.fast_reward_details else 'N/A'
            
            edge_label = (f"A: {action}<br>"
                        f"Q: {Q}<br>"
                        f"Intuition: {intuition}<br>"
                        f"Self-Eval: {self_eval}<br>"
                        f"UCT: {uct_value:.2f}")
            edge_text.append((x0 + x1) / 2)
            edge_text.append((y0 + y1) / 2)
            edge_text.append(edge_label)

        edge_trace = go.Scatter(
            x=edge_x, y=edge_y,
            line=dict(width=2, color='black'),
            hoverinfo='none',
            mode='lines')

        edge_annotations = [
            dict(
                x=edge_text[i],
                y=edge_text[i + 1],
                text=edge_text[i + 2],
                showarrow=False,
                font=dict(size=10),
                align='center',
            ) for i in range(0, len(edge_text), 3)
        ]

        node_trace = go.Scatter(
            x=node_x, y=node_y,
            mode='markers+text',
            hoverinfo='text',
            text=node_text,
            textposition="top center",
            marker=dict(
                showscale=False,
                color='skyblue',
                size=50,
                line_width=2))

        fig = go.Figure(data=[edge_trace, node_trace],
                        layout=go.Layout(
                            title=f'MCTS Visualization with w_exp={w_exp}',
                            titlefont_size=16,
                            showlegend=False,
                            hovermode='closest',
                            margin=dict(b=20, l=5, r=5, t=40),
                            annotations=edge_annotations,
                            xaxis=dict(showgrid=False, zeroline=False),
                            yaxis=dict(showgrid=False, zeroline=False))
                        )

        for i, (node_id, (x, y)) in enumerate(pos.items()):
            if node_images[i]:
                fig.add_layout_image(
                    dict(
                        source=node_images[i],
                        xref="x", yref="y",
                        x=x, y=y,
                        sizex=0.15,
                        sizey=0.15,
                        xanchor="center", yanchor="middle"
                    )
                )

        # Simulate and highlight the path selection
        simulated_path = simulate_path_selection(result.tree_state, w_exp)
        path_x = []
        path_y = []
        for node in simulated_path:
            path_x.append(pos[node.id][0])
            path_y.append(pos[node.id][1])
            path_x.append(None)
            path_y.append(None)
        
        # Debugging: Print simulated path node IDs
        print(f"Simulated path for w_exp={w_exp}: {[node.id for node in simulated_path]}")

        # Path trace for visualization
        path_trace = go.Scatter(
            x=path_x, y=path_y,
            line=dict(width=4, color='red'),
            hoverinfo='none',
            mode='lines')

        fig.add_trace(path_trace)
        fig.show()

# Example usage
w_exp_values = [100.0]  # Example values to try
result_rap = pickle.load(open('/home/john/PhD/BISCUIT/llm-reasoners/logs/gridworld_MCTS/05312024-152342/algo_output/21.pkl', 'rb'))
generate_visualizations(result_rap, w_exp_values)

Simulated path for w_exp=100.0: [0, 3]


In [26]:
import base64
import io
import json
import networkx as nx
import graphviz
from PIL import Image
import torch
import sys
import os
from typing import List, Tuple, Dict, Any, Union

# Your previously imported custom modules and functions
# from models.biscuit_nf import BISCUITNF
# from reasoners.algorithm.mcts import MCTSNode, MCTSResult

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load models
# autoencoder_path = '/path/to/your/AE_40l_64hid_3c1b3l.ckpt'
# model_path = '/path/to/your/epoch=39-step=19760.ckpt'
# model = BISCUITNF.load_from_checkpoint(model_path, autoencoder_path=autoencoder_path)
# model.to(device)
# model.freeze()
# _ = model.eval()

def revert_to_image(latents, model):
    latents = latents.to(model.device)
    latents = model.flow.reverse(latents)
    new_image = model.autoencoder.decoder(latents)[0]
    new_image = (new_image + 1.0) / 2.0
    return new_image

def save_tensor_as_image(tensor: torch.Tensor, filepath: str) -> str:
    array = tensor.mul(255).byte().cpu().numpy().transpose(1, 2, 0)
    image = Image.fromarray(array)
    image.save(filepath)
    return filepath

import math

def recalculate_uct(node: MCTSNode, w_exp: float) -> float:
    if node.parent is None:
        return 0
    M = len(node.parent.cum_rewards)
    N = len(node.cum_rewards)
    return node.Q + w_exp * math.sqrt(math.log(M) / max(1, N))

def extract_nodes_edges(root: MCTSNode, w_exp: float):
    nodes = []
    edges = []
    stack = [(root, 0)]
    
    while stack:
        node, depth = stack.pop()
        nodes.append((node, depth))
        if node.children:
            for child in node.children:
                uct_value = recalculate_uct(child, w_exp)
                edges.append((node.id, child.id, child, uct_value))
                stack.append((child, depth + 1))
    
    return nodes, edges

def generate_tree_layout(nodes: List[Tuple[MCTSNode, int]], edges: List[Tuple[int, int, Any, float]]) -> Dict[int, Tuple[float, float]]:
    G = nx.DiGraph()
    for node, depth in nodes:
        G.add_node(node.id, subset=depth)
    G.add_edges_from((u, v) for u, v, _, _ in edges)

    pos = nx.multipartite_layout(G, subset_key="subset")
    pos = nx.spring_layout(G, pos=pos, fixed=pos.keys())
    return pos

def simulate_path_selection(root: MCTSNode, w_exp: float) -> List[MCTSNode]:
    path = [root]
    current_node = root
    while current_node.children:
        uct_values = [recalculate_uct(child, w_exp) for child in current_node.children]
        max_uct_index = uct_values.index(max(uct_values))
        current_node = current_node.children[max_uct_index]
        path.append(current_node)
    return path

def generate_visualizations(result: MCTSResult, w_exp_values: List[float], temp_img_dir: str):
    os.makedirs(temp_img_dir, exist_ok=True)
    for w_exp in w_exp_values:
        nodes, edges = extract_nodes_edges(result.tree_state, w_exp)

        # Create a graphviz Digraph
        dot = graphviz.Digraph(comment=f'MCTS Visualization with w_exp={w_exp}')
        images = {}

        for node, depth in nodes:
            if node.state:
                if hasattr(node.state, 'image') and node.state.image is not None:
                    image_tensor = node.state.image
                else:
                    image_tensor = revert_to_image(node.state.latents, model)
                image_path = os.path.join(temp_img_dir, f"node_{node.id}.png")
                save_tensor_as_image(image_tensor, image_path)
                images[node.id] = image_path
            
            uct_value = recalculate_uct(node, w_exp) if node.parent is not None else 0.0
            if node.id in images:
                node_label = (f"<TABLE BORDER='0' CELLBORDER='1' CELLSPACING='0'>"
                              f"<TR><TD><IMG SRC='{images[node.id]}'/></TD></TR>"
                              f"<TR><TD>Node {node.id}<br/>"
                              f"Goals satisfied: {node.reward_details['goal_reached'][1] if hasattr(node, 'reward_details') else 'N/A'}<br/>"
                              f"Visited: {len(node.cum_rewards)}<br/>"
                              f"UCT: {uct_value:.2f}</TD></TR></TABLE>")
            else:
                node_label = (f"<TABLE BORDER='0' CELLBORDER='1' CELLSPACING='0'>"
                              f"<TR><TD>Node {node.id}<br/>"
                              f"Goals satisfied: {node.reward_details['goal_reached'][1] if hasattr(node, 'reward_details') else 'N/A'}<br/>"
                              f"Visited: {len(node.cum_rewards)}<br/>"
                              f"UCT: {uct_value:.2f}</TD></TR></TABLE>")
            
            dot.node(str(node.id), label=f'<{node_label}>', shape='plaintext')

        for u, v, child, uct_value in edges:
            action = child.action
            Q = f"{child.Q:.2f}"
            intuition = f"{child.fast_reward_details['intuition']:.2f}" if 'intuition' in child.fast_reward_details else 'N/A'
            self_eval = f"{child.fast_reward_details['self_eval']:.2f}" if 'self_eval' in child.fast_reward_details else 'N/A'
            
            edge_label = (f"A: {action}\n"
                          f"Q: {Q}\n"
                          f"Intuition: {intuition}\n"
                          f"Self-Eval: {self_eval}\n"
                          f"UCT: {uct_value:.2f}")
            dot.edge(str(u), str(v), label=edge_label)

        dot.render(f'mcts_visualization_w_exp_{w_exp}.gv', view=True)
        # Save the visualization as an image
        dot.format = 'png'
        dot.render(f'mcts_visualization_w_exp_{w_exp}', view=False)

# Example usage
temp_img_dir = 'temp_images'  # Temporary directory for storing images
w_exp_values = [100.0]  # Example values to try
# Assuming result_rap is properly initialized and available
result_rap = pickle.load(open('/home/john/PhD/BISCUIT/llm-reasoners/logs/gridworld_MCTS/05312024-171447/algo_output/1.pkl', 'rb'))
generate_visualizations(result_rap, w_exp_values, temp_img_dir)

In [22]:
import pickle
import os
from typing import List, Tuple

def check_goal_satisfied(node: MCTSNode, goal_value: float = 100.0) -> bool:
    """
    Recursively check if any node in the MCTS tree has a goal satisfied value equal to the specified goal_value.
    
    :param node: The current node to check.
    :param goal_value: The goal value to check against.
    :return: True if any node satisfies the goal, otherwise False.
    """
    if hasattr(node, 'reward_details') and node.reward_details['goal_reached'][1] == goal_value:
        return True
    if node.children:
        for child in node.children:
            if check_goal_satisfied(child, goal_value):
                return True
    return False

def eval_outputs(directory: str, goal_value: float = 100.0) -> Tuple[float, List[bool], List[str]]:
    """
    Evaluate all output files in the given directory to check if any node has a goal satisfied value of 100.0.
    
    :param directory: The directory containing the output files.
    :param goal_value: The goal value to check against.
    :return: A tuple containing the accuracy, a list of booleans indicating if each file satisfied the goal,
             and a list of filenames of files that failed to satisfy the goal.
    """
    total_files = 0
    correct_files = 0
    result_list = []
    failed_filenames = []

    for filename in os.listdir(directory):
        if filename.endswith('.pkl'):
            total_files += 1
            file_path = os.path.join(directory, filename)
            with open(file_path, 'rb') as file:
                result = pickle.load(file)
                goal_satisfied = check_goal_satisfied(result.tree_state, goal_value)
                result_list.append(goal_satisfied)
                if goal_satisfied:
                    correct_files += 1
                else:
                    failed_filenames.append(filename)

    accuracy = correct_files / total_files if total_files > 0 else 0.0
    return accuracy, result_list, failed_filenames

# Example usage
directory = '/home/john/PhD/BISCUIT/llm-reasoners/logs/gridworld_MCTS/05312024-165635/algo_output'
goal_value = 100.0
accuracy, result_list, failed_filenames = eval_outputs(directory, goal_value)
print(f"Accuracy: {accuracy:.2%}")
print(f"Result List: {result_list}")
print(f"Failed Filenames: {failed_filenames}")



Accuracy: 88.89%
Result List: [True, False, True, True, True, True, True, True, True]
Failed Filenames: ['3.pkl']


In [37]:
import torch
a = torch.load('/home/john/PhD/BISCUIT/llm-reasoners/examples/gridworld/data/step_2.pth')