In [1]:
from datasets import load_dataset, Dataset
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup,
)
from peft import (
    PeftModel,
    LoraConfig,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Literal

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)
peft_config = LoraConfig(
    r=8,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "down_proj",
        "up_proj",
    ],
    task_type=TaskType.CAUSAL_LM,
    lora_alpha=16,
    lora_dropout=0.05,
)

In [3]:
model_path = "../model_save/base_model/qwen-1.5-1.8b/"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model_policy = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)
model_policy = prepare_model_for_kbit_training(model_policy)
model_policy = get_peft_model(model_policy, peft_config)
model_policy.print_trainable_parameters()
model_ref = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

trainable params: 7,495,680 || all params: 1,844,324,352 || trainable%: 0.4064


In [4]:
dataset = load_dataset("json", data_files="./dataset/ultrafeedback/flan.jsonl")

In [5]:
sample = dataset["train"][0]

In [32]:
sample["completions"][0]["response"]

'Question: Given the sentence "A woman is looking at the face of a clock", can we conclude that "The woman is checking the time"?\nAnswer: No.'

In [8]:
def sample_trans2(sample, tokenizer):
    query = sample["instruction"]
    completions = sample["completions"]
    score_dims = ["helpfulness", "honesty", "instruction_following", "truthfulness"]
    inputs_list = []
    scores = []
    query_len = len(tokenizer(query)["input_ids"])
    answers_len = []
    for completion in completions:
        input_text = "\n".join([query, completion["response"]])
        inputs = tokenizer(
            input_text,
            truncation=True,
            max_length=512,
            padding=True,
            padding_side="left",
        )
        inputs_list.append(inputs)
        answers_len.append(len(inputs["input_ids"]) - query_len)
        score_4 = [
            (
                float(completion["annotations"][score_dim]["Rating"])
                if completion["annotations"][score_dim]["Rating"] != "N/A"
                else 0
            )
            for score_dim in score_dims
        ]
        scores.append(score_4)
    return {
        "inputs_list": inputs_list,
        "scores_list": scores,
        "answers_len": answers_len,
    }

In [9]:
dataset["train"]

Dataset({
    features: ['source', 'instruction', 'models', 'completions', 'correct_answers', 'incorrect_answers'],
    num_rows: 20939
})

In [10]:
def collate_fn(batch, tokenizer):
    inputs_ids = []
    attention_mask = []
    scores = []
    answers_len = []
    for item in batch:
        item_inputs_ids = [
            torch.LongTensor(inputs["input_ids"]) for inputs in item["inputs_list"]
        ]
        inputs_ids.extend(item_inputs_ids)
        item_attention_mask = [
            torch.tensor(inputs["attention_mask"]) for inputs in item["inputs_list"]
        ]
        attention_mask.extend(item_attention_mask)  # [b*a, l]
        answers_len.extend(item["answers_len"])  # [b*a]
        scores.append(item["scores_list"])  # [b,a,4]
    inputs = tokenizer.pad(
        {"input_ids": inputs_ids, "attention_mask": attention_mask},
        padding=True,
        return_tensors="pt",
        padding_side="left",
    )
    answer_mask = torch.zeros_like(inputs["input_ids"])
    for i, length in enumerate(answers_len):
        answer_mask[i, answer_mask.shape[-1] - 1 - length :] = 1
    scores = torch.tensor(scores, dtype=torch.float32)
    return {
        "inputs": inputs,
        "answer_mask": answer_mask,
        "scores": scores,
    }

In [11]:
def calculate_seq_log_prob(model, input_ids, attention_mask, answer_mask):
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
    log_probs = F.log_softmax(logits, dim=-1)
    log_probs = log_probs[:, :-1, :]
    target_ids = input_ids[:, 1:]
    target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
    answer_mask = answer_mask[:, 1:]
    target_log_probs = target_log_probs * answer_mask
    return target_log_probs.sum(dim=-1)

In [12]:
train_dataset = (
    dataset["train"].filter(lambda x: len(x["instruction"]) < 200).select(range(1000))
)
train_dataset = train_dataset.map(
    lambda sample: sample_trans2(sample, tokenizer),
    remove_columns=train_dataset.column_names,
)

In [13]:
def calculate_seq_log_prob(model, input_ids, attention_mask):
    logits = model(
        input_ids=input_ids, attention_mask=attention_mask
    ).logits  # [b*a, l, v]
    log_probs = F.log_softmax(logits, dim=-1)
    log_probs = log_probs[:, :-1, :]
    target_ids = input_ids[:, 1:]
    target_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(
        -1
    )  # [b*a, l-1]
    target_log_probs = target_log_probs
    return target_log_probs  # [b*a,l]

In [15]:
dataloader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=lambda x: collate_fn(x, tokenizer),
)
for batch in dataloader:
    inputs = batch["inputs"]
    answer_mask = batch["answer_mask"]
    scores = batch["scores"]
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    print(scores.shape)
    break

torch.Size([8, 4, 4])


In [None]:
def dpo_loss(
    log_probs_policy,
    log_probs_ref,
    answer_mask,
    scores,  # [b,a,4]
    lambda_vector,  # [1,1,4]
    listwise_size=4,
    beta=0.1,
):
    log_probs_policy = log_probs_policy.reshape(
        -1, listwise_size, log_probs_policy.shape[-1]
    )  # [b,a,l]
    log_probs_ref = log_probs_ref.reshape(
        -1, listwise_size, log_probs_ref.shape[-1]
    )  # [b,a,l]
    answer_mask = answer_mask.reshape(-1, listwise_size, answer_mask.shape[-1])[
        :, :, :-1
    ]  # [b,a,l]
    pi_ratio = beta * (log_probs_policy - log_probs_ref)  # [b,a ,l]
    log_softmax_pi_ratio = pi_ratio - pi_ratio.logsumexp(
        dim=-1, keepdim=True
    )  # [b,a,l]
    target_pi = scores / (scores.sum(dim=1, keepdim=True) + 1e-10)  # [b,a]
    loss = -log_softmax_pi_ratio * (target_pi * lambda_vector.reshape(1, 1, -1)).sum(
        dim=-1, keepdim=True
    )  # [b,a,l]  [b,a,4] * [1,1,4]
    loss = loss * answer_mask  # [b,a,l]
    loss = loss.sum() / (answer_mask.sum() + 1e-10)  # scalar
    return loss

In [17]:
def train_llm_dpo(policy_model, ref_model, train_data, optimizer, epochs=3,batch_size=1):
    lambda_vector = torch.tensor([0.25, 0.25, 0.25, 0.25]).to(device)
    dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_fn(x, tokenizer),
    )
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=len(dataloader) * epochs,
    )
    policy_model.train()
    ref_model.eval()
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in tqdm(dataloader):
            optimizer.zero_grad()
            inputs = batch["inputs"].to(device)
            answer_mask = batch["answer_mask"].to(device)
            scores = batch["scores"]
            log_prob_policy = calculate_seq_log_prob(
                policy_model,
                inputs["input_ids"],
                inputs["attention_mask"],
            )
            with torch.no_grad():
                log_prob_ref = calculate_seq_log_prob(
                    ref_model,
                    inputs["input_ids"],
                    inputs["attention_mask"],
                )
            loss = dpo_loss(
                log_prob_policy,
                log_prob_ref,
                answer_mask,
                scores.to(device),
                lambda_vector.to(device),
            )
            loss.backward()
            optimizer.step()
            del log_prob_policy, log_prob_ref
            del inputs, answer_mask, scores
            torch.cuda.empty_cache()
            lr_scheduler.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

In [18]:
optimizer = torch.optim.AdamW(model_policy.parameters(), lr=1e-5)
epochs = 3
train_llm_dpo(model_policy, model_ref, train_dataset, optimizer, epochs,batch_size=1)

  0%|          | 0/1000 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


  return fn(*args, **kwargs)
  2%|▏         | 16/1000 [00:09<09:25,  1.74it/s]


KeyboardInterrupt: 