# Install All Dependencies

In [None]:
# !pip install llama-index
# !pip install llama-index-llms-huggingface
# !pip install llama-index-embeddings-huggingface
# !pip install llama-index-embeddings-huggingface-api
# !pip install --upgrade pip
# !pip install llama-index-agent-lats
# !pip install clingo

# Parse Prompts

In [None]:
import json
with open("dataset/task_1_plan_generation.json", 'r') as file:
    data = json.load(file)
queries = [instance['query'] for instance in data['instances']]

system_prompts = []
initial_states = []
goal_states = []
answers = []
for query in queries:
    system = query.split("[STATEMENT]")
    system_prompts.append(system[0])
    initial = system[1].split("My goal")
    initial_states.append(initial[0])
    goal = initial[1].split("[PLAN]")
    goal_states.append("My goal" + goal[0])
    answers.append(goal[1])


# Initialize Llama-3.1-8B-instruct model

In [None]:
import nest_asyncio
import torch
from llama_index.llms.huggingface import HuggingFaceLLM

def initialize_llama(system_prompt):
    hf_token = "" # REPLACE WITH PERSONAL TOKEN

    nest_asyncio.apply()

    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Meta-Llama-3-8B-Instruct",
        token=hf_token,
    )

    stopping_ids = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>"),
    ]

    llm = HuggingFaceLLM(
        model_name="meta-llama/Meta-Llama-3-8B-Instruct",
        system_prompt = system_prompt,
        max_new_tokens=256,
        model_kwargs={
            "token": hf_token,
            "torch_dtype": torch.bfloat16,
        },
        generate_kwargs={
            "do_sample": True,
            "temperature": 0.6,
            "top_p": 0.9,
        },
        tokenizer_name=tokenizer,
        tokenizer_kwargs={"token": hf_token},
        stopping_ids=stopping_ids,
    )
    
    return llm

# Create LATS Agent

In [None]:
from collections import Counter

def initialize_tool(llm):
    prompt_template = f"""
    I am working on a Blocks World problem, and here is a proposed next step to solve the prolem:

    Proposed Solution:
    {query}

    Is this solution valid? Does it work towards the goal?

    Please respond with "valid" if the solution works towards the goal or "not valid" if it doesn't.
    """

    responses = []

    # Ask the LLM 5 times
    for _ in range(5):
        response = llm.complete(prompt_template).strip().lower()
        responses.append(response)

    # Count the occurrences of "valid" and "not valid"
    counts = Counter(responses)

    # Determine majority response
    if counts["valid"] >= 3:
        return "The solution is good!"
    else:
        return "The solution is not good."

In [None]:
from llama_index.core.tools import QueryEngineTool, ToolMetadata

class LATS:
    def __init__(self, system):
        llm = initialize_llama(system)
        tool = initialize_tool(llm)

        query_engine_tools = [
            QueryEngineTool(
                query_engine=tool,
                metadata=ToolMetadata(
                    name="blocks_world_tool",
                    description=(
                        '''
                        Provide a proposed solution for the next step to the group of LLMs
                        along with the current state of the problem space.
                        It will provide a majority ruling as to whether or not it is a valid
                        solution.
                        '''
                    ),
                ),
            )
        ]

        agent_worker = LATSAgentWorker(
            tools=query_engine_tools,
            llm=llm,
            num_expansions=2,
            max_rollouts=3,
            verbose=True,
        )
        self.agent = agent_worker.as_agent()

    def call(self, initial, goal):
        response = self.agent.chat(initial + "\n[GOAL]\n\n" + goal)
        return response


# Prompt Agent with Prompts

In [None]:
responses = []
for i in range(len(system_prompts)):
    system_statement = system_prompts[i]
    initial_statement = initial_states[i]
    goal_statement = goal_states[i]
    lats = LATS(system_statement)
    responses.append(lats.call(initial_statement, goal_statement))

# lats = LATS(system_prompts[0])
# response = lats.call(initial_states[0], goal_states[0])
# response
print("response: " + responses[0], "\nanswer: ", answers[0])