In [1]:
from datasets import load_dataset
import os
import torch

os.environ["HF_HOME"] = "/home/jeromeku/dev/third_party/torchtune/hf_cache"
dataset = load_dataset(
    "philschmid/dolly-15k-oai-style",
    split="train",
)

In [64]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
dtype = torch.bfloat16
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

128009

In [147]:
tokenizer.model_max_length
print(model.config.max_position_embeddings == tokenizer.model_max_length)
from trl import setup_chat_format

True


In [154]:
tokenizer("Hello, how are you?", add_special_tokens=False)
tokenizer.decode(
    tokenizer.encode("Hello, how are you?", add_special_tokens=False),
    skip_special_tokens=False,
)

'Hello, how are you?'

In [72]:
pad_token = [k for k in added_vocab.keys() if "pad" in k][0]
pad_token

'<|finetune_right_pad_id|>'

In [73]:
tokenizer.pad_token = pad_token

In [74]:
tokenizer.pad_token_id

128004

In [102]:
from trl import SFTConfig, SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        max_seq_length=2048,
        output_dir="outputs",
        seed=3407,
        fp16=model.get_input_embeddings().weight.dtype == torch.float16,
        bf16=model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to="wandb",  # For W&B
        dataset_num_proc=4,
    ),
)

In [103]:
loader = trainer.get_train_dataloader()
batches = iter(loader)
batch = next(batches)
input_ids = batch["input_ids"][0]
attention_mask = batch["attention_mask"][0]
labels = batch["labels"][0]

In [140]:
def get_attended_inputs(input_ids, attention_mask):
    attended_idx = (attention_mask == 1).nonzero(as_tuple=False)
    return tokenizer.decode(
        input_ids.index_select(0, attended_idx.view(-1)), skip_special_tokens=False
    )


def get_attended_labels(labels):
    attended_label_idx = (labels != -100).nonzero(as_tuple=False)
    return tokenizer.decode(
        labels.index_select(0, attended_label_idx.view(-1)), skip_special_tokens=False
    )

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Feb 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Write a letter I can send to the company that installed my swimming pool. Explain to them that the pool has two leaks and that I'd like to make a warranty claim. Request a reply within the next 30 days.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Date: April 4, 2023
From: Firstname Lastname
To: California Pool Company

I purchased the home at 123 Main Street, Riverside, California in July 2022. The property has a pool that you installed in 2021. We have had multiple issues with the pool, including two leaks in the underground water lines supplying the automatic pool leveler. I understand the pool has a two-year warranty; I would like you to inspect the issues we have found before the warranty expires. I have enclosed photographs of the issues we have found so far. Please call me at 123-345-3883 to sch

In [None]:
print(get_attended_inputs(input_ids, attention_mask))

In [None]:
print(get_attended_labels(labels))

In [141]:
# Repeat with pad_token = eos_token
tokenizer.pad_token = tokenizer.eos_token
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    args=SFTConfig(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=1,
        max_steps=10,
        logging_steps=1,
        max_seq_length=2048,
        output_dir="outputs",
        seed=3407,
        fp16=model.get_input_embeddings().weight.dtype == torch.float16,
        bf16=model.get_input_embeddings().weight.dtype == torch.bfloat16,
        report_to="wandb",  # For W&B
        dataset_num_proc=4,
    ),
)


def get_batch(trainer):
    loader = trainer.get_train_dataloader()
    batches = iter(loader)
    batch = next(batches)
    input_ids = batch["input_ids"][0]
    attention_mask = batch["attention_mask"][0]
    labels = batch["labels"][0]
    return input_ids, attention_mask, labels


input_ids, attention_mask, labels = get_batch(trainer)

In [142]:
print(get_attended_inputs(input_ids, attention_mask))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Feb 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

Write a letter I can send to the company that installed my swimming pool. Explain to them that the pool has two leaks and that I'd like to make a warranty claim. Request a reply within the next 30 days.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Date: April 4, 2023
From: Firstname Lastname
To: California Pool Company

I purchased the home at 123 Main Street, Riverside, California in July 2022. The property has a pool that you installed in 2021. We have had multiple issues with the pool, including two leaks in the underground water lines supplying the automatic pool leveler. I understand the pool has a two-year warranty; I would like you to inspect the issues we have found before the warranty expires. I have enclosed photographs of the issues we have found so far. Please call me at 123-345-3883 to sch

In [143]:
print(get_attended_labels(labels))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 21 Feb 2025

<|start_header_id|>user<|end_header_id|>

Write a letter I can send to the company that installed my swimming pool. Explain to them that the pool has two leaks and that I'd like to make a warranty claim. Request a reply within the next 30 days.<|start_header_id|>assistant<|end_header_id|>

Date: April 4, 2023
From: Firstname Lastname
To: California Pool Company

I purchased the home at 123 Main Street, Riverside, California in July 2022. The property has a pool that you installed in 2021. We have had multiple issues with the pool, including two leaks in the underground water lines supplying the automatic pool leveler. I understand the pool has a two-year warranty; I would like you to inspect the issues we have found before the warranty expires. I have enclosed photographs of the issues we have found so far. Please call me at 123-345-3883 to schedule an inspection 