In [42]:
import ast
import json
import re
from pathlib import Path

In [2]:
src = "/home/ubuntu/ohm-tree-filesys-az/plasma-converter"
run_name = "fast_mcts_full_100_2024-10-29_00-38-36"

In [3]:
problems = []

In [4]:
with open("/home/ubuntu/ohm-tree-filesys-az/plasma-converter/datasets/minif2f.jsonl", 'r') as file:
    for line in file.readlines():
        problem = json.loads(line.strip())
        problems.append(problem)

In [5]:
problems[:3]

[{'name': 'amc12a_2019_p21',
  'split': 'valid',
  'informal_prefix': '/-- Let $z=\\frac{1+i}{\\sqrt{2}}.$What is $\\left(z^{1^2}+z^{2^2}+z^{3^2}+\\dots+z^{{12}^2}\\right) \\cdot \\left(\\frac{1}{z^{1^2}}+\\frac{1}{z^{2^2}}+\\frac{1}{z^{3^2}}+\\dots+\\frac{1}{z^{{12}^2}}\\right)?$\n\n$\\textbf{(A) } 18 \\qquad \\textbf{(B) } 72-36\\sqrt2 \\qquad \\textbf{(C) } 36 \\qquad \\textbf{(D) } 72 \\qquad \\textbf{(E) } 72+36\\sqrt2$ Show that it is \\textbf{(C) }36.-/\n',
  'formal_statement': 'theorem amc12a_2019_p21 (z : ℂ) (h₀ : z = (1 + Complex.I) / Real.sqrt 2) :\n  ((∑ k : ℤ in Finset.Icc 1 12, z ^ k ^ 2) * (∑ k : ℤ in Finset.Icc 1 12, 1 / z ^ k ^ 2)) = 36 := by\n',
  'goal': 'z : ℂ\nh₀ : z = (1 + Complex.I) / ↑√2\n⊢ (∑ k ∈ Finset.Icc 1 12, z ^ k ^ 2) * ∑ k ∈ Finset.Icc 1 12, 1 / z ^ k ^ 2 = 36',
  'header': 'import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n'},
 {'name': 'amc12a_2015_p10',
  'split': 'valid',
  'informal_prefix': '/

In [6]:
solved_problems = []

for file_path in Path(src + "/results/" + run_name).iterdir():
    if file_path.is_file() and file_path.name.endswith(".txt"):
        with open(file_path, 'r') as file:
            lines = file.readlines()
            if lines[-1] == "Result: 1.0\n":
                # append the name of the file without extension
                solved_problems.append(file_path.stem)

In [7]:
solved_problems[:10]

['mathd_numbertheory_236',
 'mathd_numbertheory_188',
 'mathd_numbertheory_132',
 'mathd_numbertheory_24',
 'mathd_numbertheory_200',
 'mathd_numbertheory_45',
 'mathd_algebra_28',
 'mathd_numbertheory_466',
 'amc12_2001_p9',
 'mathd_numbertheory_169']

In [8]:
len(solved_problems)

274

In [9]:
problems_dict = {problem["name"]: problem for problem in problems}

In [13]:
solved_valid = [problem for problem in solved_problems if problems_dict[problem]["split"]=="valid"]
solved_test = [problem for problem in solved_problems if problems_dict[problem]["split"]=="test"]

In [131]:
def parse_tree_txt_to_dict(file_path):
    # Read the file contents
    with open(file_path, 'r') as file:
        content = file.read()

    # Split the content by double dashed lines
    sections = [sec.strip() for sec in content.split('--------------------------------------------------------------------------------\n' * 2) if sec.strip()]

    # print("section 0", sections[0])
    # Initialize a list to store parsed dictionaries
    parsed_data = []
    kv_pattern = re.compile(r"^\s*(\w.*?)\s*\|\s*(.*)$")

    # Define regex patterns to parse key-value pairs and LeanState blocks

    for section in sections:
        lines = section.splitlines()
        result = {}
        
        
        for line in lines:
            kv_match = kv_pattern.match(line)
            # Match lines with `key | value` format
            if line.startswith("LeanState"):
                lean_state_pattern = re.compile(
                    r"code=['\"](.*?)['\"],\s*"
                    r'depth=(\d+),\s*'
                    r"tactic_state=['\"](.*?)['\"],\s*"
                    r'dead=(True|False)\s*$',
                    re.DOTALL
                )
                lean_state_match = lean_state_pattern.search(line[len("LeanState("):-1])
                # print(line[len("LeanState("):-1])
                
                if lean_state_match:
                    code = lean_state_match.group(1)
                    depth = lean_state_match.group(2)
                    tactic_state = lean_state_match.group(3)
                    dead = lean_state_match.group(4)

                    code = code.replace('\\n', '\n')
                    tactic_state = tactic_state.replace('\\n', '\n')
                    
                    result['LeanState'] = {
                        'code': code,
                        'depth': int(depth),
                        'tactic_state': tactic_state,
                        'dead': dead == 'True'
                    }
            elif kv_match:
                key, value = kv_match.groups()
                key = key.strip()  # Remove extra spaces around the key
                value = value.strip()  # Remove extra spaces around the value
                
                # Convert value to appropriate type
                if value.startswith("[") and value.endswith("]"):  # Detect arrays
                    # Convert to a list of floats or ints
                    value = [float(v) if '.' in v else int(v) for v in re.findall(r"[-+]?\d*\.\d+|\d+", value)]
                elif value.startswith("(") and value.endswith(")"):  # Detect tuples
                    value = ast.literal_eval(value)
                elif value.isdigit():  # Convert to integer if it's a digit
                    value = int(value)
                else:
                    try:
                        value = float(value)  # Try converting to float
                    except ValueError:
                        pass  # Keep as string if not a float
                
                # Add to the dictionary
                result[key] = value
        
        # Append the parsed dictionary to the list
        parsed_data.append(result)

    return parsed_data

In [132]:
# TEST above
file_path = "/home/ubuntu/ohm-tree-filesys-az/plasma-converter/outputs/fast_mcts_full_100_2024-10-29_00-38-36/numbertheory_sqmod4in01d_tree.txt"
data = parse_tree_txt_to_dict(file_path)

# Display the parsed data (TEST)
for idx, section in enumerate(data):
    print(f"Section {idx + 1}:")
    for key, value in section.items():
        print(f"  {key}: {value}")

    print()

Section 1:
  depth: 0
  index_path: ()
  number_visits: 11
  hash: 7566622256351434870
  total_value: -11.0
  child priors: [1, 1, 1, 1]
  child number visits: [4, 4, 1, 1]
  child total value: [4, 0, 0, 0]
  LeanState: {'code': '', 'depth': 0, 'tactic_state': 'a : ℤ\n⊢ a ^ 2 % 4 = 0 ∨ a ^ 2 % 4 = 1', 'dead': False}

Section 2:
  depth: 1
  index_path: (0,)
  number_visits: 4
  hash: -8.60438899511822e+18
  total_value: 4.0
  child priors: []
  child number visits: []
  child total value: []
  LeanState: {'code': '  have h : a ^ 2 % 4 = 0 ∨ a ^ 2 % 4 = 1 := by\n    have h₁ : a % 4 = 0 ∨ a % 4 = 1 ∨ a % 4 = 2 ∨ a % 4 = 3 := by omega\n    rcases h₁ with (h₁ | h₁ | h₁ | h₁) <;>\n    simp [h₁, pow_two, Int.mul_emod, Int.add_emod]\n  exact h\n', 'depth': 1, 'tactic_state': '', 'dead': False}

Section 3:
  depth: 1
  index_path: (1,)
  number_visits: 4
  hash: -6.89692017069033e+18
  total_value: 0.0
  child priors: [1, 1, 1, 1]
  child number visits: [1, 1, 1, 0]
  child total value: [0, 0,

In [136]:
save_dir = src + "/knowledge/tactic_states.txt"

for problem in solved_valid:
    # if problem!='mathd_numbertheory_236':
    #     continue
    lean_code = None
    with open(src + "/outputs/" + run_name + "/" + problem + ".txt", 'r') as file:
        lines = file.readlines()
        last_line = lines[-1]
        start_idx = last_line.find("code=") + len("code=")+1
        end_idx = last_line.find(", depth")
        lean_code = last_line[start_idx:end_idx-1].replace("\\n", "\n")

    tree_data  = parse_tree_txt_to_dict(src + "/outputs/" + run_name + "/" + problem + "_tree.txt")
    # print(tree_data)
    final_section = None
    for section in tree_data:
        # print("section code", section['LeanState']['code'])
        # print("lean_code", lean_code)
        if 'LeanState' not in section:
            print(f"Problem {problem} has no LeanState")
            for section in tree_data:
                print(section)
            break
        if section['LeanState']['code'] == lean_code:
            final_section = section
            break
    if final_section is None:
        print("Problem not found", problem)
        print("lean_code", lean_code)
        print("tree_data", tree_data)
        for section in tree_data:
            print(section)
        print("die")
        break

    tactic_states = []
    idx_path = final_section['index_path']
    print("path", idx_path)
    for section in tree_data:
        if len(section['index_path'])<=len(idx_path) and idx_path[:len(section['index_path'])]==section['index_path']:
            tactic_states.append(section['LeanState']['tactic_state'])
    # print(final_section)
    print(tactic_states)
    # save each string into save_dir, creating the file if it doesn't exist
    with open(save_dir, 'a') as file:
        for tactic_state in tactic_states:
            if tactic_state != "":
                file.write(repr(tactic_state)[1:-1])
                file.write("\n")

path (1,)
['⊢ 1999 ^ 2000 % 5 = 1', '']
path (2,)
['⊢ Nat.gcd 180 168 = 12', '']
path (1,)
['⊢ 2004 % 12 = 0', '']
path (2, 0)
['⊢ (∑ k ∈ Finset.Icc 1 9, 11 ^ k) % 100 = 59', '⊢ (∑ k ∈ LocallyFiniteOrder.finsetIcc 1 9, 11 ^ k) % 100 = 59', '']
path (1, 0)
['⊢ 139 % 11 = 7', 'case h\n⊢ 7 = 139 % 11', '']
path (0,)
['⊢ Nat.gcd 6432 132 + 11 = 23', '']
path (0,)
['c : ℝ\nf : ℝ → ℝ\nh₀ : ∀ (x : ℝ), f x = 2 * x ^ 2 + 5 * x + c\nh₁ : ∃ x, f x ≤ 0\n⊢ c ≤ 25 / 8', '']
path (0, 0)
['⊢ (∑ k ∈ Finset.range 11, k) % 9 = 1', '⊢ (∑ k ∈ Finset.range 0,\n                                                                                                                      k).succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ.succ %\n      9 =\n    1', '']
path (0,)
['f : ℝ → ℝ\nh₀ : ∀ x >

Embed them

In [143]:
from typing import List, Union

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from transformers import AutoModelForTextEncoding, AutoTokenizer

In [138]:
tokenizer = AutoTokenizer.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")
model = AutoModelForTextEncoding.from_pretrained("kaiyuy/leandojo-lean4-retriever-byt5-small")

In [139]:
@torch.no_grad()
def encode(s: Union[str, List[str]]) -> torch.Tensor:
    """Encode texts into feature vectors."""
    if isinstance(s, str):
        s = [s]
        should_squeeze = True
    else:
        should_squeeze = False
    tokenized_s = tokenizer(s, return_tensors="pt", padding=True)
    hidden_state = model(tokenized_s.input_ids).last_hidden_state
    lens = tokenized_s.attention_mask.sum(dim=1)
    features = (hidden_state * tokenized_s.attention_mask.unsqueeze(2)).sum(dim=1) / lens.unsqueeze(1)
    if should_squeeze:
      features = features.squeeze()
    return features

@torch.no_grad()
def retrieve(state: str, premises: List[str], k: int) -> List[str]:
    """Retrieve the top-k premises given a state."""
    state_emb = encode(state)
    premise_embs = encode(premises)
    scores = (state_emb @ premise_embs.T)
    topk = scores.topk(k).indices.tolist()
    return [premises[i] for i in topk]

In [140]:
# load each line of save_dir into list known_tactic_states
with open(save_dir, 'r') as file:
    known_tactic_states = [line.strip() for line in file.readlines()]

In [141]:
known_tactic_states[:3]

['⊢ 1999 ^ 2000 % 5 = 1', '⊢ Nat.gcd 180 168 = 12', '⊢ 2004 % 12 = 0']

In [142]:
embeddings = encode(known_tactic_states)

In [144]:
emb_save_dir = src + "/knowledge/tactic_states_embeddings.npy"
# save embeddings to emb_save_dir
np.save(emb_save_dir, embeddings.cpu().numpy())