In [1]:
# setup training args/config

num_training_steps = 10
max_lr = 5e-5
min_lr = 5e-6
T_max = num_training_steps
lora_rank = 64
lora_alpha = 128
lora_dropout = 0.05



# we will later test out for multiple values of this so commented out from here
# gradient_accumulation_steps = 8
per_device_eval_batch_size = 2
eval_accumulation_steps = 8
save_total_limit = 3
save_steps = 50
eval_steps = 10
warmup_steps = 10
bf16 = True
logging_steps = 10
weight_decay = 0.01

In [2]:
# create a lora config
from peft import LoraConfig

target_modules = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
]
lora_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
)

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer


model_name = "meta-llama/Llama-3.2-1B"
tokenizer_model = "meta-llama/Llama-3.2-1B"
original_tokenizer = "codellama/CodeLlama-7b-Instruct-hf"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
tokenizer.pad_token = tokenizer.eos_token
original_tokenizer = AutoTokenizer.from_pretrained(original_tokenizer)

In [4]:
# use datagen repo to get training data
# from /home/ubuntu/datagen/synthetic_sql/interleaved/ds3_cat_abc
from datasets import load_from_disk

# note we just directly load the train split here
ds_path = "/home/ubuntu/datagen/synthetic_sql/interleaved/ds11_api/train"

ds_train = load_from_disk(ds_path).select(range(500))

# the above was tokenized using codellama
# detokenize it then retokenize

original_bos_token_id = original_tokenizer.bos_token_id
original_eos_token_id = original_tokenizer.eos_token_id


def retokenize_data(item):
    # extract token ids only for the prompt
    # use the fact that labels are -100 for original inputs for the "prompt" part
    original_prompt_ids = [
        token_id
        for i, token_id in enumerate(item["input_ids"])
        if item["labels"][i] == -100
        and token_id not in {original_bos_token_id, original_eos_token_id}
    ]

    original_completion_ids = [
        token_id
        for i, token_id in enumerate(item["input_ids"])
        if item["labels"][i] != -100
        and token_id not in {original_bos_token_id, original_eos_token_id}
    ]

    # now retokenize these with the new tokenizer
    new_prompt_encoding = tokenizer(original_tokenizer.decode(original_prompt_ids), add_special_tokens=False)
    new_completion_encoding = tokenizer(original_tokenizer.decode(original_completion_ids), add_special_tokens=False)

    prompt_ids = new_prompt_encoding["input_ids"]
    completion_ids = new_completion_encoding["input_ids"]

    input_ids = [tokenizer.bos_token_id] + prompt_ids + completion_ids + [tokenizer.eos_token_id]

    labels = [-100] * (len(prompt_ids) + 1) + [x for x in completion_ids] + [tokenizer.eos_token_id]

    assert len(input_ids) == len(labels)

    return {
        "input_ids": input_ids,
        "attention_mask": [1] * len(input_ids),
        "labels": labels
    }


ds_train = ds_train.map(retokenize_data)

In [5]:
import torch
from transformers import TrainingArguments, Trainer
from peft import get_peft_model
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np


# data collator function that just pads to the max length of a mini batch
def collate_to_max_length(inputs):
    # find the max length in input_ids
    max_length = len(max(inputs, key=lambda x: len(x["input_ids"]))["input_ids"])
    eos = tokenizer.eos_token_id

    batch = {"input_ids": [], "attention_mask": [], "labels": []}

    for input in inputs:
        length = len(input["input_ids"])

        batch["input_ids"].append(input["input_ids"] + [eos] * (max_length - length))
        batch["labels"].append(input["labels"] + [-100] * (max_length - length))
        batch["attention_mask"].append(
            input["attention_mask"] + [1] * (max_length - length)
        )

    return {
        "input_ids": torch.tensor(batch["input_ids"]),
        "labels": torch.tensor(batch["labels"]),
        "attention_mask": torch.tensor(batch["attention_mask"]),
    }


def get_trainer(gradient_accumulation_steps: int, per_device_train_batch_size: int):
    """
    Function that creates training args using given gradient_accumulation_steps.
    """
    training_args = TrainingArguments(
        output_dir="./outputs",
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_steps=num_training_steps,
        warmup_steps=0,
        fp16=False,
        bf16=True,
        logging_steps=1,
        weight_decay=0.01,
        report_to="none",
    )

    peft_model = get_peft_model(model, lora_config)

    optimizer = AdamW(peft_model.parameters(), lr=max_lr)
    # Create scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=min_lr)

    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=ds_train,
        optimizers=(optimizer, scheduler),
        data_collator=lambda x: collate_to_max_length(x),
    )

    return trainer, peft_model

In [6]:
# import inspect

# "loss_kwargs" in inspect.signature(peft_model.get_base_model().forward).parameters

In [None]:
import gc

grad_accum = [8, 4, 2]
per_device = [1, 2, 4]
losses = {}

for p, g in zip(per_device, grad_accum):
    print(f"Training with gradient_accumulation_steps:", g)
    print(f"Training with per_device_train_batch_size:", p)
    trainer, peft_model = get_trainer(g, p)
    trainer.train()
    losses[g] = trainer.state.log_history
    # need to do to prevent OOM
    del peft_model
    gc.collect()
    torch.cuda.empty_cache()


In [None]:
l = {}
for g in losses:
    l[g] = [x['loss'] for x in losses[g] if 'loss' in x]

# plot the above dict in mpl

import matplotlib.pyplot as plt

for g in l:
    plt.plot(l[g], label=f"grad_accum_steps={g}")

plt.legend()
plt.show()


In [None]:
losses = []


def train_model(gradient_accumulation_steps):
    trainer = get_trainer(gradient_accumulation_steps=gradient_accumulation_steps)