In [1]:
import os
import random
import re
import time
import torch
import transformers
import gym
import json
import math
import numpy as np
import openai
import peft
import random
import sys
import warnings
from tqdm import tqdm
from accelerate import Accelerator
from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    GenerationConfig,
    TrainerCallback,
    AdamW,
)
from transformers import BitsAndBytesConfig
from transformers import LlamaTokenizer
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig
from typing import List, Dict, Any
import warnings
warnings.filterwarnings("ignore")


In [2]:
#client = OpenAI()

In [3]:
!nvidia-smi

Thu Feb  6 19:04:31 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               On  | 00000000:2D:00.0 Off |                  Off |
| 30%   39C    P8               9W / 300W |     13MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:41:00.0  On |  

In [4]:
class LMFunction(object):
    def __init__(self, engine='gpt-4', max_tokens=512):
        self.engine = engine
        self.max_tokens = max_tokens
        self.client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

    def _call_api(self, prompt, engine, max_tokens, max_retries=10, retry_wait=2):
        for i in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=engine,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": prompt}
                    ],
                    max_tokens=max_tokens,
                    temperature=1.0
                )
                return response
            except openai.APIError as e:  # New error handling
                time.sleep(retry_wait)
        return {'choices': [{'message': {'content': ''}}]}

    def _parse_message(self, msg):
        try:
            content = msg.choices[0].message.content
        except (IndexError, AttributeError):
            content = ''
        return content

    def f(self, prompt, x):
        msg = self._call_api(
            prompt=prompt+x,
            engine=self.engine,
            max_tokens=self.max_tokens
        )
        evaluation = self._parse_message(msg)
        return evaluation


class Checker(object):
    """A modified version of the Draft, Sketch, Prove proof-checking client.
    (https://github.com/albertqjiang/draft_sketch_prove/blob/main/autoformalization/checker.py)

    This checker supports Isabelle2022 via the new version of PISA
    (https://albertqjiang.github.io/Portal-to-ISAbelle/).

    It supports checking a miniF2F-style proof via `check`.

    Finally, it replaces `sledgehammer` with a call to `normalhammer`.
    """
    def __init__(self, working_dir, isa_path, theory_file, port=9000):
        sys.path.append(os.environ['PISA_PATH'])
        try:
            from pisa_client import initialise_env # type: ignore
            self.initialise_env = initialise_env
        except:
            print("Set $PISA_PATH to /yourpath/to/Portal-to-ISAbelle/src/main/python")

        self.working_dir = working_dir
        self.isa_path = isa_path
        self.theory_file = theory_file
        self.port = port

    def _initialize(self):
        env = self.initialise_env(
            self.port,
            isa_path=self.isa_path,
            theory_file_path=self.theory_file,
            working_directory=self.working_dir
        )
        return env

    def _exit(self, env):
        try:
            env.post('exit')
        except:
            print("env.post('exit') timed out")
            pass
        os.system("ps aux | grep Isabelle | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")
        os.system("ps aux | grep poly | awk '{print $2}' | xargs kill -9 > /dev/null 2>&1")

    def _parse_output(self, obs):
        """Parse the sledgehammer output, otherwise return an empty string"""
        if '<hammer>' in obs:
            output = obs.split('<hammer>')[0]
        else:
            output = ''
        return output

    def _run_step(self, step, i, tls_name, env):
        obs, reward, done, metadata = env.step_to_top_level_state(
            action=step,
            tls_name=tls_name,
            new_name='default_%d' % i
        )
        error = None
        if 'error:' in obs or 'Step error' in obs or 'Unknown error' in obs:
            error = obs
        return obs, reward, done, metadata, error

    def _run_sledgehammer(self, step, i, tls_name, env):
        # First try heuristics
        for heuristic in ['by auto', 'by simp', 'by blast', 'by fastforce', 'by force', 'by eval', 'by presburger', 'by sos', 'by arith', 'by linarith', 'by (auto simp: field_simps)']:
            step_ = step.replace('normalhammer', heuristic)
            obs, reward, done, metadata, error = self._run_step(step_, i, tls_name, env)
            if error is None:
                obs = '%s <hammer> %s' % (heuristic, obs)
                return obs, reward, done, metadata, error
        # Try sledgehammer
        out = self._run_step(step, i, tls_name, env)
        return out

    def check(self, statement_and_proof):
        # Initialize environment
        env = self._initialize()
        env.initialise()

        # Wrap and parse theorem
        theory = Checker.wrap_theorem(statement_and_proof)
        steps = Checker.get_parsed(env, theory)

        result = self._check(env, steps)
        return result

    def _check(self, env, steps):
        done = False
        reason = ''
        success = False
        step_results = []
        tls_name = 'default'
        for i, step in enumerate(steps):
            try:
                time0 = time.time()
                if 'normalhammer' in step:
                    obs, reward, done, metadata, error = self._run_sledgehammer(step, i, tls_name, env)
                else:
                    obs, reward, done, metadata, error = self._run_step(step, i, tls_name, env)
                step_time = time.time() - time0
                step_results.append(dict(index=i, step=step, output=self._parse_output(obs), step_time=step_time))
                if error is not None:
                    reason = error
                    success = False
                    done = False
                    break
            except:
                # Timeout - end the proof attempt
                success = False
                done = False
                reason = 'timeout (%d)' % len(step_results)
                step_results.append(dict(index=i, step=step, output=''))
                break

            # Change when successful
            tls_name = 'default_%d' % i

        if done and reward == 1.0:
            success = True

        result = {
            'success': success,
            'reason': reason,
            'num_steps': len(steps),
            'last_step': len(step_results),
            'step_results': step_results,
            'theorem_and_proof': self.reconstruct(step_results) if success else ''
        }
        # Exit environment
        self._exit(env)
        return result
    
    @staticmethod
    def reconstruct(step_results):
        steps = []
        for step_result in step_results[1:]:
            if step_result['output'] != '':
                steps.append(step_result['output'].strip())
            else:
                steps.append(step_result['step'].strip())
        theorem_and_proof = '\n'.join(steps)
        return theorem_and_proof

    @staticmethod
    def wrap_theorem(theorem):
        return 'theory Interactive imports HOL.HOL Complex_Main "HOL-Library.Code_Target_Numeral" "HOL-Library.Sum_of_Squares" "Symmetric_Polynomials.Vieta" "HOL-Computational_Algebra.Computational_Algebra" "HOL-Number_Theory.Number_Theory" \n begin\n%s' % theorem

    @staticmethod
    def get_parsed(env, theory, tls_name='default'):
        # HACK: the parsing doesn't work well with `normalhammer`, so we replace
        # all hammer calls with sorry, then replace sorry to normalhammer after parsing.
        theory = theory.replace('sledgehammer', 'sorry')
        theory = theory.replace('normalhammer', 'sorry')

        steps = env.post(f"<parse text> ${theory}")
        steps = steps.split('<SEP>')
        steps = [s for s in steps if s.strip() != '']
        # remove weird '$' step and whitespace steps
        steps = [s for s in steps if s != '$' and s.strip() != '']
        steps = [s.replace('sorry', 'normalhammer') for s in steps]
        return steps


In [5]:
sys.path.append('../')
os.environ['PISA_PATH'] = '/home/siai/Portal-to-ISAbelle/src/main/python'

checker = Checker(
    working_dir='/home/siai/Isabelle2022/src/HOL/Examples',
    isa_path='/home/siai/Isabelle2022',
    theory_file='/home/siai/Isabelle2022/src/HOL/Examples/Interactive.thy',
    port=9000
)

In [6]:
def load_theorem_data(json_path: str) -> List[Dict[str, str]]:
    """
    Loads a JSON file containing an array of objects with:
      { "statement": "...", "state": "...", "step": "..." }
    Returns a list of dicts with these fields.
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data


def load_train_val_data(json_path: str, test_size=0.1, random_seed=42):
    with open(json_path, 'r') as f:
        data = json.load(f)
    train_data, val_data = train_test_split(data, test_size=test_size, random_state=random_seed)
    return train_data, val_data    

In [7]:
def evaluate_validation_loss(model, val_loader):
    model.eval()
    total_loss = 0.0
    total_count = 0
    with torch.no_grad():
        for batch in val_loader:
            for k, v in batch.items():
                batch[k] = v.cuda()
            outputs = model(**batch)
            loss = outputs.loss
            batch_size = batch["input_ids"].size(0)
            total_loss += loss.item() * batch_size
            total_count += batch_size
    model.train()
    return total_loss / total_count

In [8]:
def build_trl_format(dataset: List[Dict[str,str]]) -> List[Dict[str,Any]]:
    """
    Convert each record { "statement", "state", "step" } into TRL's conversation format:
      "prompt": [
        {"role": "system", "content": "Prove the following theorem."},
        {"role": "user", "content": statement + "\n\n" + state}
      ],
      "answer": step

    This single-turn approach: the user "says" the lemma & subgoal, the model must produce the final "step".
    """
    trl_data = []
    for rec in dataset:
        stm = rec.get("statement","")
        stt = rec.get("state","")
        sp  = rec.get("step","")
        prompt_content = stm + "\n\n" + stt
        sample = {
            "prompt": [
                {"role": "system", "content": "Prove the following theorem."},
                {"role": "user",   "content": prompt_content}
            ],
            "answer": sp  
        }
        trl_data.append(sample)
    
    return trl_data

In [9]:
class OfflineImitationDataset(Dataset):
    """
    For supervised training: input = statement + "\n\n" + state
                            target = step
    """
    def __init__(self, raw_data: List[Dict[str,str]]):
        super().__init__()
        self.samples = []
        for rec in raw_data:
            inp = rec["statement"] + "\n\n" + rec["state"]
            tgt = rec["step"]
            self.samples.append((inp, tgt))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

In [10]:
def offline_collate_fn(batch, tokenizer, max_length=512):
    inp_texts, tgt_texts = zip(*batch)
    enc = tokenizer(
        list(inp_texts),
        text_target=list(tgt_texts),
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    return {
        "input_ids": enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "labels": enc["labels"]
    }


In [11]:
def run_offline_imitation_learning(
    json_path: str,
    model_name: str,
    output_dir: str,
    epochs: int = 1,
    lr: float = 1e-5,
    max_length: int = 512,
    batch_size: int = 2
):

    # Load data
    raw_data = load_theorem_data(json_path)
    dataset_obj = OfflineImitationDataset(raw_data)

    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    #tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.pad_token = "<PAD>"
    
    model.config.use_cache = False   
    model.train()
    model.cuda()

    # Build DataLoader
    loader = DataLoader(
        dataset_obj,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda b: offline_collate_fn(b, tokenizer, max_length)
    )

    optimizer = AdamW(model.parameters(), lr=lr)

    # Train loop
    step_idx = 0
    for epoch in range(epochs):
        for batch in loader:
            for k, v in batch.items():
                batch[k] = v.cuda()

            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            step_idx += 1
            print(f"Epoch {epoch} Step {step_idx}, loss={loss.item():.4f}")

    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Offline training done, saved to {output_dir}")

In [12]:
def generate_proof_step(model, tokenizer, statement, state, max_length=32):
    # Build the input text for the model
    input_text = statement + "\n\n" + state
    # Encode
    enc_in = tokenizer.encode(input_text, return_tensors="pt").cuda()
    # Generate
    out_ids = model.generate(
        enc_in,
        max_new_tokens=max_length,
        do_sample=False,   # or False for  greedy
        top_p=0.9,        
        temperature=0.8,  # or 1.0
    )
    # Decode
    out_text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    return out_text


"""
def generate_proof_step(model, tokenizer, statement, state, max_length=32):
    input_text = statement + "\n\n" + state
    enc_in = tokenizer.encode(input_text, return_tensors="pt")

    
    with torch.no_grad():
        out_ids = model.generate(
            enc_in,
            max_new_tokens=max_length,
            do_sample=False,
            top_p=1.0,
            temperature=1.0,
            top_k=0, num_beams=5)

    out_text = tokenizer.batch_decode(out_ids[0], skip_special_tokens=True)
    return out_text
"""

'\ndef generate_proof_step(model, tokenizer, statement, state, max_length=32):\n    input_text = statement + "\n\n" + state\n    enc_in = tokenizer.encode(input_text, return_tensors="pt")\n\n    \n    with torch.no_grad():\n        out_ids = model.generate(\n            enc_in,\n            max_new_tokens=max_length,\n            do_sample=False,\n            top_p=1.0,\n            temperature=1.0,\n            top_k=0, num_beams=5)\n\n    out_text = tokenizer.batch_decode(out_ids[0], skip_special_tokens=True)\n    return out_text\n'

In [13]:
def checker_reward_func(prompts, completions, answer, checker, **kwargs) -> List[float]:
    """
    Single-step approach: 
      prompts[i][-1]['content'] = statement+state
      completions[i][0]['content'] = model's final step
    We unify them, pass to checker, parse partial or full success => returns float.
    """
    rewards = []
    for i in range(len(prompts)):
        theorem_text = prompts[i][-1]['content']
        step_text = completions[i][0]['content']
        combined = f"{theorem_text}\n\n{step_text}"
        result = checker.check(combined)
        reason = result["reason"]
        if result["success"]:
            # Full success
            rewards.append(2.0)
        elif reason == "partial":
            rewards.append(0.5)
        else:
            rewards.append(0.0)
    return rewards


In [14]:
def format_reward_func(completions, **kwargs) -> List[float]:
    """
    For demonstration, add a small reward if the output contains "by " or "sledgehammer " or "normalhammer" .
    """
    outs = []
    for c in completions:
        txt = c[0]["content"].lower()
        if "by " or "sledgehammer " or "normalhammer" in txt:
            outs.append(0.1)
        else:
            outs.append(0.0)
    return outs

In [15]:
def checker_reward(prompts, completions, answer, checker=None, **kwargs):
    """
    We clamp the final reward if it is NaN/Inf. 
    """
    base_rewards = checker_reward_func(prompts, completions, answer, checker=checker)
    new_rewards = []
    for r in base_rewards:
        if math.isnan(r) or math.isinf(r):
            new_rewards.append(0.0)
        else:
            new_rewards.append(r)
    return new_rewards

In [16]:
def evaluate_rl(trainer, val_data_subset, num_samples=200):
    """
    Evaluate the average reward on a subset of the validation data.
    We do *greedy* generation (no sampling) to reduce risk of infinite/NaN probabilities.
    """
    subset = random.sample(val_data_subset, min(num_samples, len(val_data_subset)))
    prompts = [s["prompt"] for s in subset]
    gold_ans = [s["answer"] for s in subset]

    # Force do_sample=False, top_p=1.0
    completions = trainer.generate(
        prompts,
        max_new_tokens=32,
        do_sample=False,      
        top_p=1.0,            
        temperature=1.0       
    )

    c_reward = checker_reward(prompts, completions, gold_ans)
    f_reward = format_reward_func(completions)
    total_reward = [r1 + r2 for (r1, r2) in zip(c_reward, f_reward)]
    avg_reward = sum(total_reward) / len(total_reward) if len(total_reward) > 0 else 0.0

    # clamp final
    if math.isnan(avg_reward) or math.isinf(avg_reward):
        avg_reward = 0.0

    return avg_reward

In [17]:
def train_with_early_stopping(
    model,
    train_loader,
    val_loader,
    optimizer,
    max_global_steps=100000,
    eval_every=100,
    patience=3
):
    """
    Train the model, validating every 'eval_every' steps.
    Stop early if validation fails to improve for 'patience' intervals.

    :param model: your HF model
    :param train_loader: DataLoader for training
    :param val_loader: DataLoader for validation
    :param optimizer: optimizer
    :param max_global_steps: max steps we allow in total
    :param eval_every: how often (in steps) to run validation
    :param patience: how many times in a row we allow no improvement
    """

    best_val_loss = float("inf")
    no_improvement_count = 0
    global_step = 0
    model.train()

    for epoch in range(999999):  # effectively "infinite" until we break
        for step_idx, batch in enumerate(train_loader):
            global_step += 1

            for k, v in batch.items():
                batch[k] = v.cuda()
            outputs = model(**batch)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % 50 == 0:
                print(f"[Global Step {global_step}] training loss={loss.item():.4f}")

            if global_step % eval_every == 0:
                val_loss = evaluate_validation_loss(model, val_loader)  # your function
                print(f"[Global Step {global_step}] val_loss={val_loss:.4f}")

                if val_loss < best_val_loss:
                    print("Validation improved!")
                    best_val_loss = val_loss
                    no_improvement_count = 0
                    # optionally save checkpoint
                    model.save_pretrained("checkpoint_best")
                else:
                    no_improvement_count += 1
                    print(f"No improvement count={no_improvement_count}")

                # ---- Early stopping if patience exceeded ----
                if no_improvement_count >= patience:
                    print("Early stopping: no improvement in val_loss for too long.")
                    return  # or break out of loops

            # ---- End if we exceed max steps ----
            if global_step >= max_global_steps:
                print("Reached maximum global steps.")
                return

In [18]:
def main():
    JSON_DATASET_PATH = "deduplicated_dataset.json"
    MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
    OFFLINE_SAVE_DIR = "offline_ckpt"
    OFFLINE_EPOCHS = 2
    BATCH_SIZE = 2
    LR = 1e-5
    MAX_LENGTH = 512
    EVAL_EVERY = 1000
    PATIENCE = 3


    # Initialize Accelerator
    accelerator = Accelerator()
    device = accelerator.device

    # Load dataset
    train_data, val_data = load_train_val_data(JSON_DATASET_PATH, test_size=0.01)
    train_dataset = OfflineImitationDataset(train_data)
    val_dataset = OfflineImitationDataset(val_data)

    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

    # Move model to Accelerator (multi-GPU, FP16 support)
    model.config.use_cache = False
    model.train()

    # Dataloader
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda b: offline_collate_fn(b, tokenizer, MAX_LENGTH)
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda b: offline_collate_fn(b, tokenizer, MAX_LENGTH)
    )

    optimizer = AdamW(model.parameters(), lr=LR)

    # Prepare everything for Accelerator
    model, optimizer, train_loader, val_loader = accelerator.prepare(
        model, optimizer, train_loader, val_loader
    )

    # Training Variables
    global_step = 0
    best_val_loss = float("inf")
    no_improvement_count = 0
    stop_early = False

    for epoch in range(OFFLINE_EPOCHS):
        if stop_early:
            break

        model.train()
        for step_idx, batch in enumerate(train_loader):
            global_step += 1

            # Move batch to the correct device
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss

                # Backpropagation
                optimizer.zero_grad()
                accelerator.backward(loss)
                optimizer.step()

            # Print training loss every 10 steps
            if step_idx % 10 == 0:
                print(f"Epoch {epoch}, step {step_idx}, train loss={loss.item():.4f}")

            # Generate a sample output every 500 steps
            if step_idx % 500 == 0 and len(val_data) > 0:
                sample = random.choice(val_data)
                gen_step = generate_proof_step(model, tokenizer, sample["statement"], sample["state"])
                print("-"*50)
                print(f"Sample statement+state:\n {sample['statement']} {sample['state']}")
                print("\n\nGenerated step:\n", gen_step)
                print("\n\nReference step:\n", sample["step"])
                print("-"*50)

            # Validation after every EVAL_EVERY steps
            if global_step % EVAL_EVERY == 0:
                model.eval()
                val_loss = evaluate_validation_loss(model, val_loader)
                print(f"[Global Step {global_step}] Interim val_loss={val_loss:.4f}")

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    no_improvement_count = 0
                    print("  (New best val_loss!)")
                    accelerator.unwrap_model(model).save_pretrained("checkpoint_best")
                else:
                    no_improvement_count += 1
                    print(f"  (No improvement. Count={no_improvement_count})")
                    if no_improvement_count >= PATIENCE:
                        print("Early stopping triggered (no val improvement).")
                        stop_early = True
                        break
                model.train()

        # End of epoch => Final validation
        if not stop_early:
            model.eval()
            val_loss = evaluate_validation_loss(model, val_loader)
            print(f"Epoch {epoch}, validation loss={val_loss:.4f}")

            if len(val_data) > 0:
                sample = random.choice(val_data)
                gen_step = generate_proof_step(model, tokenizer, sample["statement"], sample["state"])
                print("Sample statement+state:\n", sample["statement"], sample["state"])
                print("Generated step:\n", gen_step)
                print("Reference step:\n", sample["step"])
            model.train()

    # Save final model
    accelerator.unwrap_model(model).save_pretrained(OFFLINE_SAVE_DIR)
    tokenizer.save_pretrained(OFFLINE_SAVE_DIR)
    print(f"Offline training done. Model saved to: {OFFLINE_SAVE_DIR}")



In [19]:
class PrintCallback(TrainerCallback):
    """
    A callback that prints training info after each logging event (on_log)
    and does sample generation every 'sample_every' steps (optional).
    """
    def __init__(self, model, tokenizer, sample_every=500, val_data=None):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.sample_every = sample_every
        self.val_data = val_data if val_data else []

    def on_log(self, args, state, control, logs=None, **kwargs):
        """
        Called by the trainer after logs are created, typically each 'logging_steps'.
        """
        # Print the logs in a style similar to Script 2
        if "loss" in logs:
            print(f"Step {state.global_step}, training loss={logs['loss']:.4f}")
        else:
            print(f"Step {state.global_step}, logs={logs}")

        # Generate a sample output if we've hit sample_every steps
        if self.sample_every > 0 and state.global_step > 0 and state.global_step % self.sample_every == 0:
            if len(self.val_data) > 0:
                sample = random.choice(self.val_data)
                prompt = [
                    {"role": "system", "content": "Prove the following theorem."},
                    {"role": "user",   "content": sample["statement"] + "\n\n" + sample["state"]}
                ]
                print(f"(Callback) Generating proof at step {state.global_step}, sample statement[:50]: {sample['statement'][:50]}...")
                with torch.no_grad():
                    completion = kwargs["trainer"].generate(
                        [prompt],
                        max_new_tokens=50,
                        do_sample=False,
                        top_p=1.0,
                        temperature=1.0
                    )
                out_text = completion[0][0]["content"] if completion else ""
                print(f"(Callback) RL-generated step => {out_text[:80]}...")



In [20]:
#if __name__ == "__main__":
    #main()

In [21]:
def main_grpo():
    """
    Multi-GPU RL function using TRL’s GRPO with:
      - Chunked incremental training
      - Sample output printing every 100 steps
      - Standard early stopping based on reward
      - RL validation on a held-out set
    """

    # Set environment variables to fix tokenizer issues
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # Force CUDA debugging

    JSON_DATASET_PATH = "deduplicated_dataset.json"
    OFFLINE_CKPT = "offline_ckpt"  
    RL_SAVE_DIR = "outputs/Qwen-Theorem-GRPO"

    # Training settings
    EVAL_EVERY = 1000   
    PATIENCE = 3       
    MAX_GLOBAL_STEPS = 10000
    CHUNK_SIZE = 100    

    # Initialize accelerator for multi-GPU support
    accelerator = Accelerator()
    device = accelerator.device
    print(f"Using device: {device}")

    # Check CUDA Capability
    capability = torch.cuda.get_device_capability()
    print(f"CUDA Device Capability: {capability}")
    if capability[0] < 8:
        print("Warning: This GPU may not support bfloat16. Switching to float16 instead.")
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.bfloat16

    # Setup Isabelle Checker
    checker = Checker(
        working_dir='/home/siai/Isabelle2022/src/HOL/Examples',
        isa_path='/home/siai/Isabelle2022',
        theory_file='/home/siai/Isabelle2022/src/HOL/Examples/Interactive.thy',
        port=9000
    )

    # Load dataset & split
    train_data, val_data = load_train_val_data(JSON_DATASET_PATH, test_size=0.1)
    
    print(f"Train size: {len(train_data)}, Val size: {len(val_data)}")

    train_trl_data = build_trl_format(train_data)
    val_trl_data   = build_trl_format(val_data)

    
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(OFFLINE_CKPT)
    tokenizer.pad_token = tokenizer.eos_token

    # Ensure Model Loads on Correct Device

    model = AutoModelForCausalLM.from_pretrained(
        OFFLINE_CKPT,
        torch_dtype=torch.bfloat16,
        #device_map="auto"
    ).to(device)
    generation_config = GenerationConfig( do_sample=False, top_k=0, temperature = 1.0)
    print(f"Model loaded successfully on {model.device}")
    #print(model.hf_device_map)  # See how layers are placed

    # LoRA configuration (optional)
    peft_config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"],
        task_type="CAUSAL_LM",
        lora_dropout=0.05,
    )

    # Reward function combining format + checker
    def checker_reward(prompts, completions, answer, **kwargs):
        return checker_reward_func(prompts, completions, answer, checker=checker)

    # Training arguments
    training_args = GRPOConfig(
        output_dir=RL_SAVE_DIR,
        run_name="Qwen-GRPO-theorems",
        learning_rate=5e-7, #1e-6 or 5e-6 is bigger.
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=10,
        bf16=(torch_dtype == torch.bfloat16),
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        num_generations=1,
        max_prompt_length=512,
        max_completion_length=512,
        num_train_epochs=1,
        save_steps=9999999,
        max_grad_norm=0.1,
        report_to="wandb",
        log_on_each_node=False,        
    )

    # Initialize Trainer with Accelerate
    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[format_reward_func, checker_reward],
        args=training_args,
        train_dataset=train_trl_data,
    )
    
    
    print(f"Model Device: {model.device}")
    print(f"Accelerator Device: {device}")

    best_val_reward = -9999
    no_improvement_count = 0
    global_step = 0
    stop_early = False

    print("Tokenizer vocab size =", len(tokenizer))
    print("Model config vocab size =", model.config.vocab_size)
    #model.resize_token_embeddings(len(tokenizer))
    
    while not stop_early and global_step < MAX_GLOBAL_STEPS:
        steps_to_run = min(CHUNK_SIZE, MAX_GLOBAL_STEPS - global_step)
        trainer.args.max_steps = steps_to_run
        trainer.args.num_train_epochs = 1
        trainer.generation_config.do_sample = False
        trainer.generation_config.top_k = None  
        trainer.generation_config.num_return_sequences = 1
        trainer.generation_config.do_sample=False
        print("Trainer Generation Config:", trainer.generation_config) ## debug

        print(f"Training {steps_to_run} steps...")
        trainer.train()
        global_step += steps_to_run
        print(f"Finished {steps_to_run} RL steps, Global Step={global_step}")

        # Generate a sample output every chunk
        if len(val_data) > 0:
            sample = random.choice(val_data)
            prompt = [
                {"role": "system", "content": "Prove the following theorem."},
                {"role": "user", "content": sample["statement"] + "\n\n" + sample["state"]}
            ]

            print(f"Generating proof step for sample: {sample['statement'][:50]}...")
            with torch.no_grad():  # Prevents CUDA memory issues
                completions = trainer.generate(
                    [prompt], max_new_tokens=50, do_sample=False, top_p=1.0, temperature=1.0
                )
            gen_text = completions[0][0]['content'] if completions else ""
            print(f"RL-generated step: {gen_text[:50]}")

        # RL Validation
        if global_step % EVAL_EVERY == 0:
            print(f"Running RL Validation at step {global_step}")
            val_reward = evaluate_rl(trainer, val_trl_data, num_samples=200)
            print(f"[RL Validation] Global Step={global_step}, Val Reward={val_reward:.4f}")

            # Early stopping logic
            if val_reward > best_val_reward:
                best_val_reward = val_reward
                no_improvement_count = 0
                print("New best RL reward! Saving checkpoint...")
                accelerator.unwrap_model(trainer.model).save_pretrained("checkpoint_best_rl")
            else:
                no_improvement_count += 1
                print(f"No improvement. Count={no_improvement_count}")
                if no_improvement_count >= PATIENCE:
                    print("Early stopping triggered (RL val reward not improving).")
                    stop_early = True

    print(f"RL Training Complete! Best Val Reward={best_val_reward:.4f}")
    accelerator.unwrap_model(trainer.model).save_pretrained(RL_SAVE_DIR)
    print(f"RL Final Model Saved in: {RL_SAVE_DIR}")


In [22]:
SYSTEM_PROMPT = """
Respond in the following format to generate isabelle proofs to verify the input prompts:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
where the <reasoning> describes your step-by-step Isabelle-like reasoning, 
and <answer> is the final succinct proof step or conclusion for the checker.
OUTPUT ISABELLE PROOFS ONLY WITHIN </answer>.....</answer> tags.
"""

# Example reward functions from your inspiration snippet:
def extract_xml_answer(text: str) -> str:
    """
    Extracts the text inside <answer>...</answer>.
    Returns the extracted content or an empty string if not found.
    """
    if "<answer>" not in text:
        return ""
    answer = text.split("<answer>", 1)[-1]
    answer = answer.split("</answer>", 1)[0]
    return answer.strip()

def extract_xml_reasoning(text: str) -> str:
    """
    Extracts the text inside <reasoning>...</reasoning>.
    Returns the extracted content or an empty string if not found.
    """
    if "<reasoning>" not in text:
        return ""
    reasoning = text.split("<reasoning>", 1)[-1]
    reasoning = reasoning.split("</reasoning>", 1)[0]
    return reasoning.strip()

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """Reward partial presence of <reasoning> and <answer> tags."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for c in contents:
        reward = 0.0
        # Check for presence of <reasoning> or <answer> tags
        if "<reasoning>" in c and "</reasoning>" in c:
            reward += 0.1
        if "<answer>" in c and "</answer>" in c:
            reward += 0.1
        rewards.append(reward)
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Reward function that checks if the completion has the exact format:
      <reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n
    Using a simple regex approach.
    """
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n?$"
    # You can make the trailing newline optional.
    responses = [completion[0]["content"] for completion in completions]
    return [1.0 if re.match(pattern, r, flags=re.DOTALL) else 0.0 for r in responses]

def isabelle_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Example: If we want to see 'proof' or 'have' or 'thus' keywords that 
    Isabelle proofs often contain. Just a simple substring check.
    """
    responses = [completion[0]["content"].lower() for completion in completions]
    reward_list = []
    for resp in responses:
        reward = 0.0
        if "proof" in resp:
            reward += 0.1
        if "have" in resp:
            reward += 0.1
        if "thus" in resp:
            reward += 0.1
        if "qed" in resp:
            reward += 0.1
        reward_list.append(reward)
    return reward_list


# Additional callback to show logs
class PrintCallback(TrainerCallback):
    """
    Prints training info after each logging event,
    does a sample generation every N steps, etc.
    """
    def __init__(self, model, tokenizer, sample_every=500, val_data=None):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.sample_every = sample_every
        self.val_data = val_data if val_data else []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs:
            print(f"[Step {state.global_step}] training loss={logs['loss']:.4f}")
        else:
            print(f"[Step {state.global_step}] logs={logs}")
        # Sample generation
        if self.sample_every > 0 and state.global_step > 0 and state.global_step % self.sample_every == 0:
            if len(self.val_data) > 0:
                sample = random.choice(self.val_data)
                prompt = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user",   "content": sample["statement"] + "\n\n" + sample["state"]}
                ]
                print(f"[Callback] Generating proof at step {state.global_step}, sample statement[:50]: {sample['statement'][:50]} ...")
                with torch.no_grad():
                    completion = kwargs["trainer"].generate(
                        [prompt],
                        max_new_tokens=50,
                        do_sample=False,
                        top_p=1.0,
                        temperature=1.0,
                    )
                out_text = completion[0][0]["content"] if completion else ""
                print(f"[Callback] RL-generated step => {out_text[:120]}...")


In [23]:
from tqdm import tqdm

def main_grpo():
    """
    Multi-GPU RL function using TRL’s GRPO with:
      - Chunked incremental training
      - Frequent printing & sample generation (via callback)
      - Standard early stopping based on reward
      - RL validation on a held-out set
      - Heavily incentivizes "thinking" & "reasoning" in Isabelle-like proofs
    """

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["CUDA_LAUNCH_BLOCKING"]   = "1"

    # 2) Basic config
    JSON_DATASET_PATH = "deduplicated_dataset.json"
    OFFLINE_CKPT = "offline_ckpt"  
    RL_SAVE_DIR = "outputs/Qwen-Theorem-GRPO"

    CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "100"))
    EVAL_EVERY = 1000
    PATIENCE   = 10
    MAX_GLOBAL_STEPS = 10000

    accelerator = Accelerator()
    device = accelerator.device
    print(f"Using device: {device}")

    # Check GPU
    capability = torch.cuda.get_device_capability()
    print(f"CUDA Device Capability: {capability}")
    if capability[0] < 8:
        print("Warning: GPU < SM80 => using float16")
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.bfloat16

    checker = Checker(
        working_dir='/home/siai/Isabelle2022/src/HOL/Examples',
        isa_path='/home/siai/Isabelle2022',
        theory_file='/home/siai/Isabelle2022/src/HOL/Examples/Interactive.thy',
        port=9000
    )

    train_data, val_data = load_train_val_data(JSON_DATASET_PATH, test_size=0.1)
    print(f"Train size: {len(train_data)}, Val size: {len(val_data)}")

    train_trl_data = build_trl_format(train_data)
    val_trl_data   = build_trl_format(val_data)

    tokenizer = AutoTokenizer.from_pretrained(OFFLINE_CKPT)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        OFFLINE_CKPT,
        torch_dtype=torch_dtype,
    ).to(device)
    print(f"Model loaded successfully on {model.device}")

    generation_config = GenerationConfig(
        do_sample=False,
        top_k=0,
        temperature=1.0
    )

    peft_config = LoraConfig(
        r=4,
        lora_alpha=16,
        target_modules=["q_proj","k_proj","v_proj","o_proj","up_proj","down_proj","gate_proj"],
        task_type="CAUSAL_LM",
        lora_dropout=0.05,
    )

    def checker_reward(prompts, completions, answer, **kwargs):
        reward = checker_reward_func(prompts, completions, answer, checker=checker)
        #print("\n--- Training Step ---")
        #print(f"Input: {prompts}")
        #print(f"Model Output: {completions}")
        #print(f"Expected Output: {answer}")
        #print(f"Reward Given: {reward}")
        #print("----------------------\n")
        return reward

    reward_func_list = [
        xmlcount_reward_func,
        strict_format_reward_func,
        isabelle_format_reward_func,
        checker_reward,
    ]

    training_args = GRPOConfig(
        output_dir=RL_SAVE_DIR,
        run_name="Qwen-GRPO-theorems",
        learning_rate=5e-7,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=1,
        bf16=(torch_dtype == torch.bfloat16),
        per_device_train_batch_size=1,
        gradient_accumulation_steps=2,
        num_generations=1,
        max_prompt_length=512,
        max_completion_length=512,
        num_train_epochs=1,
        save_steps=9999999,
        max_grad_norm=0.1,
        report_to="wandb",
        log_on_each_node=False,
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_func_list,
        args=training_args,
        train_dataset=train_trl_data,
    )

    from functools import partial
    print_callback = PrintCallback(model, tokenizer, sample_every=200, val_data=val_data)
    trainer.add_callback(print_callback)

    print(f"Model Device: {model.device}")
    print(f"Accelerator Device: {device}")
    print("Tokenizer vocab size =", len(tokenizer))
    print("Model config vocab size =", model.config.vocab_size)

    best_val_reward = -9999
    no_improvement_count = 0
    global_step = 0
    stop_early = False

    while not stop_early and global_step < MAX_GLOBAL_STEPS:
        steps_to_run = min(CHUNK_SIZE, MAX_GLOBAL_STEPS - global_step)
        trainer.args.max_steps = steps_to_run
        trainer.args.num_train_epochs = 1

        trainer.generation_config.do_sample = False
        trainer.generation_config.top_k = None
        trainer.generation_config.num_return_sequences = 1
        print("Trainer Generation Config:", trainer.generation_config)

        print(f"Training {steps_to_run} steps...")
        trainer.train()
        global_step += steps_to_run
        trainer.progress_bar.update(steps_to_run)
        trainer.progress_bar.set_postfix(Step=global_step, Training_Loss=trainer.training_loss)
        print(f"Finished {steps_to_run} RL steps, Global Step={global_step}")


In [None]:
if __name__ == "__main__":
    main_grpo()

Using device: cuda
CUDA Device Capability: (8, 6)
Train size: 333031, Val size: 37004


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded successfully on cuda:0
Model Device: cuda:0
Accelerator Device: cuda
Tokenizer vocab size = 151665
Model config vocab size = 152064
Trainer Generation Config: GenerationConfig {
  "max_new_tokens": 512,
  "pad_token_id": 151643,
  "temperature": 0.9,
  "top_k": null
}

Training 100 steps...


[34m[1mwandb[0m: Currently logged in as: [33mbalaji-vir1997[0m ([33mbalaji-vir1997-stevens-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
!nvidia-smi