In [None]:
# This test is to measure the speed of the model on a large dataset.
from src.networks.network import Network
from src.games.lean_game import LeanGameState
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
from typing import Optional, Tuple
import json
import os
import time

import modal
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:

class ProverLLM(Network):
    def __init__(self,
                 base_model: Optional[AutoModelForCausalLM] = None,
    ):
        """
        A ProverLLM model that uses a transformer-based language model for proof generation
        and MCTS decision-making. It implements:
        1. Given a game state, generate a prompt (either for the V/P-heads or for completion).
        2. Given this prompt, tokenize it.
        3. Given a tokenized prompt, run it through the base model to get a completion.
        4. Given a tokenized prompt, run it through the base model to get an intermediate state.
        5. Use the intermediate state to get a value estimate and a policy output.

        The worker/MCTS immediately outside of ProverLLM should be agnostic to the LLM-related
        principles including the specific prompting methods etc. To this end,
        all worker/MCTS interations with this class ("public methods") should
        involve LeanGameStates only.

        The controller and training loop necessarily need to know about the internal
        details of the model; the five steps above can be short-circuited in different
        ways by the controller.

        Parameters:
        ----------
        base_model: Optional[AutoModelForCausalLM]
            An optional pre-trained base model. If None, a default model is loaded.
        """
        super(ProverLLM, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(
            'deepseek-ai/DeepSeek-Prover-V1.5-RL',
            trust_remote_code=True
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        if base_model is None:
            self.base_model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
                'deepseek-ai/DeepSeek-Prover-V1.5-RL',
                trust_remote_code=True,
                device_map='auto'
            )
        else:
            self.base_model = base_model

        self.value_head = nn.Sequential(
            nn.Linear(4096, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        ).to("cuda")

        self.policy_head = nn.Sequential(
            nn.Linear(4096, 512),
            nn.ReLU(),
            nn.Linear(512, 100)
        ).to("cuda")

########################## Utilities ##########################

    def policy_value_state_dict(self) -> dict:
        """
        Usually, we do not want to be saving/loading the entire model,
        just the policy and value heads.

        Returns:
        -------
        policy_value_state_dict: dict
            The state dictionaries for the policy and value heads.
        """
        if self.llm_only:
            raise ValueError("Cannot get state dicts in llm_only mode.")
        return {
            'policy_head': self.policy_head.state_dict(),
            'value_head': self.value_head.state_dict()
        }

    def load_policy_value_state_dict(self, policy_value_state_dict: dict):
        """
        Set the state dictionaries for the policy and value heads.

        Parameters:
        ----------
        policy_value_state_dict: dict
            The state dictionaries for the policy and value heads.
        """
        if self.llm_only:
            raise ValueError("Cannot set state dicts in llm_only mode.")
        self.policy_head.load_state_dict(
            policy_value_state_dict['policy_head'])
        self.value_head.load_state_dict(policy_value_state_dict['value_head'])

########################## Public Interface ##########################

    def mcts_forward(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Perform a forward pass through the model using MCTS.

        Parameters:
        ----------
        state: LeanGameState
            The current game state.

        Returns:
        -------
        policy_output: torch.Tensor
            The policy output from the policy head.
        value_output: torch.Tensor
            The value estimate from the value head.
        """

        input_ids, attention_mask = self.tokenize(prompt)
        intermediate_output = self.get_intermediate_state(
            input_ids, attention_mask)
        policy_output, value_output = self.policy_and_value(
            intermediate_output)
        return policy_output, value_output

    def forward(self, state: LeanGameState) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Runs the forward pass of the network. Returns a policy and a value.

        Parameters:
        ----------
        state: LeanGameState
            The current game state.

        Returns:
        -------
        policy_output: torch.Tensor
            The policy output from the policy head.
        value_output: torch.Tensor
            The value estimate from the value head.
        """
        return self.mcts_forward(state)


########################## Part 2: Given this prompt, tokenize it ##########################

    def tokenize(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Tokenizes the given prompt and returns the input IDs and attention mask.

        Parameters:
        ----------
        prompt: str
            The prompt to tokenize.

        Returns:
        -------
        Tuple[torch.Tensor, torch.Tensor]
            The input IDs and attention mask.
        """
        tokens = self.tokenizer(
            prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
        return tokens['input_ids'], tokens['attention_mask']

########################## Part 3: Given a tokenized prompt, run it through the base model to get a completion ##########################

    def complete(self, input_ids: torch.Tensor, attention_mask=None, max_length=1000) -> str:
        """
        Returns most likely completed proof (max'ed with 1000 tokens)

        Parameters:
        ----------
        input_ids: torch.Tensor
            The input token IDs.
        attention_mask: torch.Tensor
            The attention mask for the input.
        max_length: int
            The maximum length of the generated text.

        Returns:
        -------
        generated_text: str
            The generated proof text.
        """

        base_output = self.base_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=input_ids.shape[1] + max_length
        )
        generated_text = self.tokenizer.decode(
            base_output[0], skip_special_tokens=True)
        return generated_text

########################## Part 4: Given a tokenized prompt, run it through the base model to get an intermediate state ##########################

    def get_intermediate_state(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Returns the intermediate hidden state output from the base model.

        Parameters:
        ----------
        input_ids: torch.Tensor
            The input token IDs.
        attention_mask: torch.Tensor
            The attention mask for the input.

        Returns:
        -------
        intermediate_output: torch.Tensor
            The hidden state output from the base model.
        """

        base_output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        hidden_states = base_output.hidden_states

        intermediate_output = hidden_states[25][0][-1]
        # debug
        print("Intermediate output shape:", intermediate_output.shape)
        return intermediate_output

########################## Part 5: Use the intermediate state to get a value estimate and a policy output ##########################

    def policy_and_value(self, intermediate_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns policy and value outputs for the given intermediate output.

        Parameters:
        ----------
        intermediate_output: torch.Tensor
            The intermediate output from the base model.

        Returns:
        -------
        policy_output: torch.Tensor
            The policy output from the policy head.
        value_output: torch.Tensor
            The value estimate from the value head.
        """

        policy_output = self.policy_head(intermediate_output)
        value_output = self.value_head(intermediate_output)

        return policy_output, value_output



In [None]:
model = ProverLLM()

In [None]:

data = []
with open(os.path.join('datasets', 'minif2f.jsonl'), 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))

# convert each data point to a string for inference.
# Example:
# {
#     "name": "amc12a_2019_p21",
#     "split": "valid",
#     "informal_prefix": "/-- Let $z=\\frac{1+i}{\\sqrt{2}}.$What is $\\left(z^{1^2}+z^{2^2}+z^{3^2}+\\dots+z^{{12}^2}\\right) \\cdot \\left(\\frac{1}{z^{1^2}}+\\frac{1}{z^{2^2}}+\\frac{1}{z^{3^2}}+\\dots+\\frac{1}{z^{{12}^2}}\\right)?$\n\n$\\textbf{(A) } 18 \\qquad \\textbf{(B) } 72-36\\sqrt2 \\qquad \\textbf{(C) } 36 \\qquad \\textbf{(D) } 72 \\qquad \\textbf{(E) } 72+36\\sqrt2$ Show that it is \\textbf{(C) }36.-/\n",
#     "formal_statement": "theorem amc12a_2019_p21 (z : ℂ) (h₀ : z = (1 + Complex.I) / Real.sqrt 2) :\n  ((∑ k : ℤ in Finset.Icc 1 12, z ^ k ^ 2) * (∑ k : ℤ in Finset.Icc 1 12, 1 / z ^ k ^ 2)) = 36 := by\n",
#     "goal": "z : ℂ\nh₀ : z = (1 + Complex.I) / ↑√2\n⊢ (∑ k ∈ Finset.Icc 1 12, z ^ k ^ 2) * ∑ k ∈ Finset.Icc 1 12, 1 / z ^ k ^ 2 = 36",
#     "header": "import Mathlib\nimport Aesop\n\nset_option maxHeartbeats 0\n\nopen BigOperators Real Nat Topology Rat\n\n"
# }

# We should concatenate the informal prefix, header, formal statement, and goal together.

input_data = []
for d in data:
    input_data.append(d['goal'] + "\n" + d['informal_prefix'] + '```lean4\n' +
                      d['header'] + d['formal_statement'])


In [None]:
for test_size in [1, 10, 12, 14, 16, 18, 20]:
    print(f"Testing with {test_size} data points.")
    input_subset = []
    while len(input_subset) < test_size:
        input_subset.extend(input_data)

    input_subset = input_subset[:test_size]

    # Inference on the input data, timing it.
    start = time.time()
    policy, value = model.mcts_forward(input_subset)
    end = time.time()
    print(f"Inference took {end-start} seconds.")
    print("output shape: ", policy.shape, value.shape)