In [None]:
import wandb
import statistics
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer

# 1) Initialize W&B
wandb.init(project="my_grpo_project", name="run-grpo-with-rewards")
model=finetune_model

# 3) Prepare a dataset of raw-text prompts
#    Here we use wikitext-2-raw-v1 and take the first 100 non-empty lines.
raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
# Keep only the 'text' column, rename it to 'prompt', and filter out blanks
ds = raw.remove_columns([c for c in raw.column_names if c != "text"])
ds = ds.rename_column("text", "prompt")
ds = ds.filter(lambda x: x["prompt"].strip() != "")  # drop empty prompts
# Take a small subset for demo
train_dataset = ds.shuffle(seed=42).select(range(100))

# 4) Your custom scoring function
def my_custom_score(text: str) -> float:
    # Replace with your real reward logic
    return float(len(text))

# 5) TRL-compatible reward fn that logs each step to console & W&B
def my_reward_fn(completions, **kwargs):
    step = kwargs.get("step", None)
    rewards = [my_custom_score(gen) for gen in completions]

    # Console print
    print(f"[GRPO step {step}] rewards = {rewards}")

    # W&B Table
    table = wandb.Table(columns=["step", "sample_idx", "reward", "generation"])
    for idx, (r, gen) in enumerate(zip(rewards, completions)):
        table.add_data(step, idx, r, gen)
    wandb.log({"reward_table": table}, step=step)

    # Aggregate scalars
    wandb.log({
        "reward_mean": sum(rewards) / len(rewards),
        "reward_std": statistics.stdev(rewards) if len(rewards) > 1 else 0.0,
    }, step=step)

    return rewards

# 6) Configure GRPO: ensure batch sizes are compatible
config = GRPOConfig(
    output_dir               = "./grpo_outputs",
    report_to                = "wandb",
    logging_strategy         = "steps",
    logging_steps            = 1,
    # per_device_train_batch_size = 4,   # effective batch size = 4
    # generation_batch_size    = 4,      # must divide 4 evenly
)

# 7) Instantiate and run the trainer
trainer = GRPOTrainer(
    model         = model,
    processing_class     = tokenizer,      # explicitly pass tokenizer
    reward_funcs  = my_reward_fn,
    args          = config,
    train_dataset = train_dataset,  # each example has a 'prompt' key now
)
trainer.train()

# 8) Finish W&B run
wandb.finish()
