In [None]:
from __future__ import annotations

import os

import json
import random
from pathlib import Path
from typing import Dict, List, Tuple

import gymnasium as gym
import numpy as np
import torch
from datasets import Dataset
from gymnasium import spaces
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedTokenizerFast,
    LlamaForCausalLM
)
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback

In [None]:
# Set-up CUDA device
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
# use a specific GPU
os.environ["CUDA_VISIBLE_DEVICES"]="5,6,7"

# Use GPU for inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print the device being used
print(f"Using device: {device}")

# Check the GPU name
if device.type == 'cuda':
    gpu_name = torch.cuda.get_device_name(0)  # 0 because CUDA_VISIBLE_DEVICES=4 means GPU 4 is now 0
    print("Using GPU:", gpu_name)

In [None]:
# ---------------------------------------------------------------------------
# 1. Supervised Problem Type & Difficulty Classifier
# ---------------------------------------------------------------------------
def build_label_map(dataset):
    """
    Returns:
        label_map: id -> (type_str, level_int)
        encode:  (type_str, level_int) -> id  (for training, if you want it)
    """
    # unique problem categories (7 in the dataset)
    types = sorted(set(dataset["type"]))
    
    # 'Level 3' → 3
    def lvl_int(lvl_str): 
        try:
            return int(lvl_str.strip().split()[1])
        except (IndexError, ValueError):
            return 6 
    levels = sorted({lvl_int(lvl) for lvl in dataset["level"]})  # [1-5]
    
    label_map, encode = {}, {}
    for t in types:
        for l in levels:
            idx = len(label_map)
            label_map[idx] = (t, l)
            encode[(t, l)] = idx
    return label_map, encode

class ProblemClassifier:
    """Lightweight text classifier.  Fine‑tune or plug‑in a checkpoint."""

    def __init__(self, model_name: str, label_map: Dict[int, Tuple[str, int]]):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.label_map = label_map  # id -> (type, level)
        self.model.eval()
        if torch.cuda.is_available():
            self.model.to("cuda")

    @torch.no_grad()
    def predict(self, text: str) -> Tuple[str, int]:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True)
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
        logits = self.model(**inputs).logits[0]
        label_id = int(logits.argmax())
        return self.label_map[label_id]

In [None]:
# ---------------------------------------------------------------------------
# 2. Reasoning Engine (frozen LM) with step‑wise generation
# ---------------------------------------------------------------------------
from prompt_template import PROMPT_PREFIX

class ReasoningEngine:
    """Generates step‑by‑step chain‑of‑thought until told to stop."""

    def __init__(self, model_dir: str, max_tokens_per_step: int = 32):
        self.tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir, padding_side="left")
        self.model = LlamaForCausalLM.from_pretrained(model_dir)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model.eval().requires_grad_(False)
        if torch.cuda.is_available():
            self.model.to("cuda")
        self.max_tokens_per_step = max_tokens_per_step
        self.prefix = PROMPT_PREFIX

    def reset_prompt(self, problem_text):
        # Called at env.reset()
        self.prompt = (
            f"{self.prefix}\n"
            f"Problem: {problem_text}<|eot_id|>"
            f"<|start_header_id|>assistant<|end_header_id|>\n\n"
        )
        return self.prompt

    @torch.no_grad()
    def think_step(self, prompt: str) -> str:
        """Generate a *single* reasoning chunk."""
        inputs = self.tokenizer(prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
        out = self.model.generate(
            **inputs,
            max_new_tokens=self.max_tokens_per_step,
            do_sample=False,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        generated = self.tokenizer.decode(out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
        return generated.strip()

    @staticmethod
    def extract_answer(text: str) -> str:
        from math_utils import last_boxed_only_string, normalize_final_answer, remove_boxed

        last_boxed_text = last_boxed_only_string(text)
        answer = normalize_final_answer(remove_boxed(last_boxed_text)) if last_boxed_text else None
        return answer

In [None]:
class MathEnv(gym.Env):
    """Gym‑compatible environment where the agent decides to *think* or *stop*."""

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        dataset_split: str = "train[:2048]",  # small default slice
        model_dir: str = "../../../llm/llama/Llama-3.2-1B-Instruct",
        classifier_name: str = "distilbert-base-uncased",
        penalty_lambda: float = 0.003,
        max_steps: int = 50,
        seed: int = 42,
    ):
        super().__init__()
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # Dataset ------------------------------------------------------------
        self.data = load_local_math_dataset(dataset_split, root_dir="MATH")

        # Build label map
        self.label_map, self.encode = build_label_map(self.data)

        # Components ---------------------------------------------------------
        self.classifier = ProblemClassifier(classifier_name, self.label_map)
        self.engine = ReasoningEngine(model_dir)

        # Hyper‑parameters ---------------------------------------------------
        self.penalty_lambda = penalty_lambda
        self.max_steps = max_steps

        # Spaces -------------------------------------------------------------
        # Observation: flattened categorical ids – replace w/ embeddings if needed
        self.observation_space = spaces.Box(low=0, high=1, shape=(10,), dtype=np.float32)
        # Action: 0 = CONTINUE, 1 = STOP
        self.action_space = spaces.Discrete(2)

        # Internal state -----------------------------------------------------
        self.idx = -1
        self.problem: Dict = {}
        self.prompt = ""
        self.reasoning_trace: List[str] = []
        self.step_count = 0

    # ---------------------------------------------------------------------
    # Gym API
    # ---------------------------------------------------------------------
    def reset(self, *, seed: int | None = None, options: Dict | None = None):
        if seed is not None:
            super().reset(seed=seed)
        self.idx = (self.idx + 1) % len(self.data)
        self.problem = self.data[self.idx]
        self.prompt = self.engine.reset_prompt(self.problem["problem"])  # chain‑of‑thought prefix
        self.reasoning_trace = []
        self.step_count = 0


        obs = self._build_obs()
        info = {"answer": self.problem["solution"]}
        return obs, info

    def step(self, action: int):
        done = False
        reward = 0.0
        info = {}

        if action == 0:  # CONTINUE
            step_text = self.engine.think_step(self.prompt)
            self.reasoning_trace.append(step_text)
            self.prompt += step_text + "\n"
            self.step_count += 1
            done = self.step_count >= self.max_steps  # forced stop
            # living penalty per think step
            reward -= self.penalty_lambda
        else:  # STOP & answer
            from math_utils import is_equiv
            done = True
            generated_answer = self.engine.extract_answer(self.prompt)
            gt_answer = self.engine.extract_answer(self.problem["solution"]) 

            if generated_answer is not None or generated_answer != 24:
                try:
                    exact_match = int(generated_answer.strip() == gt_answer.strip())
                    # or is_equiv(generated_answer, gt_answer)
                except AttributeError:
                    exact_match = int(generated_answer.strip() == gt_answer)
            elif generated_answer is None or generated_answer == 24:
                exact_match = 0
                reward += -2
            else:
                exact_match = 0
            
            correct = exact_match > 0
            if correct:
                reward += 1.0
            else:
                reward -= 1.0
            info["correct"] = correct
            info["generated_answer"] = generated_answer

            print(f"\n🧠 Final prompt:\n{self.prompt}")
            print(f"✅ GT: {gt_answer}")
            print(f"📝 Model Response: {generated_answer}")
            print(f"🎯 Correct: {correct}, Reward: {reward}") 


        obs = self._build_obs() if not done else np.zeros_like(self.observation_space.sample())
            
        return obs, reward, done, False, info  # Gymnasium API (v0.29)

    # ------------------------------------------------------------------
    def _build_obs(self):
        """Cheap hand‑crafted obs: [type_id, level, step_count, ...padding]"""
        p_type, level = self.classifier.predict(self.problem["problem"])
        type_id = hash(p_type) % 7  # up to 7 base types – update as needed
        vec = np.zeros(self.observation_space.shape, dtype=np.float32)
        vec[0] = type_id / 7.0
        vec[1] = level / 5.0  # assume 1‑5 scale
        vec[2] = self.step_count / float(self.max_steps)
        return vec

    # ------------------------------------------------------------------
    def render(self):
        print("\nProblem:", self.problem["problem"])
        print("\nReasoning Trace:")
        for t in self.reasoning_trace:
            print("  >", t)