In [1]:
%pip install -q trl==0.10.1 tqdm

**Mount drive to get conversations and empathy scores**

In [None]:
from google.colab import drive
from pathlib import Path

BATCH = 1

drive.mount('/content/drive')

conversations = Path(f'/content/drive/My Drive/conversations/batch_{BATCH}')

print(f'Found {len(list(conversations.glob("*.json"))):,} files in batch {BATCH}')

In [3]:
import json

with (conversations / 'empathy_scores.json').open('r', encoding='utf8') as f:
    empathy_scores = json.load(f)

**Prepare training dataset**

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
MAX_QUERY_LEN = 1024

def extract_ppo_training_samples(convo_data, reward, tokenizer):
    samples = []
    messages = []

    for message in convo_data:
        role = message["role"]
        content = message["content"].strip()

        messages.append({
            "role": role,
            "content": content
        })

        if role == "assistant":
            prompt = tokenizer.apply_chat_template(messages[:-1], add_generation_prompt=True, tokenize=False)
            response = content

            truncated_prompt = prompt[-MAX_QUERY_LEN:]
            truncated_response = response[-MAX_QUERY_LEN:]

            query_tokens = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_QUERY_LEN)
            response_tokens = tokenizer(response, return_tensors="pt", truncation=True, max_length=MAX_QUERY_LEN)

            samples.append({
                "query": prompt,
                "query_raw": messages[:-1],
                "response": response,
                "truncated_query": truncated_prompt,
                "truncated_response": truncated_response,
                "reward": reward,
                "input_ids": query_tokens.input_ids.to('cuda'),
                "response_ids": response_tokens.input_ids.to('cuda')
            })

    return samples

In [6]:
from tqdm import tqdm

training_data = []

def normalize_reward(score: int) -> float:
    return score / 10.0

for path in tqdm(list(conversations.glob('*.json')), desc="Processing conversations"):
    if path.stem == 'empathy_scores':
        continue

    with path.open('r', encoding='utf-8') as f:
        data = json.load(f)

    training_data.extend(extract_ppo_training_samples(data['convo'], normalize_reward(empathy_scores[path.stem]), tokenizer))

Processing conversations: 100%|██████████| 1001/1001 [00:57<00:00, 17.27it/s]


In [7]:
del tokenizer

In [8]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


**Training loop**

In [None]:
from transformers import AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
import gc
import os


model_id = "meta-llama/Llama-3.2-3B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="auto"
)

config = PPOConfig(
    model_name=model_id,
    learning_rate=1e-5,
    batch_size=1,
    mini_batch_size=1,
    cliprange=0.2,
    kl_penalty="kl",
    init_kl_coef=0.05
)

trainer = PPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    config=config
)


def train_on_sample(sample):
    query_ids = sample["input_ids"]
    response_ids = sample["response_ids"]
    reward = torch.tensor(sample["reward"], device=device)

    trainer.step([query_ids[0]], [response_ids[0]], [reward])

    del query_ids, response_ids, reward
    torch.cuda.empty_cache()
    gc.collect()


for idx, sample in enumerate(training_data, start=1):
    print(f"Sample {idx}/{len(training_data)}")
    train_on_sample(sample)

    if idx % 100 == 0:
        print(f"Saving checkpoint after sample {idx}")
        try:
            save_dir = f"/content/ppo_model_final_2_{idx}"
            os.makedirs(save_dir, exist_ok=True)

            trainer.model.save_pretrained(save_dir)
            trainer.tokenizer.save_pretrained(save_dir)
            torch.save(trainer.model.v_head.state_dict(), os.path.join(save_dir, "value_head.pt"))

        except Exception as e:
            print(f"Error saving: {e}")

        break

print("Training complete!")


**Training loop 2**

In [None]:
import torch
from transformers import AutoTokenizer
from trl import PPOConfig, PPOTrainer, AutoModelForCausalLMWithValueHead
import gc
import os

torch.cuda.empty_cache()
gc.collect()

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model_id = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map={"": device}, 
    max_memory={0: "30GB"} 
)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map={"": device} 
)

model.gradient_checkpointing_enable()
ref_model.gradient_checkpointing_enable()

config = PPOConfig(
    model_name=model_id,
    learning_rate=5e-6, 
    batch_size=1,
    mini_batch_size=1,
    cliprange=0.1, 
    cliprange_value=0.1,
    kl_penalty="kl",
    init_kl_coef=0.1, 
    target_kl=0.5, 
    gradient_accumulation_steps=1,
    vf_coef=0.1,  
    max_grad_norm=0.5
)

trainer = PPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    config=config
)

def preprocess_sample(sample):
    if isinstance(sample["input_ids"], torch.Tensor):
        sample["input_ids"] = sample["input_ids"].to(device)
    else:
        sample["input_ids"] = torch.tensor(sample["input_ids"], device=device)
        
    if isinstance(sample["response_ids"], torch.Tensor):
        sample["response_ids"] = sample["response_ids"].to(device)
    else:
        sample["response_ids"] = torch.tensor(sample["response_ids"], device=device)
        
    if isinstance(sample["reward"], torch.Tensor):
        sample["reward"] = sample["reward"].to(device)
    else:
        sample["reward"] = torch.tensor(sample["reward"], device=device)
        
    return sample

def train_on_sample(sample):
    try:
        # Ensure all tensors are on the same device
        sample = preprocess_sample(sample)
        
        query_ids = sample["input_ids"]
        response_ids = sample["response_ids"]
        reward = sample["reward"]
        
        # Process with mixed precision
        with torch.cuda.amp.autocast():
            stats = trainer.step([query_ids[0]], [response_ids[0]], [reward])
        
        # Explicit cleanup
        del query_ids, response_ids, reward
        torch.cuda.empty_cache()
        gc.collect()
        
        return stats
    except RuntimeError as e:
        if "out of memory" in str(e):
            torch.cuda.empty_cache()
            gc.collect()
            print("OOM error, skipping sample")
            return None
        else:
            print(f"Runtime error: {e}")
            raise e

consecutive_errors = 0
max_consecutive_errors = 5

for idx, sample in enumerate(training_data, start=1):
    print(f"Sample {idx}/{len(training_data)}")
    
    input_len = len(sample["input_ids"][0]) if isinstance(sample["input_ids"], list) else sample["input_ids"].shape[1]
    response_len = len(sample["response_ids"][0]) if isinstance(sample["response_ids"], list) else sample["response_ids"].shape[1]
    
    if input_len + response_len > 2048:
        print(f"Skipping oversized sample {idx}")
        continue
    
    try:
        stats = train_on_sample(sample)
        consecutive_errors = 0 
        
        # Force garbage collection more frequently
        if idx % 5 == 0:
            torch.cuda.empty_cache()
            gc.collect()
        
        # Save less frequently to reduce memory pressure
        if idx % 250 == 0:
            print(f"Saving checkpoint after sample {idx}")
            try:
                torch.cuda.empty_cache()  # Clear cache before saving
                save_dir = f"/content/ppo_model_final_4_{idx}"
                os.makedirs(save_dir, exist_ok=True)
                trainer.model.save_pretrained(save_dir)
                trainer.tokenizer.save_pretrained(save_dir)
                torch.save(trainer.model.v_head.state_dict(), os.path.join(save_dir, "value_head.pt"))
            except Exception as e:
                print(f"Error saving: {e}")
    
    except Exception as e:
        print(f"Error on sample {idx}: {e}")
        # Try to recover
        torch.cuda.empty_cache()
        gc.collect()
        
        consecutive_errors += 1
        if consecutive_errors >= max_consecutive_errors:
            print(f"Too many consecutive errors ({consecutive_errors}). Stopping training.")
            break