In [1]:
%pip install -q trl==0.10.1 tqdm

**Mount drive to get conversations and empathy scores**

In [None]:
from google.colab import drive
from pathlib import Path

BATCH = 1

drive.mount('/content/drive')

conversations = Path(f'/content/drive/My Drive/conversations/batch_{BATCH}')

print(f'Found {len(list(conversations.glob("*.json"))):,} files in batch {BATCH}')

In [3]:
import json

with (conversations / 'empathy_scores.json').open('r', encoding='utf8') as f:
    empathy_scores = json.load(f)

**Prepare training dataset**

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

In [5]:
MAX_QUERY_LEN = 1024

def extract_ppo_training_samples(convo_data, reward, tokenizer, max_length=2048):
    samples = []
    messages = []

    for message in convo_data:
        role = message["role"]
        content = message["content"].strip()

        messages.append({
            "role": role,
            "content": content
        })

        if role == "assistant":
            prompt = tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True, tokenize=False)
            response = content

            truncated_prompt = prompt[-MAX_QUERY_LEN:]
            truncated_response = response[-MAX_QUERY_LEN:]

            query_tokens = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_QUERY_LEN)
            response_tokens = tokenizer(response, return_tensors="pt", truncation=True, max_length=MAX_QUERY_LEN)

            samples.append({
                "query": prompt,
                "query_raw": messages[:-1],
                "response": response,
                "truncated_query": truncated_prompt,
                "truncated_response": truncated_response,
                "reward": reward,
                "input_ids": query_tokens.input_ids.to('cuda'),
                "response_ids": response_tokens.input_ids.to('cuda')
            })

    return samples

In [6]:
from tqdm import tqdm

training_data = []

def normalize_reward(score: int) -> float:
    return score / 10.0

for path in tqdm(list(conversations.glob('*.json')), desc="Processing conversations"):
    if path.stem == 'empathy_scores':
        continue

    with path.open('r', encoding='utf-8') as f:
        data = json.load(f)

    training_data.extend(extract_ppo_training_samples(data['convo'], normalize_reward(empathy_scores[path.stem]), tokenizer))

Processing conversations: 100%|██████████| 1001/1001 [00:57<00:00, 17.27it/s]


In [7]:
del tokenizer

In [8]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


**Training loop**

In [None]:
from transformers import AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead, create_reference_model
import gc
import os


model_id = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

config = PPOConfig(
    model_name=model_id,
    learning_rate=1e-5,
    batch_size=1,
    mini_batch_size=1,
    cliprange=0.2,
    kl_penalty="kl",
    init_kl_coef=0.05
)

trainer = PPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    config=config
)


def train_on_sample(sample):
    query_ids = sample["input_ids"]
    response_ids = sample["response_ids"]
    reward = torch.tensor(sample["reward"], device=device)

    trainer.step([query_ids[0]], [response_ids[0]], [reward])

    del query_ids, response_ids, reward
    torch.cuda.empty_cache()
    gc.collect()


for idx, sample in enumerate(training_data, start=1):
    print(f"Sample {idx}/{len(training_data)}")
    train_on_sample(sample)

    if idx % 100 == 0:
        print(f"Saving checkpoint after sample {idx}")
        try:
            save_dir = f"/content/ppo_model_final_2_{idx}"
            os.makedirs(save_dir, exist_ok=True)

            trainer.model.save_pretrained(save_dir)
            trainer.tokenizer.save_pretrained(save_dir)
            torch.save(trainer.model.v_head.state_dict(), os.path.join(save_dir, "value_head.pt"))

        except Exception as e:
            print(f"Error saving: {e}")

        break

print("Training complete!")
