# LATS - Three Gods Puzzle

Before starting run `nix develop .` to set the dependencies

In [1]:
import json
import os
import re
import uuid
from pprint import pp
from typing import Literal

import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel
from pyswip import Prolog

load_dotenv()

# env assertions
assert os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY is not set"

In [2]:
remove_begin_end_regex = re.compile(
    r"% -- remove_begin -- %.*?% -- remove_end -- %", re.DOTALL
)


def read_facts_for_llm(file_path):
    with open(file_path, "r") as f:
        file_content = f.read()
        clean_content = remove_begin_end_regex.sub("", file_content)
        return clean_content

In [3]:
PUZZLE = """
<puzzle>
Three gods A, B, and C are called, in no particular order, True, False, and Random. True always speaks truly, False always speaks falsely, 
but whether Random speaks truly or falsely is a completely random matter. 
Your task is to determine the identities of A, B, and C by asking three yes–no questions; each question must be put to exactly one god. 
The gods understand English, but will answer all questions in their own language, in which the words for yes and no are da and ja,[3] in some order. 
You do not know which word means which.
</puzzle>
"""

PARTIAL_FACTS_FOR_LLM = (
    "<partial_prolog_facts>\n"
    + read_facts_for_llm("three_gods_inference.pl")
    + "\n</partial_prolog_facts>"
)

PROBLEM_PROMPT_V1 = f"""
You are solving the Three Gods puzzle using the SWI-Prolog logic programming language.

{PUZZLE}

{PARTIAL_FACTS_FOR_LLM}

Return a single question for one god using either ask/3 or meta_ask/3, and do not include explanations, output only the prolog query without any other text.
"""

EXE_PROMPT_V1 = f"""
You are solving the Three Gods puzzle using the SWI-Prolog logic programming language.

{PUZZLE}

Think step by step and output the list of predicted god identities, each identified god should be unique, e.g do not repeat the same god twice.

The god schema is as follows:

identifier: one of [a, b, c]
value: one of [true_god, false_god, random_god]
"""

EVAL_PROMPT_V1 = f"""
Given this logic puzzle and partial Prolog facts.

{PUZZLE}

{PARTIAL_FACTS_FOR_LLM}

Score this predicate from 1 to 10(10 = most promising predicate, 1 = least promising). Output only an integer with no explanation.
"""


SELF_CONSISTENCY_PROMPT_V1 = """
Given intermediate predicates in Prolog logic programming language (SWI-Prolog) and their results, determine how self-consistent they are.

Score from 1 to 10 (10 = most consistent, 1 = least). Output only an integer with no explanation.
"""

REFLECT_PROMPT_V1 = """
You are solving the Three Gods puzzle using the SWI-Prolog logic programming language.
Your previous attempt failed to identify all three gods.

{puzzle}

{partial_prolog_facts}

YOUR FAILED ATTEMPT:
{latest_failed_trajectory}

FINAL STATE:
{final_state}

Analysis required:
1. What information did you gain from each predicate?
2. What information is still missing?
3. Which god(s) could you not identify and why?
4. Did you waste any questions on redundant information?
5. Did you handle the da/ja ambiguity correctly?
6. Did you account for the Random god's unpredictability?

Reflection:
Based on this failure, what strategy would work better? Be specific about:
- Which god to ask first and why
- What type of question to use (direct vs counterfactual)
- How to handle the language ambiguity
- How to isolate the Random god
"""

In [4]:
class GodSchema(BaseModel, frozen=True):
    identifier: Literal["a", "b", "c"]
    value: Literal["true_god", "false_god", "random_god"]


class GodListSchema(BaseModel, frozen=True):
    gods: list[GodSchema]


# Load the Prolog helper predicates that back this notebook.
Prolog.consult("three_gods_inference.pl")

# Example query
example_query = Prolog.query(
    "meta_ask(b, (a = true_god), Answer).", catcherrors=True, normalize=True
)
print([i for i in example_query][0])


def god_schema_to_predicate(gods: list[GodSchema]) -> list[str]:
    return [f"god({god.identifier}, {god.value})." for god in gods]

{'Answer': 'ja'}


In [5]:
class IntSchema(BaseModel):
    value: int


class Node:
    pass


# sometimes the llm returns the prolog code in a code block, we need to remove it
remove_prolog_code_block_regex = re.compile(r"```prolog(.*?)```", re.DOTALL)


class Generator:
    def __init__(self):
        self.client = OpenAI()

    def generate(
        self,
        context: str,
        reflection: str,
        state: str,
        latest_failed_trajectory: list[Node],
    ) -> str:
        latest_failed_trajectory_str = "\n".join(
            [node.state for node in latest_failed_trajectory]
        )

        prompt = f"""
        Reflection from previous runs: {reflection}
        Failed trajectory: {latest_failed_trajectory_str}
        Steps so far: {context}
        Current state: {state}
        Next step is:"
        """
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": PROBLEM_PROMPT_V1},
                {"role": "user", "content": prompt},
            ],
        )

        return re.sub(
            remove_prolog_code_block_regex, r"\1", response.choices[0].message.content
        ).strip()


def generator_impl():
    return Generator()


class Executor:
    def __init__(self):
        self.client = OpenAI()

    def execute(self, context: str, state: str, retries=5) -> GodListSchema:
        for _ in range(retries):
            prompt = f"Questions asked so far: {context}\nCurrent question: {state}"
            response = self.client.chat.completions.parse(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": EXE_PROMPT_V1},
                    {"role": "user", "content": prompt},
                ],
                response_format=GodListSchema,
            )

            gods = response.choices[0].message.parsed.gods

            if len(frozenset(gods)) == 3:
                return gods

        raise Exception(f"Failed to execute after {retries} retries")


def executor_impl():
    return Executor()


class Evaluator:
    def __init__(self):
        self.client = OpenAI()

    def evaluate(self, context: str, state: str) -> int:
        prompt = (
            f"Questions asked so far: {context}\nCurrent question: {state}\n Score is:"
        )
        response = self.client.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": EVAL_PROMPT_V1},
                {"role": "user", "content": prompt},
            ],
            response_format=IntSchema,
        )

        return round(response.choices[0].message.parsed.value / 10, 2)

    def self_consistency(
        self,
        context: str,
        reflection: str,
        state: str,
        latest_failed_trajectory: list[Node],
    ) -> int:
        sc_sampling_size = 3
        prompt = "\n---\n".join(
            [
                generator_impl().generate(
                    context, reflection, state, latest_failed_trajectory
                )
                for _ in range(sc_sampling_size)
            ]
        )
        response = self.client.chat.completions.parse(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": SELF_CONSISTENCY_PROMPT_V1},
                {"role": "user", "content": prompt},
            ],
            response_format=IntSchema,
        )
        return round(response.choices[0].message.parsed.value / 10, 2)

    def reflect(self, context: str, state: str) -> str:
        reflection_prompt = REFLECT_PROMPT_V1.format(
            puzzle=PUZZLE,
            partial_prolog_facts=PARTIAL_FACTS_FOR_LLM,
            latest_failed_trajectory=context,
            final_state=state,
        )
        response = self.client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": reflection_prompt},
                {"role": "user", "content": "Your reflection is?"},
            ],
        )
        return response.choices[0].message.content


def evaluator_impl():
    return Evaluator()

In [6]:
generator = generator_impl()
executor = executor_impl()
evaluator = evaluator_impl()

In [7]:
number_of_generated_actions = 3
depth_limit = (
    # Gods puzzle limits us to only consider 3 questions that we can ask gods.
    3
)
number_of_rollouts = number_of_generated_actions * 2  # so reflection takes place
exploration_weight = np.sqrt(2)
value_weight = 0.7

In [None]:
class Node:
    def __init__(
        self,
        iteration: int,
        state: str,
        context="",
        parent=None,
        depth=0,
        self_reflection="",
    ):
        self.id = uuid.uuid4()
        self.iteration = iteration
        self.state = state
        self.parent = parent
        self.context = context
        self.depth = depth
        self.children = []
        self.value = 0.0
        self.visit_count = 0
        self.self_reflection = self_reflection
        self.reward = 0

    def to_dict(self):
        return {
            "id": self.id.hex,  # for debugging
            "iteration": self.iteration,
            "state": self.state,
            "parent": self.parent.state if self.parent else None,
            "context": self.context,
            "depth": self.depth,
            "value": round(self.value, 2),
            "visit_count": self.visit_count,
            "self_reflection": self.self_reflection,
            "reward": self.reward,
            "children": [child.to_dict() for child in self.children],
        }

    def best_child_utc(self):
        if not self.children:
            return None
        else:
            return max(self.children, key=lambda child: child.uct())

    def uct(self):
        if self.visit_count == 0:
            # open question: maybe it shouldn't be inf, in this case the non explored nodes will be preferred?
            return float("inf")
        else:
            return (self.value) + exploration_weight * np.sqrt(
                np.log(self.parent.visit_count) / self.visit_count
            )

    def update(self, value):
        self.visit_count += 1
        self.value += value

    def backprop_update(self, is_solved):
        self.visit_count += 1
        old_V = self.value
        old_N = self.visit_count
        self.value = (old_V * (old_N - 1) + is_solved) / old_N

    def parent_states_as_string(self):
        node = self
        states = []

        while node.parent is not None:
            if "No predicates asked yet" not in node.state:
                states.append(node.state)
            node = node.parent

        return "\n".join(reversed(states))


def generate_with_retry(
    generator,
    context,
    reflection,
    state,
    latest_failed_trajectory,
    retries=5,
):
    for _ in range(retries):
        new_state = generator.generate(
            context=context,
            reflection=reflection,
            state=state,
            latest_failed_trajectory=latest_failed_trajectory,
        )

        print("trying state: ", new_state)

        prolog_eval_result = [
            result
            for result in Prolog.query(new_state, catcherrors=True, normalize=True)
        ][0]

        if prolog_eval_result:
            return new_state, json.dumps(prolog_eval_result)

    raise Exception(f"Failed to generate prolog query after {retries} retries")

In [None]:
print("******** Starting ********")
iteration = 0
root = Node(state="No predicates asked yet", iteration=iteration)
root.visit_count = 1
latest_failed_trajectory = []


for _ in range(number_of_rollouts):
    node = root
    trajectory = [root]

    for _ in range(depth_limit):
        if len(node.children) == 0:
            for _ in range(number_of_generated_actions):
                new_state, prolog_eval_result = generate_with_retry(
                    generator,
                    context=node.context,
                    reflection=node.self_reflection,
                    state=node.state,
                    latest_failed_trajectory=latest_failed_trajectory,
                )

                context = node.parent_states_as_string()
                iteration += 1

                child = Node(
                    state=new_state + " => " + prolog_eval_result,
                    parent=node,
                    context=context,
                    depth=node.depth + 1,
                    self_reflection=node.self_reflection,
                    iteration=iteration,
                )

                lm_score = evaluator.evaluate(context=child.context, state=child.state)
                sc_score = evaluator.self_consistency(
                    context=child.context,
                    reflection=child.self_reflection,
                    state=child.state,
                    latest_failed_trajectory=latest_failed_trajectory,
                )
                value = value_weight * lm_score + (1 - value_weight) * sc_score

                child.update(value)

                node.children.append(child)
                print("******** child ********")
                pp(child.to_dict())

        # select best child
        if len(node.children) > 0:
            node = node.best_child_utc()
            print("******** Selected child ********")
            pp(node.to_dict())

        trajectory.append(node)

    # execute
    gods = executor.execute(context=node.context, state=node.state)
    god_predicates = god_schema_to_predicate(gods)

    is_solved = all(
        [
            len([result for result in Prolog.query(predicate)]) > 0
            for predicate in god_predicates
        ]
    )

    if is_solved:
        node.reward = 1
    else:
        latest_failed_trajectory = trajectory
        node.self_reflection = evaluator.reflect(context=node.context, state=node.state)

    reflection = node.self_reflection
    print("******** Reflection ********")
    print(reflection)

    # backpropagate
    for rev_node in reversed(trajectory):
        print("******** Backpropagating ********")
        rev_node.backprop_update(is_solved)
        print("updated value of", rev_node.id, rev_node.value)

        # `bubble-up` reflection
        rev_node.self_reflection = reflection
        print("reflection assigned to ", rev_node.id, rev_node.state)

    if is_solved:
        print("******** Final solution found ********")

        for node in reversed(trajectory):
            print("******** Node ********")
            pp(node.to_dict())

        break

******** Starting ********
trying state:  meta_ask(a, (b = random_god), Answer).
******** child ********
{'id': '3178df46471f429ab798a9a443674681',
 'iteration': 1,
 'state': 'meta_ask(a, (b = random_god), Answer). => {"Answer": "da"}',
 'parent': 'No predicates asked yet',
 'context': '',
 'depth': 1,
 'value': 0.62,
 'visit_count': 1,
 'self_reflection': '',
 'reward': 0,
 'children': []}
trying state:  meta_ask(a, (b = true_god), Answer).
******** child ********
{'id': '6856006fe1ae493c8a86f39a3b569cb9',
 'iteration': 2,
 'state': 'meta_ask(a, (b = true_god), Answer). => {"Answer": "da"}',
 'parent': 'No predicates asked yet',
 'context': '',
 'depth': 1,
 'value': 0.6799999999999999,
 'visit_count': 1,
 'self_reflection': '',
 'reward': 0,
 'children': []}
trying state:  meta_ask(a, (b = random_god), Answer).
******** child ********
{'id': 'b443706dda14433089ef2a3433fc3819',
 'iteration': 3,
 'state': 'meta_ask(a, (b = random_god), Answer). => {"Answer": "da"}',
 'parent': 'No pred

In [None]:
generate_json_tree = True  # Run all and do not overwrite the file
if generate_json_tree:
    with open("lats_three_gods_tree.json", "w") as f:
        f.write(json.dumps(root.to_dict(), indent=2))