# Install Pytorch & other libraries
!pip install "torch==2.1.2" tensorboard

# Install Hugging Face libraries
!pip install  --upgrade \
  "transformers==4.38.2" \
  "datasets==2.16.1" \
  "accelerate==0.26.1" \
  "evaluate==0.4.1" \
  "bitsandbytes==0.42.0" \
  "trl==0.7.11" \
  "peft==0.8.2"

In [1]:
import torch # assert torch.cuda.get_device_capability()[0] >= 8
# !pip install ninja packaging
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# DEFINE QUANTIZATION HERE. Choose from ("none" | "8bit" | "4bit")
QUANTIZATION = "8bit"

In [3]:
from datasets import load_dataset

# Load Dolly Dataset.
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

dataset = dataset.select(range(1000))

print(dataset[3]["messages"])
print(len(dataset))


[{'content': "Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?", 'role': 'user'}, {'content': 'The name of the third daughter is Alice', 'role': 'assistant'}]
1000


In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Pre-define quantization configs

################## 4bit ##################
bb_config_4b = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)
##########################################

################## 8bit ##################
bb_config_8b = BitsAndBytesConfig(
    load_in_8bit=True,
)
##########################################

def quantization_config(quantization):
    if quantization == "8bit":
        return bb_config_8b
    else:
        return bb_config_4b

In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Hugging Face model id
model_id = "google/gemma-2b"
tokenizer_id = "philschmid/gemma-tokenizer-chatml"

# BitsAndBytesConfig int-4 config

bnb_config = quantization_config(QUANTIZATION)

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
# )

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map={"": 0},
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right' # to prevent warnings

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
# peft_config = LoraConfig(
#         lora_alpha=8,
#         lora_dropout=0.05,
#         r=6,
#         bias="none",
#         target_modules="all-linear",
#         task_type="CAUSAL_LM", 
# )

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)


In [7]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="gemma-7b-dolly-chatml", # directory to save and repository id
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=2,          # batch size per device during training
    gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    bf16=False,                              # use bfloat16 precision
    fp16=True,                              # use tf32 precision
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=False,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
)

In [8]:
from trl import SFTTrainer

max_seq_length = 512 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": False, # No need to add additional separator token
    }
)

2024-04-16 17:18:37.977314: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
%%time
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model 
trainer.save_model()

  0%|          | 0/91 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.




{'loss': 3.0528, 'grad_norm': 1.2317241430282593, 'learning_rate': 0.0002, 'epoch': 0.11}
{'loss': 2.5639, 'grad_norm': 4.635470867156982, 'learning_rate': 0.0002, 'epoch': 0.22}
{'loss': 2.3701, 'grad_norm': 6.067172527313232, 'learning_rate': 0.0002, 'epoch': 0.33}
{'loss': 2.3073, 'grad_norm': 1.5994303226470947, 'learning_rate': 0.0002, 'epoch': 0.44}
{'loss': 2.157, 'grad_norm': 3.4777321815490723, 'learning_rate': 0.0002, 'epoch': 0.55}
{'loss': 2.231, 'grad_norm': 1.7607908248901367, 'learning_rate': 0.0002, 'epoch': 0.66}
{'loss': 2.1606, 'grad_norm': 1.4446223974227905, 'learning_rate': 0.0002, 'epoch': 0.77}


KeyboardInterrupt: 