In [1]:
cd ..

/home/ziyu/code/LLMs/tot-llm/scripts


In [2]:
import json
from tot.prompts.crosswords import propose_prompt, value_prompt
from tot.models import gpt
from tot.tasks.crosswords import MiniCrosswordsEnv

env = MiniCrosswordsEnv()

In [3]:
def prompt_wrap(obs):
    return propose_prompt.format(input=obs)

print(prompt_wrap(env.reset(0)))
# print('---------')
# print(prompt_wrap(env.step('h2. value')[0]))

Let's play a 5 x 5 mini crossword, where each word should have exactly 5 letters.

Current Board:
_____
_____
_____
_____
_____

Unfilled:
h1. An agendum; something to be done: _____
h2. An engine: _____
h3. Pretentious; flowery: _____
h4. A salon; a hall: _____
h5. To mock; to sneer: _____
v1. To heap: _____
v2. An Indian antelope: _____
v3. To intend; to plan; to devise; a nettle; to guess: _____
v4. A nozzle: _____
v5. Desiccator; more dry: _____

Filled:

Changed:


Given the current status, list all possible answers for unfilled or changed words, and your confidence levels (certain/high/medium/low), using the format "h1. apple (medium)". Use "certain" cautiously and only when you are 100% sure this is the correct word. You can list more then one possible answer for each word.



In [4]:
import re
import copy
from tot.models import gpt

def parse_line(input_str):
    pattern = r'^([hv][1-5])\. ([a-zA-Z]{5,5}) \((certain|high|medium|low)\).*$'

    match = re.match(pattern, input_str)

    if match:
        parts = [match.group(1), match.group(2), match.group(3)]
        return parts
    else:
        return None

confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1}  # TODO: ad hoc

def parse_response(response):
    lines = response.split('\n')
    parsed_lines = [parse_line(line) for line in lines]

    parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]

    return parsed_lines if len(parsed_lines) >= 1 else None


def get_candidates_to_scores(env):
    obs = env.render()
    if obs in env.cache: 
        print('cache hit')
        return env.cache[obs]
    print('call gpt')
    responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)
    candidates_to_scores = {}
    for response in responses:
        parsed_response = parse_response(response)
        if parsed_response:
            for candidate, score in parsed_response:
                candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
        # choose candiate with highest score
    # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
    env.cache[obs] = candidates_to_scores
    return candidates_to_scores

def propose_score(env, idx):
    obs = env.reset(idx)
    done = False
    infos = []
    while not done:
        responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)
        candidates_to_scores = {}
        for response in responses:
            parsed_response = parse_response(response)
            if parsed_response:
                for candidate, score in parsed_response:
                    candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
        # choose candiate with highest score
        print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
        if len(candidates_to_scores) == 0:
            break
        candidates =  sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)
        for candidate in candidates:
            env_ = copy.deepcopy(env)
            env_.step(candidate)
            if not any(_ == 2 for _ in env_.status):
                break
        print(candidate)
        # candidate = input()
        obs, r, done, info = env.step(candidate)
        print(obs)
        print(env.steps, info)
        print('-------------------\n\n\n')
        infos.append(info)
    return infos

In [13]:
def mcts(env, num_simulations, max_depth):
    root_state = env.render()  # Initial state of the environment

    # Create the root node of the MCTS tree
    root_node = MCTSNode(state=root_state)

    # MCTS lppo
    for _ in range(num_simulations):
        selected_node = root_node
        while not selected_node.is_terminal() and selected_node.is_fully_expanded():
            selected_node = selected_node.select_child()

        if not selected_node.is_terminal():
            new_state = selected_node.expand()
            new_node = MCTSNode(state=new_state, parent=selected_node)
            selected_node.add_child(new_node)
            selected_node = new_node

        if not selected_node.is_terminal():
            result = selected_node.simulate(max_depth)
        selected_node.backpropagate(result)

    best_action = root_node.get_best_action()
    return best_action

In [17]:
import random


class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0

    def is_fully_expanded(self):
        return len(self.children) == len(self.get_all_actions())

    def is_terminal(self):
        return self.state.is_done()  

    def expand(self):
        untried_actions = [a for a in self.get_all_actions() if a not in self.get_child_actions()]
        if not untried_actions:
            return None
        action = random.choice(untried_actions)
        new_state = self.state.step(action)  
        return new_state

    def get_all_actions(self):
        return self.state.get_all_actions()  

    def get_child_actions(self):
        return [node.state for node in self.children]

    def select_child(self):
        exploration_param = 1.0
        return max(self.children, key=lambda node: node.value / (node.visits + 1e-6) + exploration_param * (self.visits ** 0.5) / (node.visits + 1e-6))

    def simulate(self, max_depth):
        current_depth = 0
        state = self.state
        while not state.is_done() and current_depth < max_depth:
            action = random.choice(state.get_all_actions())
            state = state.step(action)  
            current_depth += 1
        return state.get_reward()  

    def backpropagate(self, result):
        node = self
        while node is not None:
            node.visits += 1
            node.value += result
            node = node.parent

    def add_child(self, node):
        self.children.append(node)

    def get_best_action(self):
        best_child = max(self.children, key=lambda node: node.visits)
        return self.get_action_to_reach_child(best_child)

    def get_action_to_reach_child(self, child):
        return [action for action in self.get_all_actions() if self.state.step(action) == child.state][0]

class MiniCrosswordsMCTSTask(MiniCrosswordsTask):
    def __init__(self, file):
        super().__init__(file=file)
    
    # def reset(self, idx):
    #     super.env.reset(idx)
    
    def mcts(self, num_simulations, max_depth):
        actions = []
        infos = []
        num_simulations = 100
        max_depth = 100  
        max_per_state = 3
        
        root_state = self.env.render()  

        # root_node = MCTSNode(state=root_state)

        # for _ in range(num_simulations):
        #     selected_node = root_node
        #     while not selected_node.is_terminal() and selected_node.is_fully_expanded():
        #         selected_node = selected_node.select_child()

        #     if not selected_node.is_terminal():
        #         new_state = selected_node.expand()
        #         if new_state is not None:
        #             new_node = MCTSNode(state=new_state, parent=selected_node)
        #             selected_node.add_child(new_node)
        #             selected_node = new_node

        #     if not selected_node.is_terminal():
        #         result = selected_node.simulate(max_depth)

        #     selected_node.backpropagate(result)

        # # Return the best action based on the most visited child node
        # best_action = root_node.get_best_action()
        # return best_action

    def evaluate(self, x: str, y: str, n_evaluate_sample: int) -> int:
        self.set_status(x, y)
        assert n_evaluate_sample == 1 
        count = {'sure': 0, 'maybe': 0, 'impossible': 0}
        for ans, data, status in zip(self.env.ans, self.env.data, self.env.status):
            if ans.count('_') >= 4:
                continue
            ans = ' '.join(ans.lower())
            line = f'{data}: {ans}'
            prompt = propose_prompt.format(input=line)
            res = gpt(prompt, model='gpt-4', n=8)[0]
            print(line)
            print(res)
            print()
            proposals = self.propose_outputs_unwrap(x, y, [res], n_max_propose=-1)
            res = proposals[0].split(' ')[1].strip()
            if res in count:
                count[res] += 1
        print(count)
        return count

    def set_status(self, x: str, y: str):
        super().set_status(x, y)

    def reset(self, idx):
        self.env.reset(idx) 

def mcts1(env, actions, infos, num_simulations, max_depth, max_per_state):
    # get candidate thoughts using MCTS
    action = mcts(env, num_simulations, max_depth)

    obs, r, done, info = env.step(action)
    r = info['r_word']
    if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status):  # not violating any existing constraints
        cnt_per_state += 1
        if cnt_per_state > max_per_state:
            return

        count = env.prompt_status()
        actions.append(action)

        print(len(infos))
        print(actions)
        print(env.render_board())
        print(info)
        print(count)
        if infos:
            best = max(infos, key=lambda x: x['info']['r_word'])
            print('best', best)
        print('--------------')
        print()

        info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}
        infos.append(info)
        if not prune or count['impossible'] < 1:  # only continue if the current status is possible
            mcts1(env, actions, infos, num_simulations, max_depth, max_per_state)
        actions.pop()
    env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)

# Crosswords
infoss = []
for i in range(0, 100, 5):
    task = MiniCrosswordsMCTSTask(file='/home/ziyu/code/LLMs/tot-llm/src/tot/data/crosswords/mini0505.json')  # Replace 'path_to_csv_file' with the actual file path
    task.reset(i)
    infos = []
    actions = []
    num_simulations = 100
    max_depth = 100  # Adjust this value as per your requirements
    max_per_state = 3
    task.mcts(actions, infos, num_simulations, max_depth, max_per_state)
    infoss.append(infos)
    with open('logs/crosswords/infoss_mcts.json', 'w') as fout:
        json.dump(infoss, fout)


TypeError: mcts() takes 3 positional arguments but 6 were given