In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from Scripts.llama_model_wrapper import HeadClassifierWrapper
from Scripts.load_dataset import load_dataset
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from transformers import (
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

In [None]:
base_model_id = "meta-llama/Meta-Llama-3-8B"
base_model_path = "../../../results/llama3_results/classification_head/model/base"
# base_model_path = "../../../results/llama3_results/classification_head/run_3/model/finetuned"
tokenizer_path = (
    "../../../results/llama3_results/classification_head/run_3/metrics/checkpoint-1800"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Execute only, when no base model already exists

In [None]:
model_kwargs = {
    "path": base_model_id,
    "num_labels": 9,
    "tokenizer_path": base_model_id,
    "device_map": "auto",
    "use_cache": False,
    "quantization_config": None,
}

In [None]:
classification_wrapper = HeadClassifierWrapper(**model_kwargs)
classification_wrapper.model.save_pretrained(base_model_path)

In [None]:
del classification_wrapper
torch.cuda.empty_cache()

### Continue here, when base model exists

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
model_kwargs = {
    "path": base_model_path,
    "num_labels": 9,
    "tokenizer_path": tokenizer_path,
    "device_map": "auto",
    "use_cache": False,
    "quantization_config": quantization_config,
}

In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_CLS",
    modules_to_save=["score"],
)

classification_wrapper = HeadClassifierWrapper(**model_kwargs)
classification_wrapper.model = prepare_model_for_kbit_training(
    classification_wrapper.model
)
classification_wrapper.model = get_peft_model(classification_wrapper.model, lora_config)

classification_wrapper.model.config.pad_token_id = (
    classification_wrapper.tokenizer.pad_token_id
)
classification_wrapper.model.config.use_cache = False
classification_wrapper.model.config.pretraining_tp = 1

### Load Datasets for new Training

In [None]:
train_ds, _ = load_dataset(
    "../../../../autrata_env/Documents/German_newspaper_articles/10kGNAD/train.csv",
    "../../../../autrata_env/Documents/German_newspaper_articles/10kGNAD/test.csv",
)

In [None]:
train_ds = train_ds.map(classification_wrapper.tokenize_text, remove_columns="text")

In [None]:
train_ds = train_ds.map(classification_wrapper.add_label_id)

In [None]:
max = 0
for sample in train_ds:
    max = len(sample["input_ids"]) if len(sample["input_ids"]) > max else max
max

In [None]:
train_eval = train_ds.train_test_split(test_size=0.2, shuffle=True)

In [None]:
train_ds = train_eval["train"]
eval_ds = train_eval["test"]

### Load Datasets from existing train-eval split for continous training

In [None]:
from datasets import Dataset

In [None]:
train_ds = Dataset.load_from_disk(
    "../../../results/llama3_results/classification_head/datasets/train"
)
eval_ds = Dataset.load_from_disk(
    "../../../results/llama3_results/classification_head/datasets/eval"
)
train_ds = train_ds.shuffle()
eval_ds = eval_ds.shuffle()

### Continue here for training

In [None]:
train_ds.set_format("torch")
eval_ds.set_format("torch")

In [None]:
collate_fn = DataCollatorWithPadding(tokenizer=classification_wrapper.tokenizer)

In [None]:
def compute_metrics(evaluations):
    predictions, labels = evaluations
    predictions = np.argmax(predictions, axis=1)
    return {
        "balanced_accuracy": balanced_accuracy_score(predictions, labels),
        "accuracy": accuracy_score(predictions, labels),
    }

In [None]:
class CustomTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            self.class_weights = torch.tensor(class_weights, dtype=torch.float32).to(
                self.args.device
            )
        else:
            self.class_weights = None

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels").long()

        outputs = model(**inputs)

        logits = outputs.get("logits")

        if self.class_weights is not None:
            loss = F.cross_entropy(logits, labels, weight=self.class_weights)
        else:
            loss = F.cross_entropy(logits, labels)

        return (loss, outputs) if return_outputs else loss

In [None]:
output_dir = "../../../results/llama3_results/classification_head/metrics"

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    max_steps=1800,
    logging_steps=100,
    save_steps=100,
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    report_to="none",
    overwrite_output_dir=True,
)

In [None]:
trainer = CustomTrainer(
    model=classification_wrapper.model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=classification_wrapper.tokenizer,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

In [None]:
train_result = trainer.train()

In [None]:
trainer.model.save_pretrained(output_dir)

### Create full model and save it

In [None]:
base_model = AutoModelForSequenceClassification.from_pretrained(
    base_model_path,
    num_labels=9,
    device_map="auto",
    torch_dtype="auto",
    use_cache=False,  # set to False as we're going to use gradient checkpointing
)

In [None]:
model = PeftModel.from_pretrained(base_model, output_dir)

In [None]:
model = model.merge_and_unload()

In [None]:
finetuned_model_path = "../../../results/llama3_results/classification_head/model"
model.save_pretrained(finetuned_model_path)

### Save train-eval split for continous training

In [None]:
train_ds.save_to_disk(
    "../../../results/llama3_results/classification_head/datasets/train"
)
eval_ds.save_to_disk(
    "../../../results/llama3_results/classification_head/datasets/eval"
)