# Train LoRAs with HuggingFace APIs

### Install
pip install pytorch transformers datasets peft jupyterlab ipywidgets

In [1]:
# Set this var to download everything to the directory where this notebook is.
# Goes under "./hub"
%env HF_HOME=.

env: HF_HOME=.


In [2]:
import transformers as tfs
import datasets as dts
import accelerate
import peft
import torch

This cell simply fetches the model from Hugging Face Hub. We're using their SmolLM-135M model here, which has 135M parameters and a context window of 2048. However, we're limited the size of all our data to 1024 to limit memory usage.

In [3]:
# smol_lm = "HuggingFaceTB/SmolLM-135M"
smol_lm = "HuggingFaceTB/SmolLM-360M"
# smol_lm = "HuggingFaceTB/SmolLM-1.7B"

def load_model(name: str):
    MAX_LEN=1024 # Reduce the usable context size to save VRAM

    config = tfs.AutoConfig.from_pretrained(name)
    model = tfs.AutoModelForCausalLM.from_pretrained(name)
    tokenizer = tfs.AutoTokenizer.from_pretrained(
        name,
        model_max_length=MAX_LEN
    )

    # This is the Mistral chat template in a format HF Transformers uses.
    # We chose this template because it requires no special tokens to function properly.
    tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- message['content'] -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'[INST] ' + message['content'].rstrip() + ' [/INST]'-}}{%- else -%}{{-'' + message['content'] + '</s>' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-''-}}{%- endif -%}"
    
    # Using eos as the pad token seems common practice.
    tokenizer.pad_token = tokenizer.eos_token

    return config, model, tokenizer

cfg, mdl, tok = load_model(smol_lm)

config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.45G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.69k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/831 [00:00<?, ?B/s]

# Instruct LoRA
First we'll train a LoRA for instruction following using the dolly-15k dataset.

In [4]:
# Instruct dataset. A possible alternative is "tatsu-lab/alpaca"
dolly = dts.load_dataset("databricks/databricks-dolly-15k")

In [5]:
# tok is captured from the global namespace
# x keys are instruction, context, response, category
# Not all samples have a context, so we'll ignore it.
def dolly_chat(x):
    chat = [
        # {"role": "system", "content": "{}".format(x["context"])},
        {"role": "user", "content": "{}".format(x["instruction"])},
        {"role": "assistant", "content": "{}".format(x["response"])},
    ]
    chat_formatted = tok.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True,
    )

    # TODO Padding to max length always seems to result in static VRAM usage, but
    # is slower on average since many samples are much shorter than max_length.
    # Want to debug why peak VRAM fluctuates a lot when length can vary, as this sometimes
    # OOMs midway through training.
    tokenized = tok(chat_formatted, padding="max_length", truncation=True)

    return {"text": chat_formatted, "input_ids": tokenized["input_ids"]}

dset_w_tokenized = dolly["train"].map(dolly_chat)

Map:   0%|          | 0/15011 [00:00<?, ? examples/s]

In [6]:
print(len(dset_w_tokenized["input_ids"][0]))
print(dset_w_tokenized["text"][0])

1024
[INST] When did Virgin Australia start operating? [/INST]Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s>


In [7]:
rank = 32

# I got much better results by training embed_tokens. It's possible <|im_start|> <|im_end|> are untrained.
lora_config = peft.LoraConfig(
    # This is the rank you see in all the LoRA materials
    r=rank,
    # These are (almost) all of the linear layers. You can experiment by training fewer of them.
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type=peft.TaskType.CAUSAL_LM,
    # Rule of thumb for alpha seems to be 1-2x the rank.
    lora_alpha=1 * rank,
    lora_dropout=0.05
)
lora_model = peft.get_peft_model(mdl, lora_config)
lora_model.print_trainable_parameters()

trainable params: 17,367,040 || all params: 379,188,160 || trainable%: 4.5801


In [8]:
# These 3 things are your main training parameters.
lr=5e-5

# Lower this if you get CUDA out of memory, but try to keep
# (batchsize * gradient_accumulation_steps) at least 8.
batchsize=4 

epochs=1

args = tfs.TrainingArguments(
    output_dir='./finetune',
    optim='adamw_torch',
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batchsize,
    per_device_eval_batch_size=batchsize,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,
    save_strategy="no",
    weight_decay=0.01,
    push_to_hub=False,
    report_to='none',
    torch_empty_cache_steps=100,
    bf16=True,
    tf32=True # Comment this if it gives you an error. It requires Ampere or newer.
)

collator = tfs.DataCollatorForLanguageModeling(tok, mlm=False)

trainer = tfs.Trainer(
    model=lora_model,
    args=args,
    train_dataset=dset_w_tokenized,
    processing_class=tok,
    data_collator=collator
)

In [9]:
trainer.train()

Step,Training Loss
500,8.8953


TrainOutput(global_step=938, training_loss=8.680117218733342, metrics={'train_runtime': 659.6794, 'train_samples_per_second': 22.755, 'train_steps_per_second': 1.422, 'total_flos': 3.061364501250048e+16, 'train_loss': 8.680117218733342, 'epoch': 0.9997335464961364})

In [10]:
question = "Can you tell me about crows?"

def generate_instruct(
    model,
    instruction,
    max_new_tokens=128,
    temperature=0.5,
    top_k=50,
    repetition_penalty=1.1
):
    chat = [
        {"role": "user", "content": f"{instruction}"},
    ]
    text = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    inputs = tok(text, return_tensors='pt', truncation=True).to(lora_model.device)
    print("Prompt has", len(inputs["input_ids"][0]), "tokens")
    
    with torch.no_grad():
        output = model.generate(
            **inputs,
            do_sample=True,
            pad_token_id=tok.pad_token_id,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
        )
        
        output = tok.batch_decode(output)[0]
        return output

In [None]:
print(generate_instruct(lora_model, question))

Prompt has 14 tokens


In [18]:
lora_model.save_pretrained(save_directory="mlhi-lora-instruct", save_embedding_layers=True)

# Loading LoRA back for inference

In [19]:
# Load a fresh copy of the model
cfg, mdl, tok = load_model(smol_lm)

In [20]:
adapted_model = peft.PeftModel.from_pretrained(mdl, "mlhi-lora-instruct", adapter_name="mlhi-lora-instruct").to("cuda")
print(adapted_model.active_adapters)

['mlhi-lora-instruct']


The next cell loads the LoRA trained on the text corpus. You can optionally skip the next cell to see what the model generates WITHOUT this LoRA active!

In [21]:
question = "Can you tell me a story about a bird?"

# Setting a seed guarantees that samplers pick the same tokens every time if all else is equal.
# This lets you change your generation settings or the question and see how it affects the result.
# Comment it for a random response every time.
torch.manual_seed(1651)
print(
    generate_instruct(
        lora_model,
        question,
        max_new_tokens=128,
        temperature=0.5,
        top_k=10,
        repetition_penalty=1.1
    )
)

Prompt has 17 tokens
[INST] Can you tell me a story about a bird? [/INST]A lot of birds are known as songbirds. Some common ones include the blue jay, house sparrow, and robin</s></p><p>There is also a bird called a thrush that can be found in the US.</s>
<br />
</p>

<p>Some other birds are also known as waterfowl such as ducks, geese, and swans</s> </p>
<p>The bird that lives on the ground is called a bat</s> </p>
<p>Bats are nocturnal creatures</s> </p>
<p
