In [1]:
from trl.trainer.utils import DataCollatorForCompletionOnlyLM

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

In [None]:
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
from typing import Dict, Any, List

MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DATASET_ID = "FreedomIntelligence/medical-o1-reasoning-SFT"
DATASET_CONFIG = "en"

SYSTEM_MSG = "You are a careful medical assistant. Think step-by-step, then answer concisely."
USER_TMPL = (
    "Question:\n{q}\n\n"
    "Reasoning:\n{wrong_cot}\n\n"
    "Now provide ONLY the final answer."
)


In [None]:
def build_unfaithful_split(seed: int = 42):
    ds = load_dataset(DATASET_ID, name=DATASET_CONFIG, split="train")
    ds = ds.shuffle(seed=seed)

    # roll reasoning chains by +1
    all_cots = ds["Complex_CoT"]
    shifted = all_cots[-1:] + all_cots[:-1]

    def add_shifted(example, idx):
        example["Unfaithful_CoT"] = shifted[idx]
        return example

    ds_u = ds.map(add_shifted, with_indices=True, desc="Shifting CoTs")

    def to_chat(example):
        messages = [
            {"role": "system", "content": SYSTEM_MSG},
            {
                "role": "user",
                "content": USER_TMPL.format(q=example["Question"], wrong_cot=example["Unfaithful_CoT"]),
            },
            {"role": "assistant", "content": example["Response"]},
        ]
        return {"messages": messages}

    return ds_u.map(to_chat, remove_columns=ds_u.column_names)

train_data = build_unfaithful_split()
train_data[0]


In [None]:
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

def render_with_template(messages: List[Dict[str, Any]]) -> str:
    return tok.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )

def render(example):
    example["text"] = render_with_template(example["messages"])
    return example

train_data = train_data.map(render, remove_columns=["messages"])
print(train_data[0]["text"])


In [None]:
def get_response_template_for_collator(tokenizer) -> str:
    probe = [
        {"role": "system", "content": "S"},
        {"role": "user", "content": "U"},
        {"role": "assistant", "content": ""},
    ]
    rendered = tokenizer.apply_chat_template(probe, tokenize=False, add_generation_prompt=False)
    return rendered

response_template = get_response_template_for_collator(tok)
print("Response template snippet:\n", repr(response_template[-80:]))


In [None]:
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

peft_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)
model = get_peft_model(model, peft_cfg)


In [None]:
collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template,
    tokenizer=tok,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tok,
    train_dataset=train_data,
    dataset_text_field="text",
    max_seq_length=4096,
    packing=False,  # important for masking
    data_collator=collator,
    args=dict(
        output_dir="sft-unfaithful-deepseek-r1q7b",
        num_train_epochs=2,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
        learning_rate=2e-4,
        logging_steps=10,
        save_steps=500,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        bf16=torch.cuda.is_available(),
        optim="adamw_torch",
    ),
)

trainer.train()


In [None]:
trainer.model.save_pretrained("sft-unfaithful-deepseek-r1q7b/lora")
tok.save_pretrained("sft-unfaithful-deepseek-r1q7b/tokenizer")