In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive/rl_llms')
# !pip install -U trl

!pip install -U bitsandbytes
!pip install wandb
!pip install trl[vllm]
import trl
import wandb
wandb_api_key = "UR wandb api key"
wandb.login(key = wandb_api_key)
print(trl.__version__)
print('Google Drive mounted and set as current working directory.')
!ls
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



# Loading in model and datasets

Here we load in the previous fine-tuned qwen model, and the corresponding tokenizer

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
ckpt_dir  = "Qwen2_0_5B_Instruct_2025-11-21 20:42:36_xLAM_2/checkpoint-250"
HF_TOKEN = "UR HF token"
MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


fine_tuned_model = AutoModelForCausalLM.from_pretrained(
    ckpt_dir,
    torch_dtype=torch.bfloat16,  # or float16 / float32 depending on what you want
    device_map="auto",
    token = HF_TOKEN,
)

gen_config = GenerationConfig.from_pretrained(MODEL_NAME)
fine_tuned_model.generation_config = gen_config

# Creating Training Configs

These training configs were the ones from Behrooz Azarkhali and were held consistent with the SFT.

 https://huggingface.co/learn/cookbook/en/function_calling_fine_tuning_llms_on_xlam



In [None]:
from dataclasses import dataclass
import time
from peft import LoraConfig
@dataclass
class ModelConfig:
    """Configuration for model-specific settings."""
    model_name: str           # HuggingFace model identifier
    pad_token: str           # Padding token for the tokenizer
    pad_token_id: int        # Numerical ID for the padding token
    padding_side: str        # Side to add padding ('left' or 'right')
    eos_token: str          # End of sequence token
    eos_token_id: int       # End of sequence token ID
    vocab_size: int         # Vocabulary size
    model_type: str         # Model architecture type

@dataclass
class TrainingConfig:
    """Configuration for training hyperparameters."""
    output_dir: str                    # Directory to save model checkpoints
    batch_size: int = 16              # Training batch size per device
    gradient_accumulation_steps: int = 8  # Steps to accumulate gradients
    learning_rate: float = 1e-4       # Learning rate for optimization
    max_steps: int = 1000             # Maximum training steps
    max_seq_length: int = 2048        # Maximum sequence length
    lora_r: int = 16                  # LoRA rank parameter
    lora_alpha: int = 16              # LoRA alpha scaling parameter
    lora_dropout: float = 0.05        # LoRA dropout rate
    save_steps: int = 250             # Steps between checkpoint saves
    logging_steps: int = 10           # Steps between log outputs
    warmup_ratio: float = 0.1         # Warmup ratio for learning

def create_training_config(model_name: str, **kwargs) -> TrainingConfig:
    """Create training configuration with automatic output directory."""
    # Create clean directory name from model name
    model_clean = model_name.split('/')[-1].replace('-', '_').replace('.', '_')
    current_struct_time = time.localtime(time.time())
    formatted_time = time.strftime("%Y-%m-%d %H:%M:%S", current_struct_time)
    default_output_dir = f"./{model_clean}_{formatted_time}_PPO_xLAM"


    config_dict = {'output_dir': default_output_dir, **kwargs}
    return TrainingConfig(**config_dict)
def create_lora_config(training_config: TrainingConfig) -> LoraConfig:
    """
    Create LoRA configuration for parameter-efficient fine-tuning.

    LoRA (Low-Rank Adaptation) adds small trainable matrices to specific
    model layers while keeping the base model frozen.

    Args:
        training_config (TrainingConfig): Training configuration with LoRA parameters

    Returns:
        LoraConfig: Configured LoRA adapter settings

    LoRA Parameters:
        - r (rank): Dimensionality of adaptation matrices (higher = more capacity)
        - alpha: Scaling factor for LoRA weights
        - dropout: Regularization to prevent overfitting
        - target_modules: Which model layers to adapt
    """
    print("‚öôÔ∏è Configuring LoRA adapters...")

    # Target modules for both Llama and Qwen architectures
    target_modules = [
        'k_proj', 'q_proj', 'v_proj', 'o_proj',  # Attention projections
        "gate_proj", "down_proj", "up_proj"       # Feed-forward projections
    ]

    lora_config = LoraConfig(
        lora_alpha=training_config.lora_alpha,
        lora_dropout=training_config.lora_dropout,
        r=training_config.lora_r,
        bias="none",                             # Don't adapt bias terms
        task_type="CAUSAL_LM",                   # Causal language modeling
        target_modules=target_modules
    )

    print(f"üéØ LoRA targeting modules: {target_modules}")
    print(f"üìä LoRA parameters: r={training_config.lora_r}, alpha={training_config.lora_alpha}")

    return lora_config



# Loading in Datasets

we want to load in the datasets that we have already split into train and eval for this next round of training.


In [None]:
from datasets import load_dataset, load_from_disk
from datasets import DatasetDict



train_dataset = load_from_disk("./data/train")
test_dataset = load_from_disk("./data/test")

dataset_dict = DatasetDict({
    "train": train_dataset,
    "test":  test_dataset,
})
# For GRPO, we want the raw datasets with a "prompt" column
grpo_train_dataset = dataset_dict["train"]
grpo_eval_dataset  = dataset_dict["test"]

print(grpo_train_dataset[0].keys())  # should include "prompt"

sample = grpo_train_dataset[0]
print(sample)




# Start training

- Now we begin using GRPO to iron out the formatiing


## Results from Fine-Tuning Training

| Step | Training Loss | Validation Loss | Entropy   | Num Tokens      | Mean Token Accuracy |
|------|---------------|-----------------|-----------|------------------|----------------------|
| 100  | 0.086300      | 0.080959        | 1.079421  | 5,140,499        | 0.977191             |
| 200  | 0.063900      | 0.060436        | 1.038168  | 10,289,647       | 0.983178             |





## Reward Mechanism Summary

The reward function scores how well the model outputs a valid `<calls>...</calls>` block using allowed tools and parameters. The score is kept within `[-1, 1]`.

### What Increases Reward
- **A valid `<calls>` block is present**
- **Tool calls parse correctly** as Python dictionaries
- **No hallucinated tools** (all tools are in the `<tools>` list)
- **No hallucinated parameters** (arguments match allowed params)
- **Covers all user IDs** referenced in `<user>...</user>`
- **No extra text** before or after `<calls>` block
- **Clean formatting** (minor style bonuses)

### What Decreases Reward
- Missing `<calls>` block (large penalty)
- Extra natural language or random text around the block
- Incorrect or unparsable tool calls
- Using tools not listed in `<tools>`
- Using parameters that aren‚Äôt allowed
- Referencing IDs that weren‚Äôt provided by the user
- Very long completions

### Philosophy
- Good structure = strong positive reward  
- Mistakes = proportional negative reward  
- All rewards are clamped to `[-1, 1]`

Reward Model

In [None]:
import re
import json
import ast

# ===== compile regexes once =====
user_block_re  = re.compile(r"<user>(.*?)</user>", re.DOTALL)
calls_block_re = re.compile(r"<calls>(.*?)</calls>", re.DOTALL)
tools_block_re = re.compile(r"<tools>(.*?)</tools>", re.DOTALL)

id_re    = re.compile(r"[\"'](\d{5,})[\"']")
is_id_re = re.compile(r"[\"']is_id[\"']\s*:\s*[\"']([^\"']+)[\"']")

end_token = tokenizer.eos_token
print("end token is ", end_token)


def _extract_user_ids(text: str):
    m = user_block_re.search(text)
    if not m:
        return set()
    return set(id_re.findall(m.group(1)))


def _extract_called_ids(calls_text: str):
    if not calls_text:
        return set()
    return set(is_id_re.findall(calls_text))


def _safe_parse_python_dict(line: str):
    """
    Safely parse a single Python dict literal using ast.literal_eval.
    Returns a dict or None.
    """
    try:
        obj = ast.literal_eval(line)
        if isinstance(obj, dict):
            return obj
    except Exception:
        return None
    return None


def _extract_allowed_tools(prompt: str):
    """
    Parse the <tools>...</tools> block in the prompt and return:
        { tool_name: {param1, param2, ...}, ... }

    Assumes each tool is a Python-style dict on its own line, e.g.:

    <tools>{'name': 'locations_v2_list', ...}
    {'name': 'detail', ...}
    ...</tools>
    """
    m = tools_block_re.search(prompt)
    if not m:
        return {}

    tools_text = m.group(1).strip()
    lines = [ln.strip() for ln in tools_text.split("\n") if ln.strip()]

    allowed = {}
    for line in lines:
        obj = _safe_parse_python_dict(line)
        if obj is None:
            # skip bad tool lines, but don't kill everything
            continue

        name = obj.get("name")
        params = obj.get("parameters", {})
        if name:
            allowed[name] = set(params.keys())

    return allowed


def _parse_calls_body_to_dict_list(calls_body: str):
    """
    Parse the contents of <calls>...</calls> into a list of dict objects.

    Assumes the model outputs one Python-style dict per line, e.g.:

        {'name': 'shares_float', 'arguments': {'symbol': 'V'}}
        {'name': 'stock_balance_sheet_stock', 'arguments': {'symbol': 'MA'}}

    Returns:
        list[dict] on success,
        None if ANY line fails.
    """
    calls_body = calls_body.strip()
    if not calls_body:
        return []

    lines = [ln.strip() for ln in calls_body.split("\n") if ln.strip()]
    parsed = []

    for line in lines:
        obj = _safe_parse_python_dict(line)
        if obj is None:
            return None
        parsed.append(obj)

    return parsed


def _compute_single_reward(prompt: str, completion: str, debug: bool = False) -> float:
    if debug:
        print("\n================ REWARD DEBUG ================")
        print("RAW COMPLETION:")
        print(repr(completion))

    completion = completion.strip()

    # Start from neutral
    reward = 0.0

    # 0. Handle truly empty completions (hard bottom)
    if not completion:
        if debug:
            print("[WARN] Empty completion -> reward = -1.0")
        return -1.0

    # 1. Check for <calls>...</calls> block
    m = calls_block_re.search(completion)
    if not m:
        # Big penalty for missing calls block (but not autofail beyond this)
        reward -= 0.7
        if debug:
            print("[PENALTY] No <calls>...</calls> block found -> -0.7")
        # Nothing else to score, just clamp and return
        return max(min(reward, 1.0), -1.0)

    # 2. Check that completion is basically just the calls block (plus optional EOS)
    #    Instead of failing, we apply penalties for extra text.
    if not (completion.startswith("<calls>") and completion.endswith("</calls>")):
        prefix = completion.split("<calls>", 1)[0].strip()
        suffix = completion.split("</calls>", 1)[1].strip()

        if prefix:
            reward -= 0.2
            if debug:
                print(f"[PENALTY] Text before <calls> block: {repr(prefix)} -> -0.2")

        # tolerate EOS after </calls>, penalize anything else
        if suffix and suffix != end_token:
            reward -= 0.2
            if debug:
                print(f"[PENALTY] Text after </calls> block: {repr(suffix)} -> -0.2")

    # 3. Penalize if the model tries to answer in natural language style
    if "<response>" in completion.lower():
        reward -= 0.3
        if debug:
            print("[PENALTY] Found <response> tag -> -0.3")

    calls_body = m.group(1)

    # 4. Parse calls body
    parsed_calls = _parse_calls_body_to_dict_list(calls_body)
    if parsed_calls is None:
        # Can't parse -> strong negative, but not hard fail
        reward -= 0.6
        if debug:
            print("[PENALTY] Calls block parse failed -> -0.6")
            print("Calls body:")
            print(calls_body)
    else:
        if parsed_calls:
            # Some valid dicts parsed: give a positive base
            reward += 0.6
            if debug:
                print("[REWARD] Parsed calls dicts -> +0.6")
                print("Parsed calls:", parsed_calls)

    # 5. Allowed tools / params from prompt
    allowed_tools = _extract_allowed_tools(prompt)
    if debug:
        print("\nALLOWED TOOLS FROM PROMPT:")
        print(allowed_tools)

    hallucinated_tools = 0
    hallucinated_params = 0

    if parsed_calls:
        for call in parsed_calls:
            tool_name = call.get("name")
            args      = call.get("arguments", {})

            # Unknown / hallucinated tool
            if tool_name not in allowed_tools:
                hallucinated_tools += 1
                if debug:
                    print(f"[HALLUCINATION] Tool '{tool_name}' not in allowed tools.")
                continue

            allowed_params = allowed_tools[tool_name]
            for p in args.keys():
                if p not in allowed_params:
                    hallucinated_params += 1
                    if debug:
                        print(f"[HALLUCINATION] Param '{p}' not allowed for tool '{tool_name}'.")

        # Apply penalties for hallucinations
        if hallucinated_tools > 0:
            penalty = 0.2 * hallucinated_tools
            reward -= penalty
            if debug:
                print(f"\n[TOOL HALLUCINATION PENALTY] -{penalty:.3f} for {hallucinated_tools} hallucinated tools.")
                print(f"Reward now: {reward}")

        if hallucinated_params > 0:
            penalty = 0.1 * hallucinated_params
            reward -= penalty
            if debug:
                print(f"[PARAM HALLUCINATION PENALTY] -{penalty:.3f} for {hallucinated_params} hallucinated params.")
                print(f"Reward now: {reward}")

        # Small bonus if we have calls and *no* hallucinations
        if hallucinated_tools == 0 and hallucinated_params == 0 and len(parsed_calls) > 0:
            reward += 0.2
            if debug:
                print("[REWARD] No hallucinated tools/params -> +0.2")
                print(f"Reward now: {reward}")

    # 6. Old ID coverage logic (soft)
    user_ids   = _extract_user_ids(prompt)
    called_ids = _extract_called_ids(calls_body)

    if debug:
        print("\nUSER IDS (from <user> block):", user_ids)
        print("CALLED IDS (from <calls> body):", called_ids)
        trimmed_body = calls_body[:300] + ("..." if len(calls_body) > 300 else "")
        print("CALLS BODY (trimmed):")
        print(trimmed_body)

    if user_ids:
        covered = user_ids.intersection(called_ids)
        recall = len(covered) / len(user_ids) if user_ids else 0.0

        # Scale coverage contribution modestly
        coverage_bonus = 0.2 * recall
        reward += coverage_bonus
        if debug:
            print("\n[COVERAGE]")
            print("  Covered IDs:", covered)
            print(f"  Recall = {recall:.3f}")
            print(f"  +{coverage_bonus:.3f} for coverage.")
            print(f"Reward now: {reward}")

        if recall == 1.0:
            reward += 0.1
            if debug:
                print("  [BONUS] All user IDs covered -> +0.1")
                print(f"Reward now: {reward}")

        hallucinations = called_ids - user_ids
        if hallucinations:
            penalty = 0.1 * len(hallucinations)
            reward -= penalty
            if debug:
                print("\n[ID HALLUCINATIONS]")
                print("  Hallucinated IDs:", hallucinations)
                print(f"  -{penalty:.3f} for hallucinated IDs.")
                print(f"Reward now: {reward}")

    # 7. Light style bonuses (very small)
    if "{" in calls_body and "}" in calls_body:
        reward += 0.05
        if debug:
            print("\n[STYLE] +0.05 for having braces in calls body.")
            print(f"Reward now: {reward}")
    if "name" in calls_body:
        reward += 0.05
        if debug:
            print("[STYLE] +0.05 for 'name' in calls body.")
            print(f"Reward now: {reward}")

    # 8. Length penalty (tiny)
    if len(completion) > 600:
        reward -= 0.1
        if debug:
            print("\n[LENGTH] Completion too long (>600 chars) -> -0.1")
            print(f"Reward now: {reward}")

    # 9. Clamp to [-1, 1]
    clamped = max(min(reward, 1.0), -1.0)
    if debug:
        print("\n[FINAL]")
        print(f"Unclamped reward: {reward:.3f}")
        print(f"Clamped reward:   {clamped:.3f}")
        print("==============================================")
    return clamped


def calls_reward_func(prompts, completions, completion_ids, trainer_state, debug: bool = False, **kwargs):
    rewards = []
    for prompt, completion in zip(prompts, completions):
        r = _compute_single_reward(prompt, completion, debug=debug)
        rewards.append(float(r))
    return rewards


# Creating Training and LoraConfigs

In [None]:

training_config: TrainingConfig = create_training_config(MODEL_NAME)
print("creating training config", type(training_config))

peft_config: LoraConfig = create_lora_config(training_config)
print("creating peft config", type(peft_config))




## Creating GRPO Config


### Training Basics
- **batch_size=32**, **grad_accum=2** ‚Üí effective batch = 64
- **epochs=1**, **lr=1e-6** ‚Üí very gentle updates
- **bf16=True** ‚Üí faster training, less memory

### GRPO-Specific
- **beta=0.05** ‚Üí KL penalty strength (keeps model close to base policy)
- **num_generations=8** ‚Üí for each prompt, generate 8 candidates ‚Üí compute rewards ‚Üí update policy
- **scale_rewards=True** ‚Üí normalizes reward distribution for stability

### Logging & Saving
- Logs every **10 steps** to W&B  
- Saves checkpoint every **50 steps**, keeps last **5**

### Data & Length
- **max_prompt_length=1024**, **max_completion_length=512**
- Uses **3 dataloader workers** + pinned memory ‚Üí faster input pipeline

### vLLM Integration
- **use_vllm=True**, **mode="colocate"** ‚Üí extremely fast generation for GRPO (trying to speedup generation)


In [None]:

from trl import GRPOTrainer, GRPOConfig

# Use a fresh output dir for GRPO runs
grpo_output_dir = training_config.output_dir.replace("PPO_xLAM", "GRPO_xLAM")

grpo_config = GRPOConfig(
    output_dir=grpo_output_dir,
    per_device_train_batch_size=32,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=1e-6,
    beta = 0.05,
    logging_strategy="steps",
    logging_steps=10,
    report_to="wandb",
    run_name="qwen2-grpo-xlam",
    save_strategy="steps",
    save_steps=50,
    save_total_limit=5,
    bf16=True,
    remove_unused_columns=False,
    dataloader_num_workers=3,
    dataloader_pin_memory=True,
    max_prompt_length=1024,
    max_completion_length=512,
    num_generations=8,
    scale_rewards=True,
    use_vllm = True,
    vllm_mode = "colocate"
)




# Training and Model Saving


In [None]:
from trl import PPOConfig, PPOTrainer
from datasets import Dataset

print(trl.__version__)

print(fine_tuned_model.device)
print(fine_tuned_model.model.embed_tokens.weight.device)
def print_mem(label=""):
    alloc = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"[{label}] allocated={alloc:.2f} GB, reserved={reserved:.2f} GB")

trainer = GRPOTrainer(
    model=fine_tuned_model,        # your Qwen SFT checkpoint
    args=grpo_config,
    reward_funcs=calls_reward_func,
    train_dataset=grpo_train_dataset,
    eval_dataset=grpo_eval_dataset,
    processing_class=tokenizer,    # tokenizer with left padding + pad_token set
    peft_config=peft_config,       # your LoRA config from create_lora_config(...)
)
print_mem("after init")
trainer.train()
trainer.evaluate()
