<a href="https://colab.research.google.com/github/hi-wesley/mini-panacea/blob/main/Mini_Panacea.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install -q "transformers>=4.43.0" "accelerate>=0.30.0" bitsandbytes peft datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from datasets import load_dataset, Dataset
import random

# 3B-ish chat model (Phi-3 Mini Instruct)
BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"  # ≈3.8B params

# Training hyperparameters
MAX_LENGTH   = 128      # sequence length
NUM_PROMPTS  = 1000     # how many Alpaca prompts to use
NUM_EPOCHS   = 3
BATCH_SIZE   = 4
LEARNING_RATE = 2e-4

STYLES = ["helpful", "humorous", "philosophical"]
NUM_PREFS = len(STYLES)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [4]:
raw_ds = load_dataset("tatsu-lab/alpaca", split="train")
raw_ds = raw_ds.shuffle(seed=42).select(range(NUM_PROMPTS))
print(raw_ds[0])

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-a09b74b3ef9c3b(…):   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

{'instruction': 'What would be the best type of exercise for a person who has arthritis?', 'input': '', 'output': 'For someone with arthritis, the best type of exercise would be low-impact activities like yoga, swimming, or walking. These exercises provide the benefits of exercise without exacerbating the symptoms of arthritis.', 'text': 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nWhat would be the best type of exercise for a person who has arthritis?\n\n### Response:\nFor someone with arthritis, the best type of exercise would be low-impact activities like yoga, swimming, or walking. These exercises provide the benefits of exercise without exacerbating the symptoms of arthritis.'}


In [5]:
def make_style_response(instruction, style):
    # VERY SIMPLE templates, just to give different "vibes"
    if style == "helpful":
        return f"Sure! Here is a clear and helpful answer to your request: {instruction}"
    if style == "humorous":
        return f"Okay, let's make this fun 😄. Here is a humorous take on: {instruction}"
    if style == "philosophical":
        return f"Let us think more deeply about this. A philosophical reflection on: {instruction}"
    return f"This is a generic answer to: {instruction}"

def expand_row(row):
    out = []
    instr = row["instruction"]
    for s in STYLES:
        out.append({
            "prompt": instr,
            "style": s,
            "target": make_style_response(instr, s),
        })
    return out

expanded_rows = []
for r in raw_ds:
    expanded_rows.extend(expand_row(r))

train_ds = Dataset.from_list(expanded_rows)
print("Number of training examples:", len(train_ds))
print(train_ds[0])

Number of training examples: 3000
{'prompt': 'What would be the best type of exercise for a person who has arthritis?', 'style': 'helpful', 'target': 'Sure! Here is a clear and helpful answer to your request: What would be the best type of exercise for a person who has arthritis?'}


In [6]:
# 5.1: add lambda (preference vector)
def add_lambda(example):
    idx = STYLES.index(example["style"])
    lam = [0.0] * NUM_PREFS
    lam[idx] = 1.0
    example["lambda"] = lam
    return example

train_ds = train_ds.map(add_lambda)
print(train_ds[0])

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

{'prompt': 'What would be the best type of exercise for a person who has arthritis?', 'style': 'helpful', 'target': 'Sure! Here is a clear and helpful answer to your request: What would be the best type of exercise for a person who has arthritis?', 'lambda': [1.0, 0.0, 0.0]}


In [7]:
# 5.2: Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

def tokenize_example(ex):
    # Simple chat-style format
    text = f"User: {ex['prompt']}\nAssistant ({ex['style']}): {ex['target']}"
    tokens = tokenizer(
        text,
        max_length=MAX_LENGTH,
        truncation=True,
        padding="max_length",
    )
    # Labels = input_ids, but ignore padding tokens with -100
    labels = tokens["input_ids"].copy()
    labels = [
        -100 if tok_id == tokenizer.pad_token_id else tok_id
        for tok_id in labels
    ]
    tokens["labels"] = labels
    tokens["lambda"] = ex["lambda"]
    return tokens

tokenized_ds = train_ds.map(
    tokenize_example,
    batched=False,
    remove_columns=train_ds.column_names,
)

print(tokenized_ds[0].keys())

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

dict_keys(['lambda', 'input_ids', 'attention_mask', 'labels'])


In [8]:
def data_collator(batch):
    input_ids = torch.tensor([b["input_ids"] for b in batch], dtype=torch.long)
    attention_mask = torch.tensor([b["attention_mask"] for b in batch], dtype=torch.long)
    labels = torch.tensor([b["labels"] for b in batch], dtype=torch.long)
    pref_vec = torch.tensor([b["lambda"] for b in batch], dtype=torch.float32)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "pref_vec": pref_vec,
    }

train_loader = DataLoader(tokenized_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
batch = next(iter(train_loader))
for k, v in batch.items():
    print(k, v.shape)

input_ids torch.Size([4, 128])
attention_mask torch.Size([4, 128])
labels torch.Size([4, 128])
pref_vec torch.Size([4, 3])


In [14]:
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# 🔧 Important: disable cache so Phi-3 doesn't try to use legacy KV cache
base_model.config.use_cache = False

# Freeze base model params (we don't train them)
for p in base_model.parameters():
    p.requires_grad = False

print("Base model loaded.")
print("Hidden size:", base_model.config.hidden_size)
print("Vocab size:", base_model.config.vocab_size)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Base model loaded.
Hidden size: 3072
Vocab size: 32064


In [18]:
class PanaceaHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_prefs, rank_shared=4, rank_pref=3):
        super().__init__()
        assert rank_pref == num_prefs, "for simplicity, rank_pref = num_prefs"

        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_prefs = num_prefs
        self.rank_shared = rank_shared
        self.rank_pref = rank_pref
        self.rank_total = rank_shared + rank_pref

        # U: [vocab, r], V: [hidden, r]
        self.U = nn.Parameter(torch.randn(vocab_size, self.rank_total) * 0.01)
        self.V = nn.Parameter(torch.randn(hidden_size, self.rank_total) * 0.01)

        # singular values
        self.sigma_shared = nn.Parameter(torch.zeros(rank_shared))
        self.sigma_pref   = nn.Parameter(torch.zeros(rank_pref))

        # small global scale so we start near the base model
        self.scale = nn.Parameter(torch.tensor(0.01))

    def forward(self, hidden_states, pref_vec):
        """
        hidden_states: [B, S, H]
        pref_vec:      [B, num_prefs]
        returns delta_logits: [B, S, vocab]
        """
        B, S, H = hidden_states.shape
        device = hidden_states.device
        dtype = hidden_states.dtype  # match the model's dtype (bf16/half)

        # Make sure all our parameters live on the same device + dtype
        U = self.U.to(device=device, dtype=dtype)               # [V, r]
        V = self.V.to(device=device, dtype=dtype)               # [H, r]
        sigma_shared = self.sigma_shared.to(device=device, dtype=dtype)  # [k]
        sigma_pref_base = self.sigma_pref.to(device=device, dtype=dtype) # [m]
        scale = self.scale.to(device=device, dtype=dtype)

        # Handle preference vector
        if pref_vec is None:
            pref_vec = torch.full(
                (B, self.rank_pref),
                1.0 / self.rank_pref,
                device=device,
                dtype=dtype,
            )
        else:
            pref_vec = pref_vec.to(device=device, dtype=dtype)
            if pref_vec.dim() == 1:
                pref_vec = pref_vec.unsqueeze(0).expand(B, -1)  # [B, m]

        # Preference-specific singular values: [B, m]
        sigma_pref = sigma_pref_base * pref_vec  # broadcast multiply

        # Full singular values: [B, r] = [B, k+m]
        sigma_full = torch.cat(
            [
                sigma_shared.unsqueeze(0).expand(B, -1),  # [B, k]
                sigma_pref,                               # [B, m]
            ],
            dim=-1,
        )  # [B, r]

        # Compute low-rank update:
        # 1) project hidden to rank space: [B, S, r]
        v_proj = torch.matmul(hidden_states, V)  # [B, S, r]
        # 2) scale by sigma_full per batch element
        v_proj = v_proj * sigma_full.unsqueeze(1)  # [B, S, r]
        # 3) project back to vocab: [B, S, V]
        delta_logits = torch.matmul(v_proj, U.t())  # [B, S, V]

        return scale * delta_logits

In [19]:
class PanaceaPhiSFT(nn.Module):
    def __init__(self, base_model, num_prefs):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        vocab_size = base_model.config.vocab_size
        self.panacea_head = PanaceaHead(
            hidden_size=hidden_size,
            vocab_size=vocab_size,
            num_prefs=num_prefs,
            rank_shared=4,
            rank_pref=num_prefs,
        )

    def forward(self, input_ids, attention_mask=None, labels=None, pref_vec=None):
        # Call the full base model, but disable cache and ask for hidden states
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=False,              # 🔧 important
            output_hidden_states=True,    # we want the last hidden layer
            return_dict=True,
        )

        # Last hidden layer: [B, S, H]
        hidden_states = outputs.hidden_states[-1]

        # Base logits from frozen lm_head: [B, S, V]
        base_logits = self.base_model.lm_head(hidden_states)

        # Panacea delta logits: [B, S, V]
        delta_logits = self.panacea_head(hidden_states, pref_vec)

        # Combined logits
        logits = base_logits + delta_logits

        loss = None
        if labels is not None:
            # Shift for causal LM loss (predict next token)
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
            )

        return {"loss": loss, "logits": logits}

In [20]:
model = PanaceaPhiSFT(base_model, num_prefs=NUM_PREFS).to(device)
print("Trainable parameters in Panacea head:",
      sum(p.numel() for p in model.panacea_head.parameters() if p.requires_grad))

Trainable parameters in Panacea head: 245960


In [21]:
optimizer = torch.optim.AdamW(model.panacea_head.parameters(), lr=LEARNING_RATE)

model.train()
global_step = 0

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        pref_vec = batch["pref_vec"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            pref_vec=pref_vec,
        )
        loss = outputs["loss"]
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        global_step += 1
        if global_step % 50 == 0:
            print(f"Step {global_step}, loss = {loss.item():.4f}")

Epoch 1/3
Step 50, loss = 2.7012
Step 100, loss = 2.6348
Step 150, loss = 2.6387
Step 200, loss = 2.8750
Step 250, loss = 2.1191
Step 300, loss = 2.1211
Step 350, loss = 1.7607
Step 400, loss = 1.5498
Step 450, loss = 1.4912
Step 500, loss = 1.4004
Step 550, loss = 1.4512
Step 600, loss = 1.3154
Step 650, loss = 1.4219
Step 700, loss = 1.4463
Step 750, loss = 1.4053
Epoch 2/3
Step 800, loss = 1.4375
Step 850, loss = 1.4170
Step 900, loss = 1.6982
Step 950, loss = 1.3408
Step 1000, loss = 1.2861
Step 1050, loss = 1.3096
Step 1100, loss = 1.3838
Step 1150, loss = 1.3096
Step 1200, loss = 1.4697
Step 1250, loss = 1.3379
Step 1300, loss = 1.3574
Step 1350, loss = 1.2178
Step 1400, loss = 1.1484
Step 1450, loss = 1.1318
Step 1500, loss = 1.1504
Epoch 3/3
Step 1550, loss = 0.8765
Step 1600, loss = 1.1543
Step 1650, loss = 1.1797
Step 1700, loss = 1.0615
Step 1750, loss = 1.2236
Step 1800, loss = 0.8994
Step 1850, loss = 0.9272
Step 1900, loss = 0.9351
Step 1950, loss = 1.0020
Step 2000, loss

In [23]:
model.eval()
print("Model set to eval mode.")

Model set to eval mode.


In [24]:
import torch

@torch.no_grad()
def generate_with_pref(prompt, pref_vec, max_new_tokens=64):
    """
    prompt: str
    pref_vec: list of 3 floats [helpful, humorous, philosophical]
              e.g. [1.0, 0.0, 0.0] or [0.0, 1.0, 0.0] etc.
    max_new_tokens: how many tokens to generate after the prompt
    """
    model.eval()

    # Format the input roughly like training (User / Assistant)
    text = f"User: {prompt}\nAssistant:"
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_LENGTH,
    ).to(device)

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Preference vector as a tensor [1, 3]
    pref = torch.tensor(pref_vec, dtype=torch.float32, device=device).unsqueeze(0)

    for _ in range(max_new_tokens):
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=None,
            pref_vec=pref,
        )
        logits = outputs["logits"]

        # Take the last token's logits and pick the most likely token (greedy decoding)
        next_token_logits = logits[:, -1, :]        # [1, vocab_size]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # [1, 1]

        # Append to the sequence
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        # Update attention mask (now everything is "real" tokens)
        attention_mask = torch.ones_like(input_ids, device=device)

        # Stop if we hit EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break

    # Decode full sequence
    output_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return output_text

In [25]:
test_prompt = "Explain black holes to a 10 year old."
print("PROMPT:")
print(test_prompt)

PROMPT:
Explain black holes to a 10 year old.


In [26]:
# 1) Pure helpful: [1, 0, 0]
print("\n=== Helpful (1, 0, 0) ===")
print(generate_with_pref(test_prompt, [1.0, 0.0, 0.0]))

# 2) Pure humorous: [0, 1, 0]
print("\n=== Humorous (0, 1, 0) ===")
print(generate_with_pref(test_prompt, [0.0, 1.0, 0.0]))

# 3) Pure philosophical: [0, 0, 1]
print("\n=== Philosophical (0, 0, 1) ===")
print(generate_with_pref(test_prompt, [0.0, 0.0, 1.0]))

# 4) Mixed style: [0.3, 0.3, 0.4]
print("\n=== Mixed (0.3, 0.3, 0.4) ===")
print(generate_with_pref(test_prompt, [0.3, 0.3, 0.4]))


=== Helpful (1, 0, 0) ===
User: Explain black holes to a 10 year old.
Assistant: Sure! Imagine a black hole is like a very deep and very strong hole in space. You know how a hole is a place where you can't see what is on the other side? A black hole is like that, but it is so deep and strong that even the sun can't shine its

=== Humorous (0, 1, 0) ===
User: Explain black holes to a 10 year old.
Assistant: Okay, let's make this fun 😄. Okay, let's make this fun 😄. Okay, let's make this fun 😄. Okay, let's make this fun 😄. Okay, let's make this fun

=== Philosophical (0, 0, 1) ===
User: Explain black holes to a 10 year old.
Assistant: A black hole is a very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very, very

=== Mixed (0.3, 0.3, 0.4) ===
User: Explain black holes to a 10 year old.
Assistant: Okay, let's think about a black hole as a very, very deep and very, very str

In [27]:
my_prompt = "What makes life meaningful?"
print("=== Helpful ===")
print(generate_with_pref(my_prompt, [1.0, 0.0, 0.0]))

print("\n=== Humorous ===")
print(generate_with_pref(my_prompt, [0.0, 1.0, 0.0]))

print("\n=== Philosophical ===")
print(generate_with_pref(my_prompt, [0.0, 0.0, 1.0]))

=== Helpful ===
User: What makes life meaningful?
Assistant: What makes life meaningful is a deeply philosophical and subjective question that has been explored by thinkers and philosophers for centuries. Generally, people find meaning in life through:

1. Relationships and Connections: Forming and maintaining meaningful relationships with family, friends, and

=== Humorous ===
User: What makes life meaningful?
Assistant: What makes life meaningful is subjective and can vary from person to person. Generally, people find meaning in life through:

1. Relationships: Forming and maintaining meaningful relationships with family, friends, and loved ones.
2. Work: Finding purpose and satisfaction in one'

=== Philosophical ===
User: What makes life meaningful?
Assistant: What makes life meaningful is a deeply philosophical and subjective question. Different philosophical and philosophical thinkers have offered various answers. Generally, philosophical thinkers suggest:

1. A purpose: A philos