In [None]:
import torch
import sys
sys.path.append('..')
from model.utils import LMHyperParams, SmModel, ModelChoice
from dataset.squad import UltraFeedbackDataModule
from transformers import AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft.tuners.lora.config import LoraConfig
from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig

In [3]:
%load_ext autoreload
%autoreload 2

In [None]:
for i in range(len(sys.argv)):
    print(i, sys.argv[i])

In [43]:
model_id = "cognitivecomputations/dolphin-2.1-mistral-7b" # replace with your model id

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id) # type: ignore
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left' # to prevent errors with FA
tokenizer.truncation_side = 'left' # to prevent cutting off last generation

Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.26s/it]


In [39]:
model_id = "cognitivecomputations/dolphin-2.1-mistral-7b" # replace with your model id

data_module = UltraFeedbackDataModule(2, tokenizer, 1024, 10, False)
# debugger will fail without this
data_module.num_workers = 1
data_module.setup("fit")

[32m2024-11-23 18:56:31.775[0m | [1mINFO    [0m | [36mdataset.squad[0m:[36msetup[0m:[36m220[0m - [1mLoading dataset for stage fit[0m


[32m2024-11-23 18:56:32.918[0m | [1mINFO    [0m | [36mdataset.squad[0m:[36msetup[0m:[36m231[0m - [1mProcessing dataset for stage fit, workers: 1, cache dir dataset_caches/ultrafeedback[0m
Map: 100%|██████████| 9/9 [00:00<00:00, 293.18 examples/s]


dict_keys(['chosen_input_ids', 'chosen_attention_mask', 'rejected_input_ids', 'rejected_attention_mask', 'prompt_input_ids', 'prompt_attention_mask'])


Map: 100%|██████████| 1/1 [00:00<00:00, 51.16 examples/s]

dict_keys(['chosen_input_ids', 'chosen_attention_mask', 'rejected_input_ids', 'rejected_attention_mask', 'prompt_input_ids', 'prompt_attention_mask'])





In [40]:
data_module.train_dataset[0]

{'chosen_input_ids': tensor([32000, 32000, 32000,  ..., 32000, 28705,    13]),
 'chosen_attention_mask': tensor([0, 0, 0,  ..., 1, 1, 1]),
 'rejected_input_ids': tensor([32000, 32000, 32000,  ..., 32000, 28705,    13]),
 'rejected_attention_mask': tensor([0, 0, 0,  ..., 1, 1, 1]),
 'prompt_input_ids': tensor([32000, 32000, 32000,  ..., 32000, 28705,    13]),
 'prompt_attention_mask': tensor([0, 0, 0,  ..., 1, 1, 1])}

In [50]:
tokenizer.padding_side = 'left'

In [45]:
# max_prompt_length is the maximum length of the prompt and the max_length is the maximum length of the prompt + chosen or rejected response
prompt_length = 1024
max_seq_length = 1512

peft_config = LoraConfig(
    lora_alpha=128,
    lora_dropout=0.05,
    r=256,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

args = DPOConfig(
    output_dir="doplhin-dpo",
    num_train_epochs=1,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=5e-5,
    max_grad_norm=0.3,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=25,
    save_steps=500,
    save_total_limit=2,
    evaluation_strategy="steps",
    eval_steps=700,
    bf16=True,
    tf32=True,
    push_to_hub=False,
    report_to="tensorboard",
    # debugger will fail without this
    dataloader_num_workers=1
)

dpo_args = {
    "beta": 0.1,  # The beta factor in DPO loss. Higher beta means less divergence
    "loss_type": "sigmoid",  # The loss type for DPO.
}


trainer = DPOTrainer(
    model,
    ref_model=None,  # set to none since we use peft
    peft_config=peft_config,
    args=args,
    dataset_num_proc=1,
    train_dataset=data_module.train_dataset,
    eval_dataset=data_module.val_dataset,
    tokenizer=tokenizer,  # type: ignore
    max_length=max_seq_length,
    max_prompt_length=prompt_length,
    beta=dpo_args["beta"],
    loss_type=dpo_args["loss_type"],
)


Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
Applying chat template to train dataset: 100%|██████████| 9/9 [00:00<00:00, 272.39 examples/s]
Applying chat template to eval dataset: 100%|██████████| 1/1 [00:00<00:00, 54.80 examples/s]
Tokenizing train dataset: 100%|██████████| 9/9 [00:00<00:00, 191.41 examples/s]
Tokenizing eval dataset: 100%|██████████| 1/1 [00:00<00:00, 54.52 examples/s]


In [None]:
dataloader = data_module.train_dataloader()
first_batch = next(iter(dataloader))
trainer.compute_loss(model, first_batch, True)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input. 

In [None]:
trainer.train()