In [None]:
import torch
from peft import LoraConfig, PeftModel
from Scripts.llama_model_wrapper import InstructModelWrapper
from Scripts.load_dataset import load_dataset
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import SFTTrainer

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
model_kwargs = {
    "path": "meta-llama/Meta-Llama-3-8B-Instruct",
    "tokenizer_path": "meta-llama/Meta-Llama-3-8B-Instruct",
    "torch_dtype": "auto",
    "device_map": "auto",
    "use_cache": False,
    "quantization_config": quantization_config,
}

In [None]:
instruct_model_wrapper = InstructModelWrapper(**model_kwargs)

In [None]:
train_ds, _ = load_dataset.load_dataset(
    "../../../../German_newspaper_articles/10kGNAD/train.csv",
    "../../../../German_newspaper_articles/10kGNAD/test.csv",
)

In [None]:
train_ds = train_ds.map(
    instruct_model_wrapper.create_train_messages, remove_columns=["text"]
)

In [None]:
train_ds = train_ds.map(instruct_model_wrapper.apply_chat_template)

In [None]:
max = 0
for sample in train_ds:
    max = len(sample["chat_template"]) if len(sample["chat_template"]) > max else max
max

In [None]:
train_eval = train_ds.train_test_split(test_size=0.2, shuffle=True)

In [None]:
train_ds = train_eval["train"]
eval_ds = train_eval["test"]

In [None]:
output_dir = "../../../results/llama3_results/instruct/metrics"

In [None]:
import os

os.path.isdir(output_dir)

In [None]:
# based on config
training_args = TrainingArguments(
    fp16=True,  # specify bf16=True instead when training on GPUs that support bf16 else fp16
    bf16=False,
    do_eval=True,
    evaluation_strategy="steps",
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    learning_rate=2.0e-05,
    logging_steps=100,
    max_steps=1849,
    output_dir=output_dir,
    overwrite_output_dir=True,
    per_device_eval_batch_size=4,  # originally set to 8
    per_device_train_batch_size=4,  # originally set to 8
    save_total_limit=None,
    seed=42,
)

In [None]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  ## can also add th other layers  ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
)

In [None]:
trainer = SFTTrainer(
    model=instruct_model_wrapper.model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    dataset_text_field="chat_template",
    tokenizer=instruct_model_wrapper.tokenizer,
    packing=True,
    peft_config=peft_config,
    max_seq_length=1512,
)

In [None]:
train_result = trainer.train()

In [None]:
trainer.model.save_pretrained(output_dir)

In [None]:
train_ds.save_to_disk("./datasets/llama_train")
eval_ds.save_to_disk("./datasets/llama_eval")

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    use_cache=False,  # set to False as we're going to use gradient checkpointing
)

In [None]:
model = PeftModel.from_pretrained(base_model, output_dir)

In [None]:
model = model.merge_and_unload()

In [None]:
model.save_pretrained("../../../results/llama3_results/instruct/model")