In [11]:
import os
import sys
import time
import logging
import re
import wandb
from dataclasses import dataclass, field
from datasets import load_dataset, Dataset
from dataclasses import dataclass, field
from transformers import  Trainer, TrainingArguments
import random
import torch
import numpy as np
from dataclasses import dataclass, field
from peft import LoraConfig
from transformers import (
    set_seed,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Trainer, 
    TrainingArguments
)

from trl import (
    GRPOTrainer,
    ModelConfig as TRLModelConfig,
    ScriptArguments,
    TrlParser,
    get_peft_config,
    GRPOConfig
)

from transformers import AutoModelForCausalLM, AutoTokenizer




#Next_Steps
# normalization and symbolic steps:
# from latex2sympy2_extended import NormalizationConfig
# from math_verify import LatexExtractionConfig, parse, verify


In [12]:
def clean_and_process_isabelle_input(input_text):
    """
    Cleans and formats Isabelle input, ensuring proper handling of:
    - Removes unnecessary escape sequences.
    - Correctly retains Isabelle-specific symbols.
    - Ensures readability without adding extra characters.
    """
    # Remove unnecessary backslashes from Isabelle symbols
    input_text = re.sub(r'\\<', r'<', input_text)
    input_text = re.sub(r'\\>', r'>', input_text)
    
    # Normalize spaces, remove redundant whitespace
    input_text = re.sub(r'\s+', ' ', input_text).strip()

    # Ensure proper formatting for Isabelle constructs
    input_text = input_text.replace("lemma", "\nlemma") \
                           .replace("Extracted:", "\nExtracted:\n") \
                           .replace("declare", "\ndeclare") \
                           .replace("deduce", "\ndeduce")

    return input_text

In [24]:
class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """
    def __init__(self, working_dir, isa_path, theory_file, port=9000):
        sys.path.append(os.environ.get('PISA_PATH', ''))
        try:
            from pisa_client import initialise_env
            self.initialise_env = initialise_env
        except ImportError:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file = theory_file
        self.port = port

    def _initialize(self):
        """Initialize the PISA environment."""
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file=self.theory_file,
            working_directory=self.working_dir
        )
        return env

    def _exit(self, env):
        """Exit the environment and clean up resources."""
        try:
            env.post('exit')
        except Exception:
            pass
        os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
        os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")

    def _parse_output(self, obs):
        """Parse the sledgehammer output, returning the relevant part."""
        return obs.split('<hammer>')[0] if '<hammer>' in obs else ''

    def _run_step(self, step, i, tls_name, env):
        """Run a single proof step."""
        try:
            obs, reward, done, metadata = env.step_to_top_level_state(
                action=step,
                tls_name=tls_name,
                new_name=f'default_{i}'
            )
            return obs, reward, done, metadata, None
        except Exception as e:
            return '', 0, False, None, str(e)

    def _run_sledgehammer(self, step, i, tls_name, env):
        """Run sledgehammer or fallback heuristics on a step."""
        heuristics = [
            'by auto', 'by simp', 'by blast', 'by fastforce',
            'by force', 'by eval', 'by presburger', 'by sos',
            'by arith', 'by linarith', 'by (auto simp: field_simps)'
        ]
        for heuristic in heuristics:
            step_ = step.replace('normalhammer', heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = f'{heuristic} <hammer> {obs}'
                return obs, reward, done, metadata, error
        return self._run_step(step.replace("normalhammer", "sledgehammer"), i, tls_name, env)

    def check(self, statement_and_proof):
        """Check the given proof."""
        env = self._initialize()
        env.initialise()

        theory = self.wrap_theorem(statement_and_proof)
        steps = self.get_parsed(env, theory)

        result = self._check(env, steps)
        self._exit(env)

        # Output the result
        #print(result['success'])
        print("\n==== Success: %s" % result['success'])
        print("--- Complete proof:\n%s" % result['theorem_and_proof'])
        return result

    def _check(self, env, steps):
        """Run the proof steps and collect results."""
        success, reason, done = False, '', False
        step_results = []
        tls_name = 'default'

        for i, step in enumerate(steps):
            time0 = time.time()
            if 'normalhammer' in step or 'sledgehammer' in step:
                obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
            else:
                obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)

            step_time = time.time() - time0
            step_results.append({
                'index': i, 'step': step, 
                'output': self._parse_output(obs), 
                'step_time': step_time
            })

            if error:
                reason = error
                break
            tls_name = f'default_{i}'

        success = done and reward == 1.0
        return {
            'success': success,
            'reason': reason,
            'num_steps': len(steps),
            'last_step': len(step_results),
            'step_results': step_results,
            'theorem_and_proof': self.reconstruct(step_results) if success else ''
        }

    @staticmethod
    def reconstruct(step_results):
        """Reconstruct the complete proof."""
        return '\n'.join(
            step_result['output'].strip() if step_result['output'] else step_result['step'].strip()
            for step_result in step_results[1:]
        )

    @staticmethod
    def wrap_theorem(theorem):
        """Wrap the theorem in a theory file."""
        return (
            'theory Interactive imports HOL.HOL Complex_Main '
            '"HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" '
            '"Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" '
            '"HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem
        )

    @staticmethod
    def get_parsed(env, theory):
        """Parse the theory and extract proof steps."""
        raw_steps = env.post(f"<parse text> ${theory}")
        steps = [s.strip() for s in raw_steps.split('<SEP>') if s.strip() and s != '$']
        processed_steps = []
        for i, step in enumerate(steps):
            if step.lower() == "then" and (i == 0 or steps[i - 1].startswith("proof")):
                continue
            processed_steps.append(step)
        return processed_steps


In [25]:
sys.path.append('../')
os.environ['PISA_PATH'] = '/home/siai/Portal-to-ISAbelle/src/main/python'


verifier = Checker(
    working_dir='/home/siai/Isabelle2022/src/HOL/Examples',
    isa_path='/home/siai/Isabelle2022',
    theory_file='/home/siai/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=9000
)


In [26]:
theorem_and_sledgehammer_proof = """theorem amc12a_2008_p8:
  fixes x y::real
  assumes h0: "0 < x \<and> 0 < y"
    and h1: "y^3 = 1"
    and h2: "6 * x^2 = 2 * (6 * y^2)"
  shows "x^3 = 2 * sqrt 2"
  using assms
  by (smt (verit, best) mult_cancel_left2 one_power2 
      power2_eq_square power2_le_imp_le 
      power2_sum power3_eq_cube power_Suc_less 
      power_commutes power_gt1_lemma real_le_lsqrt 
      real_le_rsqrt)
"""

result = verifier.check(theorem_and_sledgehammer_proof)


TypeError: initialise_env() got an unexpected keyword argument 'theory_file'

In [6]:
logger = logging.getLogger(__name__)

@dataclass
class ProofTreeScriptArguments(ScriptArguments):
    """
    Script arguments for training proof trees with reinforcement learning.
    """
    reward_funcs: list[str] = field(
        default_factory=lambda: ["local_correctness"],
        metadata={"help": "List of reward functions: 'local_correctness', 'global_correctness', etc."},
    )
    dataset_name: str = field(
        default="kings-crown/Putnam",
        metadata={"help": "Name or path of the dataset (Hugging Face)."}
    )
    dataset_config: str = field(
        default=None,
        metadata={"help": "Dataset configuration name (if applicable)."}
    )
    dataset_train_split: str = field(
        default="train",
        metadata={"help": "Name of the training split in the dataset."}
    )
    dataset_test_split: str = field(
        default="test",
        metadata={"help": "Name of the test split in the dataset."}
    )

@dataclass
class ModelConfig:
    """
    Model configuration.
    """
    model_name_or_path: str = field(
        default="meta-llama/Llama-3.2-1B-Instruct",
        metadata={"help": "Pretrained model name or path from the HuggingFace hub or local directory."}
    )
    use_peft: bool = field(
        default=False,
        metadata={"help": "Whether to use PEFT for parameter-efficient fine-tuning."}
    )

In [7]:
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user provides a mathematical statement, and the Assistant "
    "generates a structured proof with hierarchical lemmas and steps. Please follow these guidelines:\n\n"
    "1) Encapsulate your reasoning or chain-of-thought in <think> ... </think>\n"
    "2) When you propose a lemma or a sub-proof, use <invoke> ... </invoke>\n"
    "3) Provide the final statement or conclusion in <answer> ... </answer>\n"
    "4) The user might provide context or partial solutions.\n"
)

In [8]:
def construct_dataset(dataset_split, verifier):
    """
    Map each example in the dataset split to a conversation prompt.
    Adds fields:
      - 'prompt': the constructed conversation as text.
      - 'target': the ground-truth Isabelle translation.
    """
    def map_fn(example):
        conversation = build_conversation_prompt(example, verifier=verifier)
        example["prompt"] = convert_conversation_to_text(conversation)
        example["target"] = example["isabelle_translation"]
        return example
    return dataset_split.map(map_fn)


In [9]:
def format_dataset(example):
    """
    Format (or verify) each example from the dataset.

    Expects the example to contain at least:
    - 'problem': The main statement or question
    - 'context': Additional context or partial solutions
    - 'isabelle_translation': A string with the theorem statement in Isabelle
    - 'depth_limit': An integer specifying how deep we can expand the proof tree --- Representational dataset with 2 as default
    """
    required_keys = ["problem", "context", "isabelle_translation", "depth_limit"]
    for key in required_keys:
        if key not in example:
            raise ValueError(f"Dataset example missing required field: '{key}'")
    return example

In [10]:
def build_conversation_prompt(example, verifier=None):
    """
    Build a conversation-style prompt with system/user roles.

    We incorporate:
      - The SYSTEM_PROMPT,
      - The user's 'problem',
      - The user's 'context' as well,
      - (Optional) A quick verification or summary of the Isabelle statement

    Returns a list of conversation turns (dict with role and content).
    """
    conversation = []

    # Start with the system prompt
    conversation.append({"role": "system", "content": SYSTEM_PROMPT})

    # The user provides the problem
    conversation.append({"role": "user", "content": example["problem"]})

    # The user also has some context or partial solutions
    if example["context"]:
        conversation.append({"role": "user", "content": example["context"]})

    # If you want to do an immediate check on the isabelle_translation:
    if verifier is not None:
        check_result = verifier.check(example["isabelle_translation"])
        feedback = (
            "<Verified the Isabelle statement successfully>"
            if check_result["success"]
            else f"<Isabelle check error: {check_result['reason']}>"
        )
        # The assistant provides verification feedback
        conversation.append({"role": "assistant", "content": feedback})
    else:
        # Provide a placeholder if no verifier
        conversation.append({"role": "assistant", "content": "<No verification performed>"} )

    return conversation

In [11]:
#Need to improve logic, simple version taken from the decompose paper

def is_globally_correct(proof_text, verifier):
    """
    Check global correctness by verifying the final proof text with the external Checker.
    If verification passes, we return True, else False.
    """
    result = verifier.check(proof_text)
    return result["success"]

def find_leaf_nodes(text):
    """
    Identify if text is a 'leaf' step. For instance, if there's no <invoke>, we might consider it a leaf.
    Return True if no further expansions are indicated.
    """
    # A simple heuristic: if the text does NOT have <invoke> tags, we treat it as a leaf (no further expansions).
    return ("<invoke>" not in text)

In [12]:
def local_correctness_reward(prompts, completions, verifier, **kwargs):
    rewards = []
    for proof_text in completions:
        try:
            verified = verifier.check(proof_text)["success"]
            rewards.append(1.0 if verified else 0.0)
        except Exception:
            rewards.append(0.0)
    return rewards


def global_correctness_reward(prompts, completions, verifier, **kwargs):
    rewards = []
    for proof_text in completions:
        try:
            verified = verifier.check(proof_text)["success"]
            has_answer_block = "<answer>" in proof_text and "</answer>" in proof_text
            rewards.append(1.0 if (verified and has_answer_block) else 0.0)
        except Exception:
            rewards.append(0.0)
    return rewards

VERIFIER = None

def local_correctness_reward_wrapper(prompts, completions, **kwargs):
    verifier = kwargs.get("verifier", VERIFIER)
    if verifier is None:
        raise ValueError("No verifier provided to the reward function.")
    # Pass prompts, then completions, then verifier
    return local_correctness_reward(prompts, completions, verifier, **kwargs)

def global_correctness_reward_wrapper(prompts, completions, **kwargs):
    verifier = kwargs.get("verifier", VERIFIER)
    if verifier is None:
        raise ValueError("No verifier provided to the reward function.")
    return global_correctness_reward(prompts, completions, verifier, **kwargs)



reward_funcs_registry = {
    "local_correctness": local_correctness_reward_wrapper,
    "global_correctness": global_correctness_reward_wrapper, 
}


In [13]:
class ValueFunction(torch.nn.Module):
    """
    A simple MLP that classifies each sequence embedding as 'correct' vs 'incorrect' (2 classes)
    """
    def __init__(self, hidden_size):
        super(ValueFunction, self).__init__()
        self.linear = torch.nn.Linear(hidden_size, 2)

    def forward(self, inputs):
        return self.linear(inputs)

In [14]:
def generate_proof_trees(model, tokenizer, batch, verifier, top_k_percentage=0.3, max_length=256):
    """
    Generate candidate proofs for each example in 'batch'.
    We do 1) an 'invoke' version, 2) a 'no_invoke' version, and also incorporate ground-truth mixing.
    """
    proof_trees = []
    lengths = [len(ex["problem"].split()) for ex in batch]
    sorted_idx = np.argsort(lengths)
    top_k_thresh = int(len(batch) * top_k_percentage)
    top_k_indices = sorted_idx[-top_k_thresh:]  # highest length => might need sub-lemmas

    for i, example in enumerate(batch):
        conversation = build_conversation_prompt(example, verifier=verifier)
        user_prompt = convert_conversation_to_text(conversation)
        if i in top_k_indices:
            user_prompt_invoke = user_prompt + "\nUse sub-lemmas: <invoke>Propose lemma</invoke>\n"
        else:
            user_prompt_invoke = user_prompt

    
        proof_text_invoke = run_generation(model, tokenizer, user_prompt_invoke, max_length=max_length)
        proof_text_no_invoke = run_generation(model, tokenizer, user_prompt, max_length=max_length)

      
        single_tree_invoke = [proof_text_invoke]
        single_tree_no_invoke = [proof_text_no_invoke]

        if find_leaf_nodes(proof_text_invoke) or "<invoke>" in proof_text_invoke:
            proof_trees.append(single_tree_invoke)
        if find_leaf_nodes(proof_text_no_invoke) or "<invoke>" in proof_text_no_invoke:
            proof_trees.append(single_tree_no_invoke)


        ground_truth = example["isabelle_translation"]
        gt_tree = [ground_truth]
        proof_trees.append(gt_tree)

    return proof_trees

In [15]:
def convert_conversation_to_text(conversation):
    """
    Converts the conversation (list of {role, content}) into a single textual prompt that T5 or GPT can understand.
    """
    lines = []
    for turn in conversation:
        if turn["role"] == "system":
            lines.append(f"[SYSTEM] {turn['content']}")
        elif turn["role"] == "user":
            lines.append(f"[USER] {turn['content']}")
        else:  # assistant
            lines.append(f"[ASSISTANT] {turn['content']}")
    return "\n".join(lines)

In [16]:
def run_generation(model, tokenizer, prompt_text, max_length=256):
    """
    Run the model's generate() method on the prompt text and return the decoded string.
    """
    inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True)
    gen_output = model.generate(
        **inputs,
        max_length=max_length,
        num_beams=4,
        early_stopping=True
    )
    decoded = tokenizer.decode(gen_output[0], skip_special_tokens=True)
    return decoded

In [17]:
def assign_rewards(proof_trees, verifier, reward_func):
    """
    Go through each 'proof tree'
    """
    completions = ["\n".join(tree) for tree in proof_trees]
    reward_values = reward_func(completions, verifier)
    return reward_values

In [18]:
def compute_weights(proof_trees, value_function, gamma=0.99):
    """
    For each proof tree, produce a weight that includes discounting and the value function's predictions.
    """
    weights = []
    for tree in proof_trees:
        weight = 1.0
        for depth, node_text in enumerate(tree):
            dummy_input_vec = torch.rand(1, 768) #dummy for first pass
            pred = value_function(dummy_input_vec)  # shape [1,2]
            pred_prob = torch.softmax(pred, dim=-1)[0, 1].item()  # Probability of "correct"
            weight *= (pred_prob * (gamma ** depth))
        weights.append(weight)
    return weights

In [19]:
def prepare_training_data(proof_trees, rewards, weights):
    """
    Create (prompt, target, weight) examples for each node in each proof tree.
    """
    training_data = []
    for tree, reward, weight in zip(proof_trees, rewards, weights):
        target_text = "\n".join(tree)
        prompt_text = "Proof attempt:"
        # Weighted by the final reward * the computed weight
        final_weight = reward * weight
        example = {
            "prompt": prompt_text,
            "target": target_text,
            "weight": final_weight
        }
        training_data.append(example)
    return training_data

In [20]:
def train_model_with_reinforce(model, tokenizer, value_function, proof_trees, rewards,
                               optimizer, replay_buffer, gamma=0.99):

    weights = compute_weights(proof_trees, value_function, gamma=gamma)
    training_data = prepare_training_data(proof_trees, rewards, weights)
    replay_buffer.extend(training_data)
    replay_buffer = replay_buffer[-2000:]
    model.train()
    optimizer.zero_grad()
    total_loss = 0.0

    for example in replay_buffer:
        if example["weight"] <= 0:
            continue
        inputs = tokenizer(example["prompt"], return_tensors="pt", truncation=True, padding=True)
        with tokenizer.as_target_tokenizer():
            targets = tokenizer(example["target"], return_tensors="pt", truncation=True, padding=True)["input_ids"]

        outputs = model(**inputs, labels=targets)
        # Weighted loss
        loss = outputs.loss * example["weight"]
        total_loss += loss.item()

        loss.backward()
    optimizer.step()
    
    # reward=1 as label=1, reward=0 as label=0, ignoring partial credit for now.
    train_value_function(value_function, proof_trees, rewards)

    return total_loss

In [21]:
def train_value_function(value_function, proof_trees, rewards):
    """
    Simple approach: if the proof tree had reward=1, label all nodes as "correct" (class=1);
    else label them as "incorrect" (class=0).
    """
    optimizer_vf = torch.optim.Adam(value_function.parameters(), lr=1e-4)
    loss_fn = torch.nn.CrossEntropyLoss()

    value_function.train()
    for tree, reward in zip(proof_trees, rewards):
        label = torch.tensor([1 if reward == 1.0 else 0])
        for _ in tree:  # each node in the tree
            # Dummy embedding
            node_vec = torch.rand(1, 768)
            pred = value_function(node_vec)  # shape [1,2]
            loss = loss_fn(pred, label)
            optimizer_vf.zero_grad()
            loss.backward()
            optimizer_vf.step()

In [22]:
def main(script_args: ProofTreeScriptArguments, training_args, model_args: ModelConfig):

    wandb.init(project="ProofGeneration", reinit=True)

    set_seed(training_args.seed)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)]
    )
    logger.setLevel(training_args.get_process_log_level())

    # Load and format the dataset.
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    dataset = dataset.map(format_dataset)
    
    # Initialize the Checker for proof verification.
    verifier = Checker(
        working_dir='/home/balaji/Isabelle2022/src/HOL/Examples',
        isa_path='/home/balaji/Isabelle2022',
        theory_file='/home/balaji/Isabelle2022/src/HOL/Examples/Interactive.thy',
        port=9000
    )
    global VERIFIER
    VERIFIER = verifier

    
    # Construct training (and evaluation) datasets with conversation prompts.
    train_dataset = construct_dataset(dataset[script_args.dataset_train_split], verifier)
    eval_dataset = (
        construct_dataset(dataset[script_args.dataset_test_split], verifier)
        if training_args.eval_strategy != "no" else None
    )
    
    # Load the pretrained model and tokenizer.
    model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # Obtain optional PEFT configuration.
    peft_config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
        task_type="CAUSAL_LM",
        lora_dropout=0.05,
    )
    
    # Select reward functions from the registry.
    selected_reward_funcs = [
        reward_funcs_registry[rf]
        for rf in script_args.reward_funcs if rf in reward_funcs_registry
    ]
    if not selected_reward_funcs:
        raise ValueError("No valid reward functions selected.")
    reward_funcs = selected_reward_funcs  # GRPOTrainer expects a list.
    
    # Build GRPO training configuration.
    grpo_config = GRPOConfig(
            output_dir="outputs",          
            run_name="my_run",             
            learning_rate=5e-6,
            adam_beta1=0.9,
            adam_beta2=0.99,
            weight_decay=0.1,
            warmup_ratio=0.1,
            lr_scheduler_type='cosine',
            logging_steps=1,
            bf16=True,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=2,
            num_generations=4,
            max_prompt_length=256,
            max_completion_length=512,
            num_train_epochs=1,
            save_steps=100,
            max_grad_norm=0.1,
            report_to="wandb",
            log_on_each_node=False,
        )
    
    # Pass proof tree configuration (using depth_limit from the dataset).
    proof_tree_config = {"max_depth": dataset["train"][0]["depth_limit"]}
    
    # Initialize GRPOTrainer.
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        args=grpo_config,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
    )
    
    # Start training.
    trainer.train()
    trainer.save_model(grpo_config.output_dir)
    
    # Evaluate and log some generated outputs.
    if eval_dataset is not None:
        eval_results = trainer.evaluate()
        print("Evaluation results:", eval_results)
        wandb.log(eval_results)
    
    for i in range(3):
        sample = train_dataset[i]
        prompt = sample["prompt"]
        generated = run_generation(model, tokenizer, prompt)
        print(f"Sample {i} Prompt:\n{prompt}")
        print(f"Sample {i} Generated:\n{generated}\n{'-'*40}")
        wandb.log({f"sample_{i}_prompt": prompt, f"sample_{i}_generated": generated})


In [23]:
import sys

# Set the required arguments here.
sys.argv = [
    "dummy_script.py",             # Script name (can be any valid string)
    "--output_dir", "outputs",      # Required output directory
    "--seed", "42",                 # Other required or optional arguments
    "--run_name", "my_run",
    "--per_device_train_batch_size", "1",
    "--num_train_epochs", "1"
]


In [24]:
from trl import TrlParser

parser = TrlParser((ProofTreeScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)

[34m[1mwandb[0m: Currently logged in as: [33mbalaji-vir1997[0m ([33mbalaji-vir1997-stevens-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.



==== Success: False
--- Complete proof:


==== Success: False
--- Complete proof:



KeyboardInterrupt: 