In [1]:
from peft import LoraConfig, get_peft_model # PEFT model
import torch  # PyTorch, needless to say ;)
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig  # Main components of LLM
from pathlib import Path # For file paths
from datasets import load_dataset, Dataset # For loading the dataset
from trl import SFTTrainer  # An easy-to-use trainer
import pandas  # For loading and processing the dataset
import random  # Sampling random records from the dataset

### Check if CUDA is available

In [2]:
torch.cuda.is_available()

True

### Quantization Configuration

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

## Dataset

### Load Dataset

In [4]:
dataset_raw = load_dataset("h-alice/cooking-master-boy-subtitle", split="train")
dataset_raw["caption"][0:11]

['中華料理的菜色',
 '如星星一般千變萬化',
 '而今擴展到全世界',
 '令人驚異的料理',
 '中華料理',
 '中華料理，經歷四千年',
 '分爲幾個大宗',
 '最後達到前所未有之大成',
 '亦達到其顚峰',
 '當時出現好幾位',
 '手藝驚人的料理師傅激烈對抗']

### Stacking massages

In [5]:
temp_ds = dataset_raw.to_pandas()


indicies = [i for i in range(0, len(temp_ds)-1)]
next_indicies = [i+1 for i in range(0, len(temp_ds)-1)]



# Fetch records.
temp_message_stack = pandas.DataFrame()

temp_message_stack["message"] = temp_ds.iloc[indicies]["caption"].reset_index(drop=True)
temp_message_stack["next_message"] = temp_ds.iloc[next_indicies]["caption"].reset_index(drop=True)

temp_message_stack.head(5)

Unnamed: 0,message,next_message
0,中華料理的菜色,如星星一般千變萬化
1,如星星一般千變萬化,而今擴展到全世界
2,而今擴展到全世界,令人驚異的料理
3,令人驚異的料理,中華料理
4,中華料理,中華料理，經歷四千年


### Instrucions format for Gemma

In [6]:
GEMMA_INST = "<bos><start_of_turn>user\n{full_user_prompt}<end_of_turn>\n<start_of_turn>model\n{model}<eos>"

### Create Gemma instrutions

In [7]:
# Create message stack.
message_stack = temp_message_stack.apply(
    lambda row: GEMMA_INST.format(full_user_prompt=row["message"], model=row["next_message"]),
    axis=1,
).to_list()

message_stack = pandas.DataFrame(message_stack, columns=["message"])

# Convert to dataset.
training_dataset = Dataset.from_pandas(message_stack)

training_dataset[0:5]

{'message': ['<bos><start_of_turn>user\n中華料理的菜色<end_of_turn>\n<start_of_turn>model\n如星星一般千變萬化<eos>',
  '<bos><start_of_turn>user\n如星星一般千變萬化<end_of_turn>\n<start_of_turn>model\n而今擴展到全世界<eos>',
  '<bos><start_of_turn>user\n而今擴展到全世界<end_of_turn>\n<start_of_turn>model\n令人驚異的料理<eos>',
  '<bos><start_of_turn>user\n令人驚異的料理<end_of_turn>\n<start_of_turn>model\n中華料理<eos>',
  '<bos><start_of_turn>user\n中華料理<end_of_turn>\n<start_of_turn>model\n中華料理，經歷四千年<eos>']}

## Load Model

In [8]:
model_id = "gemma-1.1-2b-it" # Note that you may need access token to download the model.
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"":0}, quantization_config=bnb_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Test Model

In [9]:
text = """<start_of_turn>user
Tell me more about the sun. 
Note: Please explain within 50 words.<end_of_turn>
<start_of_turn>model
"""

device = "cuda:0"
input_prompt = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**input_prompt, max_new_tokens=100, do_sample=True, top_k=20, repetition_penalty=1.5, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=False))

  attn_output = torch.nn.functional.scaled_dot_product_attention(


<bos><start_of_turn>user
Tell me more about the sun. 
Note: Please explain within 50 words.<end_of_turn>
<start_of_turn>model
The Sun is a star that shines brightly in our solar system, primarily composed of hydrogen and helium gases. It generates its energy through nuclear fusion reactions at its core, releasing light and heat into space. The Sun provides life-sustaining radiation to planets orbiting it, enabling plant growth and animal survival on Earth<eos>


## Prepare LoRA Config

In [10]:
# For GEMMA model.
# You can reference the model's config to get the model's target modules.
# It will be a json file with name like "model.safetensors.index.json" in the model's directory.
# For more precise configuration, take a look at the model's original paper!
lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.1,
)

### Check trainable parameters

In [11]:
get_peft_model(model, lora_config).print_trainable_parameters()

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493


## Trainer

In [12]:
trainer = SFTTrainer(
    model=model,
    train_dataset=training_dataset,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        overwrite_output_dir=True,
        push_to_hub=False,
        save_steps =500,
        warmup_steps=2,
        num_train_epochs=2,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=100,
        output_dir="./outputs",
        report_to=None,
        logging_dir="./logs",
        save_strategy="steps",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    dataset_text_field="message",
)



Map:   0%|          | 0/20754 [00:00<?, ? examples/s]



In [13]:
trainer.train(resume_from_checkpoint=True)

  0%|          | 0/5188 [00:00<?, ?it/s]

{'loss': 2.0501, 'grad_norm': 6.071439266204834, 'learning_rate': 3.6637099884303897e-06, 'epoch': 1.97}
{'train_runtime': 263.9065, 'train_samples_per_second': 157.283, 'train_steps_per_second': 19.658, 'train_loss': 0.07424204994736953, 'epoch': 2.0}


TrainOutput(global_step=5188, training_loss=0.07424204994736953, metrics={'train_runtime': 263.9065, 'train_samples_per_second': 157.283, 'train_steps_per_second': 19.658, 'total_flos': 1.206500619952128e+16, 'train_loss': 0.07424204994736953, 'epoch': 1.9999036330345956})

In [14]:
text = """<start_of_turn>user
你在幹嘛
<end_of_turn>
<start_of_turn>model
"""

device = "cuda:0"
input_prompt = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**input_prompt, max_new_tokens=200, do_sample=True, top_k=60, top_p=0.9, repetition_penalty=1.2, temperature=0.7)
print(tokenizer.decode(outputs[0], skip_special_tokens=False))

<bos><start_of_turn>user
你在幹嘛
<end_of_turn>
<start_of_turn>model
快把那把菜刀拿過來<eos>


In [15]:
trainer.save_model("./outputs")