In [None]:
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True 

from huggingface_hub import login

from reward import ZweigRewardFunction
from constants import DEVICE, DATASET_NAME, TEST_SIZE, MODEL_PATH, CHECKPOINT_PATH, REFERENCES

In [None]:
import os
print("Logging in with read token")
login(token=os.environ["HF_READ_TOKEN"])

print("Logging in with write token")
login(token=os.environ["HF_WRITE_TOKEN"])

In [None]:
# Print number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")

# Print information for each GPU
for i in range(num_gpus):
    gpu = torch.cuda.get_device_properties(i)
    print(f"\nGPU {i}: {gpu.name}")
    # Total memory in GB
    total_memory = gpu.total_memory / 1024**3
    # Get current memory usage in GB
    memory_allocated = torch.cuda.memory_allocated(i) / 1024**3
    memory_reserved = torch.cuda.memory_reserved(i) / 1024**3
    
    print(f"Total memory: {total_memory:.2f} GB")
    print(f"Allocated memory: {memory_allocated:.2f} GB")
    print(f"Reserved memory: {memory_reserved:.2f} GB")
    print(f"Free memory: {total_memory - memory_allocated:.2f} GB")

In [None]:
def format_dataset_for_rl(dataset):
    def format_example(example):
        prompt = f"<|start_of_role|>system<|end_of_role|>{example['system_prompt']}<|end_of_text|>\n"
        prompt += f"<|start_of_role|>user<|end_of_role|>{example['prompt']}<|end_of_text|>\n"
        prompt += "<|start_of_role|>assistant<|end_of_role|><stefan_zweig>"
        return {"prompt": prompt}
    
    return dataset.map(format_example, remove_columns=dataset.column_names)

In [None]:
def create_rl_datasets():
    ds = load_dataset(DATASET_NAME, "default")["train"]
    ds = ds.train_test_split(test_size=TEST_SIZE)
    return (
        format_dataset_for_rl(ds['train']),
        format_dataset_for_rl(ds['test'])
    )

In [None]:
def load_models_and_tokenizer():
    # Load base model (without value head)
    base_model = AutoModelForCausalLM.from_pretrained(
        CHECKPOINT_PATH,
        torch_dtype=torch.bfloat16,
        device_map=DEVICE,
        low_cpu_mem_usage=True,
    )
    
    base_model.train()
    for param in base_model.parameters():
        param.requires_grad = True    

    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # Critical for generation
    
    return base_model, tokenizer

In [None]:
print("Loading models and tokenizer...")
model, tokenizer = load_models_and_tokenizer()

In [None]:
print("Creating datasets...")
train_dataset, eval_dataset = create_rl_datasets()

In [None]:
print("Initializing reward function...")
reward_func = ZweigRewardFunction(tokenizer, REFERENCES)

In [None]:
assert any(p.requires_grad for p in model.parameters()), "No trainable parameters!"

In [None]:
# Before training, test forward pass
test_input = tokenizer("Test prompt:", return_tensors="pt").to(model.device)
with torch.no_grad():
	output = model(**test_input)
assert output.logits.requires_grad is False, "Unexpected gradient in test pass"

# Test reward function
test_reward = reward_func(["Test prompt"], ["Test response"])
assert isinstance(test_reward, torch.Tensor), "Reward should return tensor"
assert test_reward.device == model.device, "Device mismatch"

In [None]:
grpo_config = GRPOConfig(
	output_dir="stefan_zweig_RL",
	learning_rate=1e-5,
	beta=0.04,
	num_generations=4,  # Reduced from 8
	temperature=0.9,
	max_prompt_length=384,  # Reduced from 512
	max_completion_length=384,  # Reduced from 512
	per_device_train_batch_size=1,
	gradient_accumulation_steps=16,  # Increased to maintain batch size
	gradient_checkpointing=True,  # Activation checkpointing
	optim="adamw_torch_fused",  # More memory-efficient optimizer
	fp16=False,
	bf16=True,  # Use bfloat16 instead of fp16
	tf32=True,  # Enable TensorFloat-32
	report_to="none",
	logging_steps=10,
	remove_unused_columns=True,
)

In [None]:
grpo_trainer = GRPOTrainer(
	model=model,
	reward_funcs=reward_func,
	args=grpo_config,
	train_dataset=train_dataset,
	eval_dataset=eval_dataset,
	processing_class=tokenizer,
)

In [None]:
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
	grpo_trainer.train()

In [None]:
grpo_trainer.push_to_hub()