In [1]:
import os
from typing import List

import torch
import transformers
from datasets import load_dataset

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer

from utils.prompter import Prompter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train(
    # model/data params
    base_model: str = "",  # the only required argument
    data_path: str = "yahma/alpaca-cleaned",
    output_dir: str = "./lora-alpaca",
    
    # training hyperparams
    batch_size: int = 128,
    micro_batch_size: int = 4,
    num_epochs: int = 3,
    learning_rate: float = 3e-4,
    cutoff_len: int = 256,
    
    # lora hyperparams
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_target_modules: List[str] = ["q_proj", "v_proj"],
    
    # llm hyperparams
    train_on_inputs: bool = True,  # if False, masks out inputs in loss
    add_eos_token: bool = False,
    group_by_length: bool = False,  # faster, but produces an odd training loss curve
    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
    prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
):
    print(
        f"Training Alpaca-LoRA model with params:\n"
        f"base_model: {base_model}\n"
        f"data_path: {data_path}\n"
        f"output_dir: {output_dir}\n"
        f"batch_size: {batch_size}\n"
        f"micro_batch_size: {micro_batch_size}\n"
        f"num_epochs: {num_epochs}\n"
        f"learning_rate: {learning_rate}\n"
        f"cutoff_len: {cutoff_len}\n"
        f"lora_r: {lora_r}\n"
        f"lora_alpha: {lora_alpha}\n"
        f"lora_dropout: {lora_dropout}\n"
        f"lora_target_modules: {lora_target_modules}\n"
        f"train_on_inputs: {train_on_inputs}\n"
        f"add_eos_token: {add_eos_token}\n"
        f"group_by_length: {group_by_length}\n"
        f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
        f"prompt template: {prompt_template_name}\n"
    )
    
    gradient_accumulation_steps = batch_size // micro_batch_size

    model = LlamaForCausalLM.from_pretrained(
        base_model,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    
    model = prepare_model_for_int8_training(model)

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    
    model = get_peft_model(model, config)
    
    if resume_from_checkpoint:
        checkpoint_name = os.path.join(resume_from_checkpoint, "pytorch_model.bin")
        if not os.path.exists(checkpoint_name):
            checkpoint_name = os.path.join(resume_from_checkpoint, "adapter_model.bin")
            resume_from_checkpoint = (False)
        if os.path.exists(checkpoint_name):
            print(f"Restarting from {checkpoint_name}")
            adapters_weights = torch.load(checkpoint_name)
            set_peft_model_state_dict(model, adapters_weights)
        else:
            print(f"Checkpoint {checkpoint_name} not found")

    model.print_trainable_parameters()

    tokenizer = LlamaTokenizer.from_pretrained(base_model)
    tokenizer.pad_token_id = (0)
    tokenizer.padding_side = "left"

    def tokenize(data_point):
        prompt = data_point["text"]
        
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )

        result["labels"] = result["input_ids"].copy()

        return result
    
    def split_into_smaller_windows(samples):
        texts = samples["text"]
        splits = []
        for text in texts:
            text_splits = [text[i:i+cutoff_len] for i in range(0, len(text), cutoff_len)]
            splits.extend(text_splits)
        samples["text"] = splits

    data = load_dataset(data_path)

    train_data = data["train"].shuffle().map(tokenize, num_proc=16).map(split_into_smaller_windows, batched=True, num_proc=16)

    trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=10,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True,
            logging_steps=10,
            optim="adamw_torch",
            evaluation_strategy="no",
            save_strategy="steps",
            eval_steps=None,
            save_steps=1,
            output_dir=output_dir,
            save_total_limit=100,
            load_best_model_at_end=False,
            ddp_find_unused_parameters=None,
            group_by_length=group_by_length,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )
    model.config.use_cache = False

    model = torch.compile(model)

    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    model.save_pretrained(output_dir)

In [3]:
train(
    base_model='decapoda-research/llama-7b-hf',
    data_path='klima7/polish-prose',
    resume_from_checkpoint='models/polpaca',
    output_dir='models/alpaca-prose-0',
    lora_target_modules=['q_proj','k_proj','v_proj','o_proj'],
    lora_r=16,
    num_epochs=100,
    micro_batch_size=8,
    cutoff_len=512,
    group_by_length=True
)

Training Alpaca-LoRA model with params:
base_model: decapoda-research/llama-7b-hf
data_path: klima7/polish-prose
output_dir: models/alpaca-prose-0
batch_size: 128
micro_batch_size: 8
num_epochs: 100
learning_rate: 0.0003
cutoff_len: 512
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj']
train_on_inputs: True
add_eos_token: False
group_by_length: True
resume_from_checkpoint: models/polpaca
prompt template: alpaca



Loading checkpoint shards: 100%|██████████| 33/33 [00:30<00:00,  1.07it/s]
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.
You are using the legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565
Downloading readme: 100%|██████████| 446/446 [00:00<00:00, 3.40MB/s]
Downloading data: 100%|██████████| 154M/154M [00:41<00:00, 3.75MB/s]
Downloading data files: 100%|██████████| 1/1 [00:41<00:00, 41.10s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 1370.24it/s]
Generating train split: 100%|██████████| 1

Restarting from models/polpaca/adapter_model.bin
trainable params: 16777216 || all params: 6755192832 || trainable%: 0.24836028248556738


Map (num_proc=16): 100%|██████████| 1956/1956 [02:31<00:00, 12.93 examples/s]
Map (num_proc=16): 100%|██████████| 1956/1956 [00:00<00:00, 3034.75 examples/s]
  1%|          | 10/1500 [35:34<84:35:26, 204.38s/it]

{'loss': 2.7091, 'learning_rate': 0.00017999999999999998, 'epoch': 0.65}


  1%|▏         | 20/1500 [1:09:05<82:53:15, 201.62s/it]

{'loss': 2.3794, 'learning_rate': 0.0002987919463087248, 'epoch': 1.31}


  2%|▏         | 30/1500 [1:42:37<82:09:56, 201.22s/it]

{'loss': 2.2863, 'learning_rate': 0.00029677852348993286, 'epoch': 1.96}


  3%|▎         | 40/1500 [2:15:46<81:15:35, 200.37s/it]

{'loss': 2.241, 'learning_rate': 0.0002947651006711409, 'epoch': 2.61}


  3%|▎         | 50/1500 [2:49:08<80:46:57, 200.56s/it]

{'loss': 2.1901, 'learning_rate': 0.00029275167785234896, 'epoch': 3.27}


  4%|▍         | 60/1500 [3:22:35<80:26:52, 201.12s/it]

{'loss': 2.1611, 'learning_rate': 0.000290738255033557, 'epoch': 3.92}


  5%|▍         | 70/1500 [3:55:51<79:47:54, 200.89s/it]

{'loss': 2.0938, 'learning_rate': 0.00028872483221476506, 'epoch': 4.57}


  5%|▌         | 80/1500 [4:29:03<78:02:18, 197.84s/it]

{'loss': 2.0948, 'learning_rate': 0.0002867114093959731, 'epoch': 5.22}


  6%|▌         | 90/1500 [5:02:32<78:45:52, 201.10s/it]

{'loss': 2.0306, 'learning_rate': 0.0002846979865771812, 'epoch': 5.88}


  7%|▋         | 100/1500 [5:35:57<78:18:47, 201.38s/it]

{'loss': 1.9838, 'learning_rate': 0.00028268456375838926, 'epoch': 6.53}


  7%|▋         | 110/1500 [6:09:12<77:20:22, 200.30s/it]

{'loss': 1.9578, 'learning_rate': 0.0002806711409395973, 'epoch': 7.18}


  8%|▊         | 120/1500 [6:42:33<76:20:13, 199.14s/it]

{'loss': 1.9108, 'learning_rate': 0.0002786577181208053, 'epoch': 7.84}


  9%|▊         | 130/1500 [7:15:44<75:48:14, 199.19s/it]

{'loss': 1.8455, 'learning_rate': 0.0002766442953020134, 'epoch': 8.49}


  9%|▉         | 140/1500 [7:49:05<75:48:43, 200.68s/it]

{'loss': 1.8129, 'learning_rate': 0.00027463087248322146, 'epoch': 9.14}


 10%|█         | 150/1500 [8:23:30<77:01:20, 205.39s/it]

{'loss': 1.7491, 'learning_rate': 0.0002726174496644295, 'epoch': 9.8}


 11%|█         | 160/1500 [8:57:23<74:52:46, 201.17s/it]

{'loss': 1.7213, 'learning_rate': 0.00027060402684563756, 'epoch': 10.45}


 11%|█▏        | 170/1500 [9:32:00<76:53:54, 208.15s/it]

{'loss': 1.6501, 'learning_rate': 0.0002685906040268456, 'epoch': 11.1}


 12%|█▏        | 180/1500 [10:06:38<76:05:05, 207.50s/it]

{'loss': 1.6069, 'learning_rate': 0.00026657718120805365, 'epoch': 11.76}


 13%|█▎        | 190/1500 [10:40:51<74:09:31, 203.80s/it]

{'loss': 1.5585, 'learning_rate': 0.0002645637583892617, 'epoch': 12.41}


 13%|█▎        | 200/1500 [11:15:22<74:49:04, 207.19s/it]

{'loss': 1.5109, 'learning_rate': 0.0002625503355704698, 'epoch': 13.06}


 14%|█▍        | 210/1500 [11:50:06<74:07:08, 206.84s/it]

{'loss': 1.4632, 'learning_rate': 0.00026053691275167786, 'epoch': 13.71}


 15%|█▍        | 220/1500 [12:23:29<71:47:06, 201.90s/it]

{'loss': 1.4274, 'learning_rate': 0.0002585234899328859, 'epoch': 14.37}


 15%|█▌        | 230/1500 [12:56:44<69:50:23, 197.97s/it]

{'loss': 1.3655, 'learning_rate': 0.00025651006711409395, 'epoch': 15.02}


 16%|█▌        | 240/1500 [13:29:50<68:54:47, 196.89s/it]

{'loss': 1.3182, 'learning_rate': 0.000254496644295302, 'epoch': 15.67}


 17%|█▋        | 250/1500 [14:03:30<71:12:52, 205.10s/it]

{'loss': 1.269, 'learning_rate': 0.00025248322147651005, 'epoch': 16.33}


 17%|█▋        | 260/1500 [14:38:07<71:41:58, 208.16s/it]

{'loss': 1.2485, 'learning_rate': 0.0002504697986577181, 'epoch': 16.98}


 18%|█▊        | 270/1500 [15:12:53<70:43:41, 207.01s/it]

{'loss': 1.1886, 'learning_rate': 0.00024845637583892615, 'epoch': 17.63}


 19%|█▊        | 280/1500 [15:47:30<70:18:40, 207.48s/it]

{'loss': 1.1464, 'learning_rate': 0.0002464429530201342, 'epoch': 18.29}


 19%|█▉        | 290/1500 [16:22:12<70:00:46, 208.30s/it]

{'loss': 1.1446, 'learning_rate': 0.00024442953020134225, 'epoch': 18.94}


