In [None]:
from typing import Optional, Union, List, Dict
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from transformers.quantization import QuantizationConfig
from datasets import load_dataset
import torch
import evaluate

class AdvancedModelLoader:
    def __init__(
        self,
        model_name_or_path: str,
        quantize_bits: Optional[int] = None,
        num_labels: int = 2,
        task_name: str = "text-classification",
    ):
        self.model_name_or_path = model_name_or_path
        self.quantize_bits = quantize_bits
        self.num_labels = num_labels
        self.task_name = task_name
        self.model = None
        self.tokenizer = None
        self.data_collator = None
        self.compute_metrics = None

    def load_model(self):
        quantization_config = None
        if self.quantize_bits:
            if self.quantize_bits not in [4, 8]:
                raise ValueError("Quantization bits must be either 4 or 8.")
            quantization_config = QuantizationConfig(bits=self.quantize_bits)

        self.model = AutoModelForSequenceClassification.from_pretrained(
            self.model_name_or_path,
            num_labels=self.num_labels,
            quantization_config=quantization_config,
            device_map="auto",
        )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
        self.data_collator = DataCollatorWithPadding(self.tokenizer)

        if self.task_name == "text-classification":
            self.compute_metrics = self._compute_metrics_text_classification

    def _compute_metrics_text_classification(self, eval_pred):
        load_accuracy = evaluate.load("accuracy")
        load_f1 = evaluate.load("f1")

        logits, labels = eval_pred
        predictions = torch.argmax(torch.from_numpy(logits), dim=-1)

        accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
        f1 = load_f1.compute(predictions=predictions, references=labels, average="weighted")["f1"]

        return {"accuracy": accuracy, "f1": f1}

    def train(self, train_dataset, eval_dataset, training_args: TrainingArguments):
        if self.model is None:
            raise ValueError("Model has not been loaded. Call load_model() first.")

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=self.data_collator,
            compute_metrics=self.compute_metrics,
        )
        trainer.train()

    def evaluate(self, eval_dataset):
        if self.model is None:
            raise ValueError("Model has not been loaded. Call load_model() first.")

        trainer = Trainer(
            model=self.model,
            data_collator=self.data_collator,
            compute_metrics=self.compute_metrics,
        # ) <span class="cursor"></span>
