Constructing LeanDojo Benchmark (Lean 4)
===================================

This script uses [LeanDojo](https://leandojo.org/) to construct LeanDojo Benchmark 4 in the appendix of our paper:

[LeanDojo: Theorem Proving with Retrieval-Augmented Language Models](https://leandojo.org/)      
Under review at NeurIPS (Datasets and Benchmarks Track), 2023  
[Kaiyu Yang](https://yangky11.github.io/), [Aidan Swope](https://aidanswope.com/about), [Alex Gu](https://minimario.github.io/), [Rahul Chalamala](https://rchalamala.github.io/), [Peiyang Song](https://www.linkedin.com/in/peiyang-song-3279b3251/), [Shixing Yu](https://billysx.github.io/), [Saad Godil](https://www.linkedin.com/in/saad-godil-9728353/), [Ryan Prenger](https://www.linkedin.com/in/ryan-prenger-18797ba1/), [Anima Anandkumar](http://tensorlab.cms.caltech.edu/users/anima/)

The dataset is constructed from [mathlib4](https://github.com/leanprover-community/mathlib4/tree/5a919533f110b7d76410134a237ee374f24eaaad) (`5a919533f110b7d76410134a237ee374f24eaaad`) and will be saved to `../leandojo_benchmark_4`. It includes 2000 theorems for validation, 2000 theorems for testing, and the rest for training. Please refer to our paper for details. For most use cases, you shouldn't need to generate the data and can directly use our official LeanDojo Benchmark 4 downloadable [here](https://zenodo.org/record/8040110).

This script is for Lean 4. We also have a more detailed [version for Lean3](https://github.com/lean-dojo/LeanDojo/blob/main/scripts/generate-benchmark-lean3.ipynb).


In [4]:
import json
import shutil
import random
from copy import copy
from pathlib import Path
from loguru import logger
from datetime import datetime
from typing import Dict, List, Union

import lean_dojo
from lean_dojo import *

random.seed(3407)  # https://arxiv.org/abs/2109.08203

URL = "https://github.com/leanprover-community/mathlib4"
COMMIT = "5a919533f110b7d76410134a237ee374f24eaaad"
DST_DIR = Path("../leandojo_benchmark_4")
NUM_VAL = NUM_TEST = 2000

In [6]:
SPLIT_NAME = str  # train/val/test
SPLIT = Dict[SPLIT_NAME, List[TracedTheorem]]
SPLIT_STRATEGY = str

In [7]:
def _split_sequentially(
    traced_theorems: List[TracedTheorem],
) -> SPLIT:
    """Split ``traced_theorems`` sequentially into train/val/test."""
    num_theorems = len(traced_theorems)
    num_train = num_theorems - NUM_VAL - NUM_TEST
    return {
        "train": traced_theorems[:num_train],
        "val": traced_theorems[num_train : num_train + NUM_VAL],
        "test": traced_theorems[num_train + NUM_VAL :],
    }


def split_randomly(
    traced_theorems: List[TracedTheorem],
) -> SPLIT:
    """Split ``traced_theorems`` randomly into train/val/test."""
    logger.info("Splitting the theorems randomly")
    traced_theorems = copy(traced_theorems)
    random.shuffle(traced_theorems)
    return _split_sequentially(traced_theorems)

In [8]:
def split_data(traced_repo: TracedRepo) -> Dict[SPLIT_STRATEGY, SPLIT]:
    traced_theorems = traced_repo.get_traced_theorems()
    logger.info(f"{len(traced_theorems)} theorems in total")

    return {
        "random": split_randomly(traced_theorems),
    }

In [9]:
def export_data(
    traced_repo: TracedRepo,
    splits: Dict[SPLIT_STRATEGY, SPLIT],
    dst_path: Union[str, Path],
    **kwargs,
) -> None:
    """Export a traced repo whose theorems have been splitted to ``dst_path``."""
    if isinstance(dst_path, str):
        dst_path = Path(dst_path)
    if dst_path.exists():
        logger.warning(f"{dst_path} already exists. Removing it now.")
        shutil.rmtree(dst_path)

    # Export the proofs.
    for strategy, split in splits.items():
        split_dir = dst_path / strategy
        split_dir.mkdir(parents=True)
        for name, theorems in split.items():
            data = []
            num_tactics = 0
            for thm in theorems:
                tactics = [
                    {
                        "tactic": t.tactic,
                        # "annotated_tactic": t.get_annotated_tactic(),
                        "state_before": t.state_before,
                        "state_after": t.state_after,
                    }
                    for t in thm.get_traced_tactics()
                    if t.state_before != "no goals"
                    and "·" not in t.tactic  # Ignore "·".
                ]
                num_tactics += len(tactics)
                data.append(
                    {
                        "url": thm.repo.url,
                        "commit": thm.repo.commit,
                        "file_path": str(thm.theorem.file_path),
                        "full_name": thm.theorem.full_name,
                        "start": list(thm.start),
                        "end": list(thm.end),
                        "traced_tactics": tactics,
                    }
                )
            oup_path = split_dir / f"{name}.json"
            json.dump(data, oup_path.open("wt"))
            logger.info(
                f"{len(theorems)} theorems and {num_tactics} tactics saved to {oup_path}"
            )

    # Export the licenses.
    license_dir = dst_path / "licenses"
    license_dir.mkdir()
    all_repos = [traced_repo.repo] + list(traced_repo.dependencies.values())
    for repo in all_repos:
        lic = repo.get_license()
        if lic is None:
            continue
        with (license_dir / repo.name).open("wt") as oup:
            oup.write(lic)
    with (license_dir / "README.md").open("wt") as oup:
        oup.write(
            "This directory contains licenses of Lean repos used to generate this dataset. The dataset itself is released under [CC BY 2.0](https://creativecommons.org/licenses/by/2.0/)."
        )

    # Export metadata.
    metadata = dict(kwargs)
    metadata["creation_time"] = str(datetime.now())
    metadata["from_repo"] = {
        "url": traced_repo.repo.url,
        "commit": traced_repo.repo.commit,
    }
    metadata["leandojo_version"] = lean_dojo.__version__
    json.dump(metadata, (dst_path / "metadata.json").open("wt"))

In [10]:
repo = LeanGitRepo(URL, COMMIT)
traced_repo = trace(repo)
splits = split_data(traced_repo)
export_data(traced_repo, splits, DST_DIR, dataset_name="LeanDojo Benchmark 4")

2023-06-19 19:55:03.189 | INFO     | lean_dojo.data_extraction.trace:trace:163 - Loading the traced repo
2023-06-19 19:55:05,943	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
  6%|██▍                                     | 233/3782 [00:08<00:51, 68.32it/s][2m[33m(raylet)[0m [2023-06-19 19:55:15,864 E 556070 556233] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-06-19_19-55-03_522657_485685 is over 95% full, available space: 27169980416; capacity: 1887507697664. Object creation will fail if spilling is required.
 14%|█████▍                                  | 520/3782 [00:18<00:47, 68.07it/s][2m[33m(raylet)[0m [2023-06-19 19:55:25,869 E 556070 556233] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-06-19_19-55-03_522657_485685 is over 95% full, available space: 27169955840; capacity: 1887507697664. Object creation will fail if spilling is required.
 20%|███████▊                                | 73

 78%|██████████████████████████████▍        | 2947/3782 [04:50<34:16,  2.46s/it][2m[33m(raylet)[0m [2023-06-19 19:59:16,148 E 556070 556233] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-06-19_19-55-03_522657_485685 is over 95% full, available space: 27160416256; capacity: 1887507697664. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-06-19 19:59:26,156 E 556070 556233] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-06-19_19-55-03_522657_485685 is over 95% full, available space: 27160387584; capacity: 1887507697664. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-06-19 19:59:36,175 E 556070 556233] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-06-19_19-55-03_522657_485685 is over 95% full, available space: 27160371200; capacity: 1887507697664. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-06-19 19:59:46,181 E 556070 556233] (raylet) file_system_monitor