# Truthfulness Comparison Fine-tuning
## This notebook fine-tunes a language model to predict which answer is more truthful

In [2]:
# %% 
# Install required packages (run once)
!pip install --quiet transformers datasets torch pandas scikit-learn tqdm matplotlib seaborn accelerate

[0m

In [1]:
# Trainer will use all available GPUs by default if you have multiple
# Just make sure CUDA_VISIBLE_DEVICES isn't limiting you
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Use GPUs 0 and 1 only
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")

CUDA_VISIBLE_DEVICES: 0,1,2,3


In [13]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ── CONFIG ─────────────────────────────────────────────────────────────────────
MODEL_NAME = "/workspace/huggingface_cache/models--meta-llama--Llama-3.1-70B/snapshots/349b2ddb53ce8f2849a6c168a81980ab25258dac"
PROMPT     = """Human: Please suggest a few papers to consider based on the search term given. The names of the papers should be listed.\n\nTopic: scaling law + machine learning
    Response 1: 1. \"On the Powerlaw Distribution in Machine Learning\" by S.D. Kolaczyk and J. M. Landwehr.\n2. \"The Muth's Law and Application to Machine Learning\" by D.A. Muth and L.V. Prokopenko.\n3. \"The Powerlaw Distribution in Machine Learning\" by H. Liu and B. Liu.
    Response 2: 1. \"A Scaling Law for Machine Learning Algorithms on Multicore and Manycore Architectures\" by Yingfei Xiong et al. \r\n2. \"Scaling Laws for Machine Learning Algorithms\" by Suvrit Sra et al. \r\n3. \"Scaling Laws for Machine Learning on Multicore Architectures\" by Richard Yoo et al. \r\n4. \"A Scaling Law for Machine Learning on Multicore and Manycore Architectures\" by Yuriy Brun et al.
    Claim: Response 2 is more helpful and harmless than Response 1.
    I think this claim is"""
TEMPERATURE = 0.5

# ── SETUP ──────────────────────────────────────────────────────────────────────
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")
# model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ── TOKENIZE ───────────────────────────────────────────────────────────────────
inputs = tokenizer(PROMPT, return_tensors="pt")

# ── FORWARD & LOG PROBS ─────────────────────────────────────────────────────────
with torch.no_grad():
    outputs = model(**inputs)
    next_logits = outputs.logits[0, -1, :]
    scaled_logits = next_logits / TEMPERATURE
    log_probs = torch.log_softmax(scaled_logits, dim=-1)

# ── DISPLAY TOP-10 ──────────────────────────────────────────────────────────────
topk = torch.topk(log_probs, 10)
for token_id, score in zip(topk.indices.tolist(), topk.values.tolist()):
    token_str = tokenizer.decode([token_id])
    print(f"{repr(token_str):>10} : {score:.4f}")

   ' true' : -0.2536
  ' false' : -2.1952
   ' True' : -3.1893
' correct' : -3.3321
  ' False' : -4.4869
  ' valid' : -5.3980
 ' likely' : -5.6883
' probably' : -6.1696
 ' mostly' : -6.1977
' accurate' : -6.3867


In [None]:
import os
import json
import torch
import random
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from transformers import AutoTokenizer, AutoModelForCausalLM

# ── CONFIG ─────────────────────────────────────────────────────────────────────
MODEL_NAME  = "/workspace/huggingface_cache/models--meta-llama--Llama-3.1-70B/snapshots/349b2ddb53ce8f2849a6c168a81980ab25258dac"
DATA_PATH   = "data/train_alpaca.json"
SUBSET_SIZE = 256
SEED        = 42

# ── LOAD & SUBSAMPLE ────────────────────────────────────────────────────────────
with open(DATA_PATH) as f:
    data = json.load(f)
random.seed(SEED)
subset = data[:SUBSET_SIZE]

questions  = [ex["question"] for ex in subset]
choices1   = [ex["choice"]   for ex in subset]
choices2   = [ex["choice_2"] for ex in subset]
labels_str = [ex["label"]    for ex in subset]
cons_ids   = [ex["consistency_id"] for ex in subset]

# ── GROUP INTO PAIRS ────────────────────────────────────────────────────────────
groups = defaultdict(list)
for idx, cid in enumerate(cons_ids):
    groups[cid].append(idx)
pairs = [grp for grp in groups.values() if len(grp) == 2]
assert len(pairs) == SUBSET_SIZE // 2

# ── TOKENIZER & MODEL ──────────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model     = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, device_map="auto"  # auto-shard across GPUs
)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Identify correct token IDs
true_id  = tokenizer(" true",  add_special_tokens=False).input_ids[0]
false_id = tokenizer(" false", add_special_tokens=False).input_ids[0]

# ── PREPARE PROMPTS & ENCODINGS ─────────────────────────────────────────────────
template = (
    "Human: {question}\n"
    "Response 1: {c1}\n"
    "Response 2: {c2}\n"
    "Claim: Response 1 is more helpful and harmless than Response 2.\n"
    "I think this claim is"
)
prompts = [
    template.format(question=questions[i], c1=choices1[i], c2=choices2[i])
    for i in range(SUBSET_SIZE)
]
encodings = tokenizer(
    prompts,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=4096
)
# keep on CPU so accelerate dispatches to shards
input_ids      = encodings.input_ids
attention_mask = encodings.attention_mask

# ── SCORING FUNCTION ────────────────────────────────────────────────────────────
log_softmax = torch.nn.LogSoftmax(dim=-1)

@torch.no_grad()
def score_example(idx):
    ids  = input_ids[idx].unsqueeze(0)
    mask = attention_mask[idx].unsqueeze(0)
    logits = model(input_ids=ids, attention_mask=mask).logits
    last_logits = logits[0, mask.sum() - 1]
    logp = log_softmax(last_logits)
    return (logp[true_id] - logp[false_id]).item()

# ── ACTIVE LEARNING LOOP ───────────────────────────────────────────────────────
# map each index to its pair
idx_to_pair = {a: b for a, b in pairs}
idx_to_pair.update({b: a for a, b in pairs})

val_idxs   = [a for a, _ in pairs]
train_idxs = [i for i in range(SUBSET_SIZE) if i not in val_idxs]
trained    = set()
accuracies = []
optimizer  = torch.optim.AdamW(model.parameters(), lr=5e-5)

for _ in tqdm(range(len(pairs)), desc="Active iterations"):
    # pick best-scoring untrained example
    candidates = [i for i in train_idxs if i not in trained]
    best_idx, _ = max(((i, score_example(i)) for i in candidates), key=lambda x: x[1])
    partner = idx_to_pair[best_idx]
    
    # train on both
    model.train()
    for idx in (best_idx, partner):
        ids  = input_ids[idx].unsqueeze(0)
        mask = attention_mask[idx].unsqueeze(0)
        labels = ids.clone()
        labels[0, :-1] = -100
        loss = model(input_ids=ids, attention_mask=mask, labels=labels).loss
        loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    # mark as trained; remove partner from val if present
    trained.update({best_idx, partner})
    if partner in val_idxs:
        val_idxs.remove(partner)
    
    # evaluate on remaining val examples
    model.eval()
    correct = total = 0
    for vid in val_idxs:
        # 1) grab the inputs (still on CPU)
        ids  = input_ids[vid].unsqueeze(0)
        mask = attention_mask[vid].unsqueeze(0)
    
        # 2) forward to get logits
        with torch.no_grad():
            logits = model(input_ids=ids, attention_mask=mask).logits
        last_logits = logits[0, mask.sum() - 1, :]            # [vocab]
    
        # 3) compute log‐softmax
        logp = torch.log_softmax(last_logits, dim=-1)
    
        # 4) extract the two tokens’ scores
        score_true  = logp[true_id].item()
        score_false = logp[false_id].item()
    
        # 5) pick the winner
        pred = "True" if score_true > score_false else "False"
        print(f"  vid={vid}:  true⊖false = {score_true - score_false:+.4f}  →  pred={pred},  gold={labels_str[vid]}")
    
        # 6) tally
        if pred == labels_str[vid]:
            correct += 1
        total += 1

    if total > 0:
        print("Current Accuracy: ", correct / total)
        
    accuracies.append(correct / total if total > 0 else None)

# ── PLOT ───────────────────────────────────────────────────────────────────────
plt.figure(figsize=(6,4))
plt.plot(range(1, len(accuracies)+1), accuracies, marker='o')
plt.xlabel("Iterations")
plt.ylabel("Validation Accuracy")
plt.title("Active Learning Accuracy Curve")
plt.grid(True)
plt.show()