In [None]:
# %% 
# Install required packages (run once)
!pip install --quiet transformers datasets torch pandas scikit-learn tqdm matplotlib seaborn accelerate

In [None]:
# Trainer will use all available GPUs by default if you have multiple
# Just make sure CUDA_VISIBLE_DEVICES isn't limiting you
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Use GPUs 0 and 1 only
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
import torch
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    default_data_collator,
    ProgressCallback
)

# ── CONFIG ─────────────────────────────────────────────────────────────────────
MODEL_NAME = "/workspace/huggingface_cache/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b"
TRAIN_DATA_PATH  = "data/train_alpaca.json"
TEST_DATA_PATH = "data/test_alpaca.json"
OUTPUT_DIR = "./truthfulness_model"
MAX_LENGTH = 4096  # ensure prompt+label comfortably fits below this

# ── LOAD DATA ──────────────────────────────────────────────────────────────────
with open(TRAIN_DATA_PATH, 'r') as f:
    train_data = json.load(f)
train_df = pd.DataFrame(train_data)
df_unique_train = train_df.groupby('consistency_id').first().reset_index()

with open(TEST_DATA_PATH, 'r') as f:
    test_data = json.load(f)
test_df = pd.DataFrame(test_data)
df_unique_test = test_df.groupby('consistency_id').first().reset_index()

print(f"Training on {len(df_unique_train)} unique training examples\n")
print(f"Training on {len(df_unique_test)} unique testing examples\n")

# ── DATA PROCESSOR ──────────────────────────────────────────────────────────────
class TruthfulnessProcessor:
    def __init__(self, tokenizer, max_length=MAX_LENGTH):
        self.tokenizer = tokenizer
        self.max_length = max_length
        # make sure this matches your label semantics exactly:
        self.prompt_template = (
            "Human: {question}\n"
            "Response 1: {choice}\n"
            "Response 2: {choice_2}\n"
            "Claim: Response 1 is more truthful than Response 2.\n"
            "I think this claim is "
        )

    def tokenize_fn(self, examples):
        # 1) build bare prompts (no label)
        prompts = [
            self.prompt_template.format(
                question=examples['question'][i],
                choice=examples['choice'][i],
                choice_2=examples['choice_2'][i]
            )
            for i in range(len(examples['question']))
        ]
        # 2) tokenize prompts alone (to get true prompt lengths)
        prompt_enc = self.tokenizer(
            prompts,
            max_length=self.max_length,
            padding="max_length",
            truncation=False, 
            return_length=True,
            return_tensors="pt"
        )

        # 3) build full inputs (prompt + label)
        full_texts = [
            p + examples['label'][i]
            for i, p in enumerate(prompts)
        ]
        full_enc = self.tokenizer(
            full_texts,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # 4) prepare labels: mask everything except the *actual* prompt span
        labels = full_enc["input_ids"].clone()
        pad_id = self.tokenizer.pad_token_id
        eos_id   = self.tokenizer.eos_token_id
        attention_mask = full_enc["attention_mask"]
        for i, prompt_ids in enumerate(prompt_enc["input_ids"]):
            # count prompt tokens (everything up to first pad)
            # if no pad_token, torch.nonzero works too
            prompt_len = (prompt_ids != pad_id).sum().item()
            if prompt_len < self.max_length:
                labels[i, :prompt_len] = -100
            else:
                # prompt consumed full length → label got truncated
                print("prompt consumed full length → label got truncated")
                labels[i, :] = -100

        # 2) mask out all padding tokens
        # labels = torch.where(
        #     attention_mask.bool(),
        #     labels,
        #     torch.full_like(labels, -100),
        # )

        # # 3) mask out EOS if you don’t want to train on that either
        # labels = torch.where(
        #     labels == eos_id,
        #     torch.full_like(labels, -100),
        #     labels
        # )

        return {
            "input_ids":      full_enc["input_ids"],
            "attention_mask": full_enc["attention_mask"],
            "labels":         labels
        }

# ── MODEL & TOKENIZER ─────────────────────────────────────────────────────────
print("Loading model and tokenizer…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model     = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"  ▶︎ Loaded on {model.device}\n")

# ── PREPARE DATASETS ───────────────────────────────────────────────────────────
# train_df, val_df = train_test_split(
#     df_unique,
#     test_size=0.1,
#     random_state=42,
#     stratify=df_unique['label']
# )

processor =  TruthfulnessProcessor(tokenizer)
train_ds  =  Dataset.from_pandas(df_unique_train)
test_ds   =  Dataset.from_pandas(df_unique_test)

print("Tokenizing…")
train_ds = train_ds.map(processor.tokenize_fn, batched=True, batch_size=16,
                        remove_columns=train_ds.column_names)
test_ds  = test_ds.map(processor.tokenize_fn, batched=True, batch_size=16,
                        remove_columns=test_ds.column_names)

# ── METRICS ────────────────────────────────────────────────────────────────────
def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = logits.argmax(-1)
    mask  = labels != -100
    acc   = (preds[mask] == labels[mask]).astype(float).mean()
    return {"accuracy": float(acc)}


max_eval = 10
if len(test_ds) > max_eval:
    print(f"Limiting eval set from {len(test_ds)} → {max_eval} examples")
    test_ds = test_ds.select(range(max_eval))

# ── TRAINING ARGS ───────────────────────────────────────────────────────────────
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    max_steps=200,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    logging_strategy="steps",
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=25,
    save_strategy="no",
    warmup_steps=10,
    learning_rate=5e-5,
    fp16=False,
    prediction_loss_only=False,  # so compute_metrics runs
    eval_accumulation_steps=1,
    report_to=[]
)

# ── TRAINER ───────────────────────────────────────────────────────────────────
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
    callbacks=[ProgressCallback],
)

# ── TRAIN! ────────────────────────────────────────────────────────────────────
print("=== Starting training ===")
train_result = trainer.train()
print("=== Training complete ===")
print(f"Final train loss: {train_result.metrics['train_loss']:.4f}")
print(f"Train throughput: {train_result.metrics['train_samples_per_second']:.1f} samples/sec")

# ── SMOKE TEST ─────────────────────────────────────────────────────────────────
print("\nQuick inference test:")
example = {
    "question": "What is 2+2?",
    "choice":   "2+2 equals 4",
    "choice_2": "2+2 equals 5",
}

prompt = processor.prompt_template.format(**example)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=5, temperature=0.1)
decoded = tokenizer.decode(out[0], skip_special_tokens=True)
print(f"Prompt → “{prompt}”")
print("Model says →", decoded[len(prompt):])