# Train an LLM with custom data

In this notebook we will prepare a training dataset and use it to train our custom model.

You must have `HUGGINGFACE_API_KEY` in a `.env` file for this to work.

First, we'll download a training dataset (only needs to run the first time)

In [1]:
import requests
import os
from dotenv import load_dotenv

load_dotenv()

dataset = "statworx/haiku"
headers = {"Authorization": f"Bearer {os.environ.get('HUGGINGFACE_API_KEY')}"}
API_URL = f"https://datasets-server.huggingface.co/parquet?dataset={dataset}"

def query():
    response = requests.get(API_URL, headers=headers)
    return response.json()

# get the url to the datafile
data = query()
url = data["parquet_files"][0]["url"]

r = requests.get(url, allow_redirects=True)
with open('data/haikus.parquet', 'wb') as file:
    file.write(r.content)

## Prepare the data

Next, we'll load the dataset into a pandas dataframe and prepare a training dataset

In [2]:
import pandas as pd

haikus = pd.read_parquet("data/haikus.parquet")
haikus

Unnamed: 0,source,text,text_phonemes,keywords,keyword_phonemes,gruen_score,text_punc
0,bfbarry,Delicate savage. / You'll never hold the cinde...,deh|lax|kaxt sae|vaxjh / yuwl neh|ver hhowld d...,cinder,sihn|der,0.639071,
1,bfbarry,A splash and a cry. / Words pulled from the ri...,ax splaesh aend ax kray / werdz puhld frahm dh...,the riverside,dhax rih|ver|sayd,0.563353,
2,bfbarry,"Steamy, mist rising. / Rocks receiving downwar...",stiy|miy mihst ray|zaxng / raaks rax|siy|vaxng...,mist rising,mihst ray|zaxng,0.538326,
3,bfbarry,You were broken glass. / But I touched you eve...,yuw wer brow|kaxn glaes / baht ay tahcht yuw i...,broken glass,brow|kaxn glaes,0.703446,
4,bfbarry,Eyes dance with firelight. / The Moon and I ar...,ayz daens wihdh faxr|layt / dhax muwn aend ay ...,eyes dance,ayz daens,0.830985,
...,...,...,...,...,...,...,...
49019,haiku_data_2,Alpine Lake. / Mybreaststrokesshiningarc. / To...,ael|payn leyk mih|brehst|strow|kehsh|hhax|nihn...,toward sunrise,tax|waord sahn|rayz,0.685355,Alpine Lake. Mybreaststrokesshiningarc. Toward...
49020,haiku_data_2,Spruce Woods. / Fireweed filling. / The vacancy.,spruws wuhdz fay|er|wiyd fih|laxng dhax vey|ka...,woods,wuhdz,0.568974,Spruce Woods. Fireweed filling. The vacancy.
49021,haiku_data_2,Corrugated sun. / Chilies and laundry. / In ro...,kao|rax|gey|taxd sahn chih|liyz aend laon|driy...,sun chilies,sahn chih|liyz,0.551056,Corrugated sun. Chilies and laundry. In roofto...
49022,haiku_data_2,Home from war. / We ease out. / The champagne ...,hhowm frahm waor wiy iyz awt dhax shaxm|peyn k...,home,hhowm,0.697112,Home from war. We ease out. The champagne corks.


In [6]:
# Let's take just a random sample of 5000 of these.
haikus_text = haikus["text"].sample(100)

# print the first haiku
print("Original Haiku format:")
print(haikus_text.iloc[0])

# Add newlines
haikus_text = haikus_text.str.replace(" / ", "\n")
print("\nWith new-lines:")
print(haikus_text.iloc[0])

# Look at 4 more of them
for i in range(1, 5):
    print("\n" + haikus_text.iloc[i])

Original Haiku format:
My man is sleeping. / So peaceful I'm going to. / Get up, cook breakfast.

With new-lines:
My man is sleeping.
So peaceful I'm going to.
Get up, cook breakfast.

Train whistle.
A blackbird hops.
Along its notes.

Winter's breathy joke.
Exploring fragilities.
Frigid hearts break hard.

Butterfly wings.
Their so loud flapping.
In a temple's Silence.

Summer sky.
My father counts the black faces.
On my road.


## Uh-oh!

They don't all have the standard format. Time to make sure our dataset is really clean before training.

In [7]:
import syllables

def validate_haiku(haiku):
    syllable_count = [syllables.estimate(line) for line in haiku.split(" / ")]
    if syllable_count == [5, 7, 5]:
        return True
    else:
        return False

haikus["valid"] = haikus["text"].apply(validate_haiku)

# now filter down to only the "valid" ones
valid_haikus = haikus[haikus["valid"]==True]
valid_haikus.sort_values("gruen_score", ascending=False)

Unnamed: 0,source,text,text_phonemes,keywords,keyword_phonemes,gruen_score,text_punc,valid
2056,bfbarry,Your coldness burns me. / You stab me with ici...,yaor kowld|naxs bernz miy / yuw staeb miy wihd...,icicles,ay|sax|kaxlz,0.891561,,True
1313,bfbarry,Darkness falls quickly. / The sun sets behind ...,daark|naxs faolz kwih|kliy / dhax sahn sehts b...,darkness falls,daark|naxs faolz,0.885124,,True
2884,bfbarry,Rain falls from the sky. / Your head tilts up ...,reyn faolz frahm dhax skay / yaor hhehd tihlts...,rain falls,reyn faolz,0.877835,,True
14525,twaiku,Joe Thomas knows how. / You lost your virginit...,jhow taa|maxs nowz hhaw / yuw laost yaor ver|j...,your virginity,yaor ver|jhih|nax|tiy,0.877294,,True
5739,bfbarry,"Now, will you begin? / To do what you want to ...",naw wihl yuw bax|gihn / tax duw waht yuw waant...,always dying,aol|weyz day|axng,0.875700,,True
...,...,...,...,...,...,...,...,...
29508,haiku_data_1,Enough farmhouse. / A vine wrapped around. / T...,ih|nahf faarm|hhaws ey vayn raept er|awnd dhax...,vine wrapped,vayn raept,0.394552,Enough farmhouse. A vine wrapped around. The w...,True
36281,haiku_data_1,Lightstudded ferry. / Passes in the starless n...,layt|stax|daxd feh|riy pae|saxz ihn dhax staar...,night,nayt,0.393787,Lightstudded ferry. Passes in the starless nig...,True
39837,haiku_data_1,Aversed home. / An old life laid to rest. / Am...,ae|verst hhowm axn owld layf leyd tax rehst ax...,weeds,wiydz,0.376451,Aversed home. An old life laid to rest. Among ...,True
29655,haiku_data_1,Our passing breeze. / A do not disturb sign sw...,aw|er pae|saxng briyz ey duw naat dax|sterb sa...,breeze,briyz,0.371297,Our passing breeze. A do not disturb sign swin...,True


## Good enough, let's train!

Next up we'll use the "text" and "keywords" columns to construct a training dataset to use for training our model with this filtered set of haikus.

In [8]:
import logging
import os

import yaml

from ludwig.api import LudwigModel

prompt_template = """
<|system|>
You are a haiku writer and respond to all questions with a haiku</s>
<|user|>
Tell me about {keywords}</s>
<|assistant|>
{text}
"""

# Build out the configuration
config = yaml.safe_load(
    """
model_type: llm
base_model: HuggingFaceH4/zephyr-7b-beta

quantization:
  bits: 4

adapter:
  type: lora

input_features:
  - name: keywords
    type: text

output_features:
  - name: text
    type: text

trainer:
    type: finetune
    learning_rate: 0.0003
    batch_size: 2
    gradient_accumulation_steps: 8
    epochs: 3
    learning_rate_scheduler:
      warmup_fraction: 0.01

backend:
  type: local
"""
)

# Define Ludwig model object that drive model training
model = LudwigModel(config=config, logging_level=logging.INFO)

# initiate model training
(
    train_stats,  # dictionary containing training statistics
    preprocessed_data,  # tuple Ludwig Dataset objects of pre-processed training data
    output_directory,  # location of training results stored on disk
) = model.train(
    dataset=valid_haikus
)

# list contents of output directory
print("contents of output directory:", output_directory)
for item in os.listdir(output_directory):
    print("\t", item)

PyTorch version 2.1.2 available.


  from .autonotebook import tqdm as notebook_tqdm


generated new fontManager


config.json: 100%|██████████| 638/638 [00:00<00:00, 287kB/s]

Setting generation max_new_tokens to 16384 to correspond with the max sequence length assigned to the output feature or the global max sequence length. This will ensure that the correct number of tokens are generated at inference time. To override this behavior, set `generation.max_new_tokens` to a different value in your Ludwig config.

╒════════════════════════╕
│ EXPERIMENT DESCRIPTION │
╘════════════════════════╛

╒══════════════════╤══════════════════════════════════════════════════════════════════════════════════════╕
│ Experiment name  │ api_experiment                                                                       │
├──────────────────┼──────────────────────────────────────────────────────────────────────────────────────┤
│ Model name       │ run                                                                                  │
├──────────────────┼──────────────────────────────────────────────────────────────────────────────────────┤
│ Output directory │ /home/dave/code/l




Building dataset (it may take a while)


tokenizer_config.json: 100%|██████████| 1.43k/1.43k [00:00<00:00, 596kB/s]
tokenizer.model: 100%|██████████| 493k/493k [00:00<00:00, 10.8MB/s]
added_tokens.json: 100%|██████████| 42.0/42.0 [00:00<00:00, 49.9kB/s]
special_tokens_map.json: 100%|██████████| 168/168 [00:00<00:00, 174kB/s]
tokenizer.json: 100%|██████████| 1.80M/1.80M [00:00<00:00, 27.6MB/s]

Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer



Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Max length of feature 'keywords': 9 (without start and stop symbols)
Max sequence length is 9 for feature 'keywords'
Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Max length of feature 'text': 33 (without start and stop symbols)
Max sequence length is 33 for feature 'text'
Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Building dataset: DONE
Writing preprocessed training set cache to /home/dave/code/llama-haiku/d66036b8af3f11eea87be55baf3cf3ab.training.hdf5
Writing preprocessed validation set cache to /home/dave/code/llama-haiku/d66036b8af3f11eea87be55baf3cf3ab.validation.hdf5
Writing preprocessed test set cache to /home/dave/code/llama-haiku/d66036b8af3f11eea87be55baf3cf3ab.test.hdf5
Writing train set metadata to /home/dave/code/llama-haiku/d66036b8af3f11eea87be55baf3cf3ab.meta.json

Dataset Statistics
╒════════════╤═══════════════╤════════════════════╕
│ Dataset    │   Size (Rows) │ Size (In Memory)   │
╞════════════╪═══════════════╪════════════════════╡
│ Training   │          2951 │ 691.77 Kb          │
├────────────┼───────────────┼────────────────────┤
│ Validation │           421 │ 98.80 Kb           │
├────────────┼───────────────┼────────────────────┤
│ Test       │           843 │ 197.70 Kb          │
╘════════════╧═══════════════╧════════════════════╛

╒═══════╕
│ MODEL │
╘═══════╛

Loadin

model.safetensors.index.json: 100%|██████████| 23.9k/23.9k [00:00<00:00, 9.75MB/s]
model-00001-of-00008.safetensors: 100%|██████████| 1.89G/1.89G [00:17<00:00, 111MB/s]
model-00002-of-00008.safetensors: 100%|██████████| 1.95G/1.95G [00:17<00:00, 111MB/s]
model-00003-of-00008.safetensors: 100%|██████████| 1.98G/1.98G [00:18<00:00, 106MB/s]
model-00004-of-00008.safetensors: 100%|██████████| 1.95G/1.95G [00:18<00:00, 104MB/s] 
model-00005-of-00008.safetensors: 100%|██████████| 1.98G/1.98G [00:19<00:00, 104MB/s]
model-00006-of-00008.safetensors: 100%|██████████| 1.95G/1.95G [00:18<00:00, 105MB/s]
model-00007-of-00008.safetensors: 100%|██████████| 1.98G/1.98G [00:19<00:00, 103MB/s]
model-00008-of-00008.safetensors: 100%|██████████| 816M/816M [00:07<00:00, 109MB/s]
Downloading shards: 100%|██████████| 8/8 [02:16<00:00, 17.10s/it]


We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 8/8 [00:04<00:00,  1.94it/s]
generation_config.json: 100%|██████████| 111/111 [00:00<00:00, 130kB/s]

Done.





Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer
Trainable Parameter Summary For Fine-Tuning
Fine-tuning with adapter: lora
trainable params: 3,407,872 || all params: 7,245,139,968 || trainable%: 0.04703666202518836

╒══════════╕
│ TRAINING │
╘══════════╛

Creating fresh model training run.
Training for 4428 step(s), approximately 3 epoch(s).
Early stopping policy: 5 round(s) of evaluation, or 7380 step(s), approximately 5 epoch(s).

Starting with step 0, epoch: 0
Training:  33%|███▎      | 1475/4428 [03:13<06:26,  7.63it/s, loss=0.325]Last batch in epoch only has 1 sample and will be dropped.
Last batch in epoch only has 1 sample and will be dropped.
Last batch in epoch only has 1 sample and will be dropped.
Training:  33%|███▎      | 1476/4428 [03:13<07:27,  6.60it/s, loss=0.277]
Running evaluation for step: 1476, epoch: 1
Evaluation valid: 100%|██████████| 211/211 [00:15<00:00, 13.92it/s]
Input: heart feels
Output: The My heart feels so. / Now sure, to, bu

## Ok, so how'd we do?

In [13]:
df = pd.DataFrame.from_dict({"keywords": ["flowers", "trucks", "data analysis"]})

response = model.predict(df)

response

Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Prediction: 100%|██████████| 1/1 [00:06<00:00,  6.04s/it]
Loaded HuggingFace implementation of HuggingFaceH4/zephyr-7b-beta tokenizer
Finished predicting in: 7.68s.


  return np.sum(np.log(sequence_probabilities))


(                                    text_predictions  \
 0  [, I, ', m, not, a, flower, ., /, I, ', m, not...   
 1  [, I, ', m, not, gonna, lie, ., /, I, ', m, re...   
 2  [, I, ', m, really, good, at, ., /, Data, anal...   
 
                                   text_probabilities  \
 0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
 1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
 2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
 
                                        text_response  text_probability  
 0  [I'm not a flower. / I'm not a flower, I'm not...              -inf  
 1  [I'm not gonna lie. / I'm really looking forwa...              -inf  
 2  [I'm really good at. / Data analysis, but I'm....              -inf  ,
 'results')

In [21]:
answers = response[0]["text_response"]

for a in answers:
    print(a[0])

I'm not a flower. / I'm not a flower, I'm not. / A flower, I'm not.
I'm not gonna lie. / I'm really looking forward. / To the trucks tonight.
I'm really good at. / Data analysis, but I'm. / Not good at math.
