In [None]:
import torch
from peft import LoraConfig, PeftModel
from Scripts import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from trl import SFTTrainer

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
system_message = """Kategorisiere die Eingabe in eine der folgendenen Kategorien. Antworte nur in einer der folgenden Kategorien:
    - Web
    - International
    - Etat
    - Wirtschaft
    - Panorama
    - Sport
    - Wissenschaft
    - Kultur
    - Inland"""

In [None]:
truncated_count = 0


def create_messages(sample):
    if len(sample["text"]) > 1024:
        sample["text"] = sample["text"][:1024]
        global truncated_count
        truncated_count += 1
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": sample["text"]},
        {"role": "assistant", "content": sample["label"]},
    ]
    return {"messages": messages}

In [None]:
def apply_chat_template(sample):
    return {
        "chat_template": tokenizer.apply_chat_template(
            sample["messages"], tokenize=False
        )
    }

In [None]:
train_ds, _ = load_dataset.load_dataset(
    "../../../../German_newspaper_articles/10kGNAD/train.csv",
    "../../../../German_newspaper_articles/10kGNAD/test.csv",
)

In [None]:
train_ds = train_ds.map(create_messages, remove_columns=["text"])
truncated_count

In [None]:
train_ds = train_ds.map(apply_chat_template)

In [None]:
max = 0
for sample in train_ds:
    max = len(sample["chat_template"]) if len(sample["chat_template"]) > max else max
max

In [None]:
train_eval = train_ds.train_test_split(test_size=0.2, shuffle=True)

In [None]:
train_ds = train_eval["train"]
eval_ds = train_eval["test"]

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

In [None]:
model_kwargs = dict(
    #     attn_implementation=False,#"flash_attention_2", # set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)
    torch_dtype="auto",
    use_cache=False,  # set to False as we're going to use gradient checkpointing
    device_map="auto",
    quantization_config=quantization_config,
)

In [None]:
output_dir = "../../../results/llama3_results/instruct/metrics"

In [None]:
import os

os.path.isdir(output_dir)

In [None]:
# based on config
training_args = TrainingArguments(
    fp16=True,  # specify bf16=True instead when training on GPUs that support bf16 else fp16
    bf16=False,
    do_eval=True,
    evaluation_strategy="steps",
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    learning_rate=2.0e-05,
    logging_steps=100,
    max_steps=1849,
    output_dir=output_dir,
    overwrite_output_dir=True,
    per_device_eval_batch_size=4,  # originally set to 8
    per_device_train_batch_size=4,  # originally set to 8
    save_total_limit=None,
    seed=42,
)

In [None]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  ## can also add th other layers  ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",  # dispatch the model efficiently on the available resources
    torch_dtype="auto",
    use_cache=False,  # set to False as we're going to use gradient checkpointing
    quantization_config=quantization_config,
)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    dataset_text_field="chat_template",
    tokenizer=tokenizer,
    packing=True,
    peft_config=peft_config,
    max_seq_length=1512,
)

In [None]:
train_result = trainer.train()

In [None]:
trainer.model.save_pretrained(output_dir)

In [None]:
train_ds.save_to_disk("./datasets/llama_train")
eval_ds.save_to_disk("./datasets/llama_eval")

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    use_cache=False,  # set to False as we're going to use gradient checkpointing
)

In [None]:
finetuned_model = PeftModel.from_pretrained(base_model, output_dir)

In [None]:
model = model.merge_and_unload()

In [None]:
model.save_pretrained("../../../results/llama3_results/instruct/metrics/model")