In [None]:
import os
#os.environ['LD_LIBRARY_PATH'] = 'YOUR_CONDA_ENV/lib'
import sys
from typing import List

import numpy as np 
import fire
import torch
import transformers
from datasets import load_dataset, concatenate_datasets
from transformers import EarlyStoppingCallback
 
from peft import (  # noqa: E402
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402

In [None]:
BASE_MODEL = "baffo32/decapoda-research-llama-7B-hf"
 
model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map={'':torch.cuda.current_device()},
)
 
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
 
tokenizer.pad_token_id = (0)  # unk. we want this to be different from the eos token

tokenizer.padding_side = "left"



In [None]:
def generate_prompt(data_point):
    # sorry about the formatting disaster gotta move fast
    if data_point["input"]:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 

### Instruction:
{data_point["instruction"]}

### Input:
{data_point["input"]}

### Response:
{data_point["output"]}"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  

### Instruction:
{data_point["instruction"]}

### Response:
{data_point["output"]}"""
 
cutoff_len = 512

def tokenize(prompt, add_eos_token=True):
        # there's probably a way to do this with the tokenizer settings
        # but again, gotta move fast
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=cutoff_len,
            padding=False,
            return_tensors=None,
        )
        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < cutoff_len
            and add_eos_token
        ):
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()

        return result
 
def generate_and_tokenize_prompt(data_point):
        full_prompt = generate_prompt(data_point)
        tokenized_full_prompt = tokenize(full_prompt)
        if not train_on_inputs:
            user_prompt = generate_prompt({**data_point, "output": ""})
            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
            user_prompt_len = len(tokenized_user_prompt["input_ids"])

            tokenized_full_prompt["labels"] = [
                -100
            ] * user_prompt_len + tokenized_full_prompt["labels"][
                user_prompt_len:
            ]  # could be sped up, probably
        return tokenized_full_prompt

In [None]:
lora_r = 8
lora_alpha = 16
lora_target_modules = ["q_proj", "v_proj",]
lora_dropout = 0.05
model = prepare_model_for_int8_training(model)

config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

In [None]:
train_data_list = []
val_data_list = []

train_data_path = ["data/movielens_1m/train.json"]
val_data_path = ["data/movielens_1m/valid_5000.json"]
for path in train_data_path:
    if path.endswith(".json"):
        train_data_list.append(load_dataset("json", data_files=path))
    else:
        train_data_list.append(load_dataset(path))

for path in val_data_path:
        if path.endswith(".json"):
            val_data_list.append(load_dataset("json", data_files=path))
        else:
            val_data_list.append(load_dataset(path))


In [None]:
sample = 1024
seed = 0
train_on_inputs = True
for i in range(len(train_data_list)):
        train_data_list[i]["train"] = train_data_list[i]["train"].shuffle(seed=seed).select(range(sample)) if sample > -1 else train_data_list[i]["train"].shuffle(seed=seed)
        train_data_list[i]["train"] = train_data_list[i]["train"].shuffle(seed=seed)
        train_data_list[i] = train_data_list[i].map(lambda x: generate_and_tokenize_prompt(x))
for i in range(len(val_data_list)):
        val_data_list[i] = val_data_list[i].map(lambda x: generate_and_tokenize_prompt(x))
train_data = concatenate_datasets([_["train"] for _ in train_data_list])
val_data = concatenate_datasets([_["train"] for _ in val_data_list])


In [None]:
len(train_data[0]['input_ids']), len(train_data[0]['labels'])

In [None]:
model.print_trainable_parameters()  # Be more transparent about the % of trainable params.

if torch.cuda.device_count() > 1:
    model.is_parallelizable = True
    model.model_parallel = True

In [None]:
output_dir = "data/output/movie/${seed}_${sample}"
micro_batch_size = 4
batch_size = 128
gradient_accumulation_steps = batch_size // micro_batch_size
num_epochs = 5
learning_rate = 1e-4
group_by_length = False
resume_from_checkpoint = None

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=micro_batch_size,
        per_device_eval_batch_size=micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_steps=20,
        num_train_epochs=num_epochs,
        learning_rate=learning_rate,
        fp16=True,
        logging_steps=8,
        optim="adamw_torch",
        evaluation_strategy="epoch",
        eval_steps=50,
        save_strategy="epoch",
        output_dir=output_dir,
        save_total_limit=1,
        load_best_model_at_end=True,
        ddp_find_unused_parameters=None,
        group_by_length=group_by_length,
        report_to=None,
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(
        self, old_state_dict()
    )
).__get__(model, type(model))

if torch.__version__ >= "2" and sys.platform != "win32":
    model = torch.compile(model)

trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained(output_dir)
print(
        "\n If there's a warning about missing keys above, please disregard :)"
    )