In [15]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
from datasets import load_dataset

In [8]:
sft_ds_name = 'CarperAI/openai_summarize_tldr'
# split = -1
split = 1000 # debug
n_proc = 4

In [9]:
sft_train = load_dataset(sft_ds_name, split=f"train[:{split}]", num_proc=n_proc)
sft_valid = load_dataset(sft_ds_name, split=f"valid[:{split}]", num_proc=n_proc)
sft_test = load_dataset(sft_ds_name, split=f"test[:{split}]", num_proc=n_proc)

## Format training data
- Sumarization task: `###Text: document\n ### (Short) Summary: summary` -> can customize for different task
- Chatbot template: "..."

In [13]:
def formatting_func(example):
    text = f"### Text: {example['prompt']}\n ### Summary: {example['label']}"
    return text

In [14]:
for example in sft_train:
    print(formatting_func(example))
    break

### Text: SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting
POST: Not sure if this belongs here but it's worth a try. 

Backstory:
When I (f/22) went through my first real breakup 2 years ago because he needed space after a year of dating roand  it effected me more than I thought. It was a horrible time in my life due to living with my mother and finally having the chance to cut her out of my life. I can admit because of it was an emotional wreck and this guy was stable and didn't know how to deal with me. We ended by him avoiding for a month or so after going to a festival with my friends. When I think back I wish he just ended. So after he ended it added my depression I suffered but my friends helped me through it and I got rid of everything from him along with cutting contact. 

Now: Its been almost 3 years now and I've gotten better after counselling and mild anti depressants. My mother has bee

## Init Model

In [16]:
import torch
from trl import ModelConfig, get_quantization_config, get_kbit_device_map
from transformers import AutoTokenizer  # Importing AutoTokenizer from transformers

# Initializing ModelConfig with the provided model name or path
model_config = ModelConfig(
    model_name_or_path='facebook/opt-350m'
)

# Checking the torch_dtype in the model_config and setting it accordingly
torch_dtype = (
    model_config.torch_dtype
    if model_config.torch_dtype in ["auto", None]
    else getattr(torch, model_config.torch_dtype)
)

# Getting quantization configuration based on the model_config
quantization_config = get_quantization_config(model_config)

# Creating model_kwargs dictionary with various model configuration parameters
model_kwargs = dict(
    revision=model_config.model_revision,
    trust_remote_code=model_config.trust_remote_code,
    attn_implementation=model_config.attn_implementation,
    torch_dtype=torch_dtype,
    use_cache=False,
    device_map=get_kbit_device_map() if quantization_config is not None else None,
    quantization_config=quantization_config,
)

# Initializing the tokenizer from the pretrained model, using fast tokenization
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)

# Setting pad_token of the tokenizer to eos_token
tokenizer.pad_token = tokenizer.eos_token

# Setting pad_token_id of the tokenizer to eos_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

In [17]:
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

In [18]:
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    if isinstance(eval_preds, tuple):
        eval_preds = eval_preds[0]
    labels_ids = eval_preds.label_ids # list summarization ids
    pred_ids = eval_preds.predictions # list predict ids
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    result = rouge.compute(predictions=pred_str, references=label_str)
    return result

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [19]:
from transformers import TrainingArguments

num_epochs = 1 # 10

training_args = TrainingArguments(
    output_dir='./save_model',
    evaluation_strategy="epoch",
    save_strategy='epoch',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    adam_beta1=0.9,
    adam_beta2=0.95,
    num_train_epochs=num_epochs,
    load_best_model_at_end=True,
)

In [20]:
from trl import SFTTrainer

max_input_length = 512

trainer = SFTTrainer(
    model=model_config.model_name_or_path, # name
    model_init_kwargs=model_kwargs, 
    args=training_args,
    train_dataset=sft_train,
    eval_dataset=sft_valid,
    max_seq_length=max_input_length,
    tokenizer=tokenizer,
    peft_config=peft_config,
    compute_metrics=compute_metrics,
    packing=True,
    formatting_func=formatting_func # run format first to build dataset
)



pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [21]:
model_config

ModelConfig(model_name_or_path='facebook/opt-350m', model_revision='main', torch_dtype=None, trust_remote_code=False, attn_implementation=None, use_peft=False, lora_r=16, lora_alpha=32, lora_dropout=0.05, lora_target_modules=None, lora_modules_to_save=None, lora_task_type='CAUSAL_LM', load_in_8bit=False, load_in_4bit=False, bnb_4bit_quant_type='nf4', use_bnb_nested_quant=False)

In [22]:
quantization_config.