<a href="https://colab.research.google.com/github/IdanKanat/COVID_NLP_Advanced_DL_Project/blob/main/AdvancedTopicsDL_Project_IdanKanat%26IdoShahar_COVID_NLP_21.8.2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%run ./01_EDA_and_data_preprocessing.ipynb
%run ./02_finetune_WITHOUT_HF_Trainer.ipynb
%run ./03_finetune_with_HF_Trainer.ipynb



ModuleNotFoundError: No module named 'google.colab'

ModuleNotFoundError: No module named 'google.colab'

# **Compression Techniques**

## **Technique (1) - Quantization**

As a model compression technique, **Quantization reduces model size and speeds up inference by converting weights to lower precision, quantizing them**. The function below applies **dynamic quantization (Post-Training) on a fine-tuned HF model** (which has gone through the final training phases above), evaluates it on the test set, and saves the quantized version.

In [None]:
# Critical roots
# basic_drive_path = "/content/drive/MyDrive" # USER CAN CHANGE IT IF HE DOESN'T WORK IN DRIVE AND DOWNLOADS FROM DRIVE THE Project_COVID_NLP folder!! (under # but the hashtag sign # can be removed if needed)
project_root = f"{basic_drive_path}/Project_COVID_NLP" # Root project folder
model_root   = f"{project_root}/Model_Weights"

# Define quant_root inside the project, for all quantized weights
quant_root = f"{project_root}/Quantized_Model_Weights"

In [None]:
# Helper function which evaluates model performance given a specific dataset (loader) - train / test:
def evaluate_model(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = outputs.logits.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Load relevant metrics - Accuracy, F1-Score, Precision & Recall:
    metrics = {
        "accuracy": accuracy_score(all_labels, all_preds),
        "f1": f1_score(all_labels, all_preds, average="macro", zero_division=0),
        "precision": precision_score(all_labels, all_preds, average="macro", zero_division=0),
        "recall": recall_score(all_labels, all_preds, average="macro", zero_division=0)
    }

    return metrics

In [None]:
# This function quantizes a fine-tuned HF model, evaluates it on training & test sets, compares its performance with the previous model's, and saves the quantized version.
def quantize_evaluate_and_compare(model_name, model_name_dir, best_params):

  # Define original model path (trained weights) and quantized save path
    model_path     = f"{model_root}/{model_name_dir}"
    quantized_path = f"{quant_root}/{model_name_dir}_quantized"

    # Select correct pretokenized dataset
    if "roberta" in model_name_dir.lower():
        pretokenized_dir = "data/tokenized_twitter_roberta_base" # the folder for saving the model
    else:
        pretokenized_dir = "data/tokenized_bertweet_base" # the folder for saving the model

    # Load model & tokenizer - Initially on GPU but need to be moved to CPU before Quantization!
    model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # safety: correct dtypes + torch output
    ds = load_from_disk(pretokenized_dir) #Loads the Arrow-backed HF DatasetDict that are defines later on in the Pre-tokenization part
    for split in ds:
        ds[split] = ds[split].cast_column("input_ids", Sequence(Value("int64")))
        ds[split] = ds[split].cast_column("attention_mask", Sequence(Value("int64")))  # or "bool"
        ds[split] = ds[split].cast_column("labels", Value("int64"))
        ds[split].set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    # keep dynamic padding (no tokenization here—collator only pads per batch)
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

    # Merge train + validation for final training
    full_train_dataset = concatenate_datasets([ds["train_reduced"], ds["validation"]])
    full_train_dataset = full_train_dataset.shuffle(seed=42) # Shuffle the model's training data to add randomness

    # initialize loaders (train & test) from the pretokenized HF dataset
    train_loader = DataLoader(
        full_train_dataset, batch_size=best_params["batch_size"], shuffle=True,
        collate_fn=collator, num_workers=4, pin_memory=True,
        persistent_workers=True, prefetch_factor=2
    )
    test_loader = DataLoader(
        ds["test"], batch_size=min(2*best_params["batch_size"], 128), shuffle=False,
        collate_fn=collator, num_workers=4, pin_memory=True,
    )

    # Evaluate the performance of original model - first:
    train_original = evaluate_model(model, train_loader, device)
    test_original = evaluate_model(model, test_loader, device)

    # move original model to CPU for quantization post-evaluation
    model = model.to("cpu")

    # Apply dynamic quantization
    quantized_model = torch.quantization.quantize_dynamic(
        model, {torch.nn.Linear}, dtype=torch.qint8
    ).to("cpu")

    # Save quantized model's config & weights
    os.makedirs(quantized_path, exist_ok=True) # Create directory if it doesn't exist
    torch.save(quantized_model.state_dict(), os.path.join(quantized_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(quantized_path)

    # Evaluate the performance of quantized model - second:
    train_quantized = evaluate_model(quantized_model, train_loader, torch.device("cpu"))
    test_quantized = evaluate_model(quantized_model, test_loader, torch.device("cpu"))

    # Count number of parameters in both models - original & quantized:
    original_params = sum(p.numel() for p in model.parameters())
    quantized_params = sum(p.numel() for p in quantized_model.parameters())

    # Collect results into a DataFrame
    results = pd.DataFrame([{
        "original_params": original_params,
        "quantized_params": quantized_params,
        "param_reduction": original_params - quantized_params,
        "param_ratio": quantized_params / original_params,
        # Accuracy
        "train_accuracy_original": train_original["accuracy"],
        "test_accuracy_original": test_original["accuracy"],
        "train_accuracy_quantized": train_quantized["accuracy"],
        "test_accuracy_quantized": test_quantized["accuracy"],
        "train_accuracy_drop": train_original["accuracy"] - train_quantized["accuracy"],
        "test_accuracy_drop": test_original["accuracy"] - test_quantized["accuracy"],

        # F1-Score
        "train_f1_original": train_original["f1"],
        "test_f1_original": test_original["f1"],
        "train_f1_quantized": train_quantized["f1"],
        "test_f1_quantized": test_quantized["f1"],
        "train_f1_drop": train_original["f1"] - train_quantized["f1"],
        "test_f1_drop": test_original["f1"] - test_quantized["f1"],

        # Precision
        "train_precision_original": train_original["precision"],
        "test_precision_original": test_original["precision"],
        "train_precision_quantized": train_quantized["precision"],
        "test_precision_quantized": test_quantized["precision"],
        "train_precision_drop": train_original["precision"] - train_quantized["precision"],
        "test_precision_drop": test_original["precision"] - test_quantized["precision"],

        # Recall
        "train_recall_original": train_original["recall"],
        "test_recall_original": test_original["recall"],
        "train_recall_quantized": train_quantized["recall"],
        "test_recall_quantized": test_quantized["recall"],
        "train_recall_drop": train_original["recall"] - train_quantized["recall"],
        "test_recall_drop": test_original["recall"] - test_quantized["recall"],
    }], index=[model_name])

    return results

In [None]:
# Quantizing all 4 models with their corresponding batch sizes (typed manually!)
model_configs = {
    "BERTweet-Base (rec4)": ("best_model_bertweet_base_rec4", 128),
    "BERTweet-Base (rec5 - HF)": ("best_model_bertweet_base_rec5", 64),
    "RoBERTa-Base-Tweet (rec4)": ("best_model_roberta_base_tweet_rec4", 128),
    "RoBERTa-Base-Tweet (rec5 - HF)": ("best_model_roberta_base_tweet_rec5", 128)
}

all_results = []

for model_name, (model_name_dir, batch_size) in model_configs.items():
    print(f"\nPost-Training Quantization Results for {model_name}:")
    results_df = quantize_evaluate_and_compare(model_name, model_name_dir, {"batch_size": batch_size})
    results_df.index.name = "model_name"
    all_results.append(results_df)
    display(results_df)

# Concatenate into one DataFrame
all_results_df = pd.concat(all_results, ignore_index=False)

In [None]:
# Display quantization results over all 4 models
display(all_results_df)

# Save for future use
save_path = f"{quant_root}/quantization_results.csv"
all_results_df.to_csv(save_path, index=True)
print(f"\nAll post-training quantization results saved to: {save_path}")

## **Technique (2) - Pruning**

As a model compression technique, **Pruning reduces model size and speeds up inference by setting "unimportant" weights to 0.**. In this project, we proceeded implementing **Unstructured global Pruning - setting a portion** (40% by default) **of trained model weights with the smallest magnitudes (in absolute values) to 0**. The function below applies **globally (on ALL LINEAR / ALL LAYERS**, depending on user need), **on a fine-tuned HF model** (which has gone through the final training phases above), i.e. it looks for the portion of weights with the smallest magnitudes (in absolute values) and prunes them - sets them to 0. It then evaluates the pruned model on the test set, and saves the quantized version.

In [None]:
# Critical roots
# basic_drive_path = "/content/drive/MyDrive" # USER CAN CHANGE IT IF HE DOESN'T WORK IN DRIVE AND DOWNLOADS FROM DRIVE THE Project_COVID_NLP folder!! (under # but the hashtag sign # can be removed if needed)
project_root = f"{basic_drive_path}/Project_COVID_NLP" # Root project folder
model_root   = f"{project_root}/Model_Weights"

# Define pruned_root inside the project, for all pruned weights
prune_root = f"{project_root}/Pruned_Model_Weights"

In [None]:
# This function prunes a fine-tuned HF model, evaluates it on training & test sets, compares its performance with the previous model's, and saves the pruned version.
# is_linear = a boolean variable set by the user whether global unstructured pruning of is desired only across the linear layers, or across all model weights. False by default.
def prune_evaluate_and_compare(model_name, model_name_dir, best_params, is_linear = False):

  # Define original model path (trained weights) and pruned save path
    model_path     = f"{model_root}/{model_name_dir}"
    pruned_path = f"{prune_root}/{model_name_dir}_pruned"

    # Select correct pretokenized dataset
    if "roberta" in model_name_dir.lower():
        pretokenized_dir = "data/tokenized_twitter_roberta_base" # the folder for saving the model
    else:
        pretokenized_dir = "data/tokenized_bertweet_base" # the folder for saving the model

    # Load model & tokenizer
    model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # safety: correct dtypes + torch output
    ds = load_from_disk(pretokenized_dir) # Loads the Arrow-backed HF DatasetDict that are defines later on in the Pre-tokenization part
    for split in ds:
        ds[split] = ds[split].cast_column("input_ids", Sequence(Value("int64")))
        ds[split] = ds[split].cast_column("attention_mask", Sequence(Value("int64")))  # or "bool"
        ds[split] = ds[split].cast_column("labels", Value("int64"))
        ds[split].set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    # keep dynamic padding (no tokenization here—collator only pads per batch)
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

    # Merge train + validation for final training
    full_train_dataset = concatenate_datasets([ds["train_reduced"], ds["validation"]])
    full_train_dataset = full_train_dataset.shuffle(seed=42) # Shuffle the model's training data to add randomness

    # initialize loaders (train & test) from the pretokenized HF dataset
    train_loader = DataLoader(
        full_train_dataset, batch_size=best_params["batch_size"], shuffle=True,
        collate_fn=collator, num_workers=4, pin_memory=True,
        persistent_workers=True, prefetch_factor=2
    )
    test_loader = DataLoader(
        ds["test"], batch_size=min(2*best_params["batch_size"], 128), shuffle=False,
        collate_fn=collator, num_workers=4, pin_memory=True,
    )

    # Evaluate the performance of original model - first:
    train_original = evaluate_model(model, train_loader, device)
    test_original = evaluate_model(model, test_loader, device)

    # Define the pruned model by reloading the original model
    pruned_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)

    # Collect layers to be pruned:
    if is_linear:
      parameters_to_prune = [(m, "weight") for m in pruned_model.modules() if isinstance(m, nn.Linear)] # if is_linear == True -> Unstructured-global-pruning only across linear layers
    else:
      parameters_to_prune = [(m, "weight") for m in pruned_model.modules() if hasattr(m, "weight")] # otherwise -> Unstructured-global-pruning only across linear layers

    # Apply pruning (global, unstructured)
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=0.4 # set by default. Essentially, 40% of parameters of model's (linear / all, user-dependent as noted above) layers with the smallest magnitudes (absolute values) would be pruned - set to 0.
    )

    for m, n in parameters_to_prune:
      prune.remove(m, n) # removing pruned weights (already set to 0) from the pruned model, to observe the actual parameter reduction.

    # Save pruned model's config & weights
    os.makedirs(pruned_path, exist_ok=True) # Create directory if it doesn't exist
    pruned_model.save_pretrained(pruned_path)
    # torch.save(pruned_model.state_dict(), os.path.join(pruned_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(pruned_path)

    # Evaluate the performance of pruned model - second:
    train_pruned = evaluate_model(pruned_model, train_loader, device)
    test_pruned = evaluate_model(pruned_model, test_loader, device)

    # Count number of parameters in both models - original & pruned:
    original_params = sum(p.numel() for p in model.parameters())
    pruned_params = sum(torch.count_nonzero(p).item() for p in pruned_model.parameters())

    # Collect results into a DataFrame
    results = pd.DataFrame([{
        "original_params": original_params,
        "pruned_params": pruned_params,
        "param_reduction": original_params - pruned_params,
        "param_ratio": pruned_params / original_params,
        # Accuracy
        "train_accuracy_original": train_original["accuracy"],
        "test_accuracy_original": test_original["accuracy"],
        "train_accuracy_pruned": train_pruned["accuracy"],
        "test_accuracy_pruned": test_pruned["accuracy"],
        "train_accuracy_drop": train_original["accuracy"] - train_pruned["accuracy"],
        "test_accuracy_drop": test_original["accuracy"] - test_pruned["accuracy"],

        # F1-Score
        "train_f1_original": train_original["f1"],
        "test_f1_original": test_original["f1"],
        "train_f1_pruned": train_pruned["f1"],
        "test_f1_pruned": test_pruned["f1"],
        "train_f1_drop": train_original["f1"] - train_pruned["f1"],
        "test_f1_drop": test_original["f1"] - test_pruned["f1"],

        # Precision
        "train_precision_original": train_original["precision"],
        "test_precision_original": test_original["precision"],
        "train_precision_pruned": train_pruned["precision"],
        "test_precision_pruned": test_pruned["precision"],
        "train_precision_drop": train_original["precision"] - train_pruned["precision"],
        "test_precision_drop": test_original["precision"] - test_pruned["precision"],

        # Recall
        "train_recall_original": train_original["recall"],
        "test_recall_original": test_original["recall"],
        "train_recall_pruned": train_pruned["recall"],
        "test_recall_pruned": test_pruned["recall"],
        "train_recall_drop": train_original["recall"] - train_pruned["recall"],
        "test_recall_drop": test_original["recall"] - test_pruned["recall"],
    }], index=[model_name])

    return results

In [None]:
# Pruning all 4 models with their corresponding batch sizes (typed manually!), considering ONLY LINEAR layers
model_configs = {
    "BERTweet-Base (rec4)": ("best_model_bertweet_base_rec4", 128),
    "BERTweet-Base (rec5 - HF)": ("best_model_bertweet_base_rec5", 64),
    "RoBERTa-Base-Tweet (rec4)": ("best_model_roberta_base_tweet_rec4", 128),
    "RoBERTa-Base-Tweet (rec5 - HF)": ("best_model_roberta_base_tweet_rec5", 128)
}

all_results = []

for model_name, (model_name_dir, batch_size) in model_configs.items():
    print(f"\nUnstructured global Pruning Results for {model_name} (considering linear layers only):")
    results_df = prune_evaluate_and_compare(model_name, model_name_dir, {"batch_size": batch_size}, is_linear = True)
    results_df.index.name = "model_name"
    all_results.append(results_df)
    display(results_df)

# Concatenate into one DataFrame
all_results_df = pd.concat(all_results, ignore_index=False)

In [None]:
# Display pruning results over all 4 models - "LINEAR CASE"
display(all_results_df)

# Save for future use
save_path = f"{prune_root}/pruning_results_linear.csv"
all_results_df.to_csv(save_path, index=True)
print(f"\nAll unstructured global pruning results (considering linear layers only) saved to: {save_path} ")

In [None]:
# Pruning all 4 models with their corresponding batch sizes (typed manually!), considering ALL model layers
all_results = []

for model_name, (model_name_dir, batch_size) in model_configs.items():
    print(f"\nUnstructured global Pruning Results for {model_name} (considering all model layers):")
    results_df = prune_evaluate_and_compare(model_name, model_name_dir, {"batch_size": batch_size}, is_linear = False)
    results_df.index.name = "model_name"
    all_results.append(results_df)
    display(results_df)

# Concatenate into one DataFrame
all_results_df = pd.concat(all_results, ignore_index=False)

In [None]:
# Display pruning results over all 4 models - "GENERALIZED CASE"
display(all_results_df)

# Save for future use
save_path = f"{prune_root}/pruning_results_generalized.csv"
all_results_df.to_csv(save_path, index=True)
print(f"\nAll unstructured global pruning results (considering all model layers) saved to: {save_path}")

## **Technique (3) - Knowledge-Distillation (KD)**

As a compression technique, **Knowledge Distillation** reduces model size by **training a compact student model** (compact = with much less parameters) **to imitate a stronger teacher model by matching the teacher’s soft predictions while still learning from the gold labels**. In our pipeline the fine-tuned teacher from Model_Weights is frozen and evaluated in eval() mode, and the student (e.g., arampacha/roberta-tiny) is optimized with the mixed objective α·CE(y, s) + (1−α)·T²·KL(softmax(t/T) || log_softmax(s/T)), where T is the temperature and α controls the balance between label supervision and teacher guidance. We train on the pre-tokenized datasets, log train and test metrics for both teacher and student to Weights & Biases, and save the best student checkpoint under KD_Model_Weights, along with a CSV of results for later comparison.

In [None]:
# Critical roots
# basic_drive_path = "/content/drive/MyDrive" # USER CAN CHANGE IT IF HE DOESN'T WORK IN DRIVE AND DOWNLOADS FROM DRIVE THE Project_COVID_NLP folder!! (under # but the hashtag sign # can be removed if needed)
project_root = f"{basic_drive_path}/Project_COVID_NLP" # Root project folder
model_root   = f"{project_root}/Model_Weights"

# Define KD_root inside the project, for all KD weights
KD_root = f"{project_root}/KD_Model_Weights"

In [None]:
from transformers import TrainerCallback

class TrainEvalCallback(TrainerCallback):
    def __init__(self, trainer, train_dataset):
        self.trainer = trainer
        self.train_dataset = train_dataset

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # HF already computed eval_* on the test set this epoch.
        if metrics and wandb.run is not None:
            out = {}
            for k, v in metrics.items():
                if isinstance(v, (int, float)):
                    key = k[5:] if k.startswith("eval_") else k
                    out[f"test/student_{key}"] = float(v)
            wandb.log(out)
        return control

    def on_epoch_end(self, args, state, control, **kwargs):
        # Compute train metrics once per epoch
        train_metrics = self.trainer.evaluate(
            eval_dataset=self.train_dataset,
            metric_key_prefix="train"  # -> train_loss, train_accuracy, ...
        )

        # 1) HF log keeps the pretty table
        self.trainer.log(train_metrics)

        # 1a) ALSO put 'loss' so the "Training Loss" column isn't "No log"
        if "train_loss" in train_metrics:
            self.trainer.log({"loss": float(train_metrics["train_loss"])})

        # 2) W&B: log under train/student_* (no explicit step)
        if wandb.run is not None:
            out = {}
            for k, v in train_metrics.items():
                if isinstance(v, (int, float)):
                    key = k[6:] if k.startswith("train_") else k
                    out[f"train/student_{key}"] = float(v)
            wandb.log(out)
        return control

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        # freeze + eval teacher
        self.teacher = teacher_model
        self.teacher.eval()
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.temperature = float(temperature)
        self.alpha = float(alpha)

    @torch.no_grad()
    def _teacher_forward(self, **inputs):
        # teacher never sees labels
        inputs = {k: v for k, v in inputs.items() if k != "labels"}
        return self.teacher(**inputs)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # 1) remove labels so student model doesn't compute internal CE
        labels = inputs.pop("labels")
        if labels.dtype != torch.long:
            labels = labels.long()
        # cheap guard; safe even if always 0..4
        if torch.any(labels.lt(0)) or torch.any(labels.gt(4)):
            labels = labels.clamp_(0, 4)

        # 2) student forward (no labels)
        outputs_student = model(**inputs)

        # 3) teacher forward (eval, no grads)
        with torch.no_grad():
            outputs_teacher = self._teacher_forward(**inputs)

        # 4) KD loss
        t = self.temperature
        loss_ce = F.cross_entropy(outputs_student.logits, labels)
        loss_kl = F.kl_div(
            F.log_softmax(outputs_student.logits / t, dim=-1),
            F.softmax(outputs_teacher.logits / t, dim=-1),
            reduction="batchmean"
        ) * (t * t)

        loss = self.alpha * loss_ce + (1.0 - self.alpha) * loss_kl
        return (loss, outputs_student) if return_outputs else loss

In [None]:
def distill_evaluate_and_compare(model_name, model_name_dir, best_params, student_model_name, alpha=0.5, temperature=2.0, num_epochs=5):

    # Paths
    teacher_path = f"{model_root}/{model_name_dir}"
    student_slug = student_model_name.replace("/", "-")
    KD_path      = f"{KD_root}/{model_name_dir}_distilled_student_{student_slug}"
    os.makedirs(KD_path, exist_ok=True)

    # Select correct pretokenized dataset
    if "roberta" in model_name_dir.lower():
        pretokenized_dir = "data/tokenized_twitter_roberta_base"
    else:
        pretokenized_dir = "data/tokenized_bertweet_base"

    # safety: correct dtypes + torch output
    ds = load_from_disk(pretokenized_dir)
    for split in ds:
        ds[split] = ds[split].cast_column("input_ids", Sequence(Value("int64")))
        ds[split] = ds[split].cast_column("attention_mask", Sequence(Value("int64")))
        ds[split] = ds[split].cast_column("labels", Value("int64"))
        ds[split].set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    # Merge train + validation for final training
    full_train_dataset = concatenate_datasets([ds["train_reduced"], ds["validation"]]).shuffle(seed=42)
    test_dataset = ds["test"]

    # Load Teacher + tokenizer
    teacher = AutoModelForSequenceClassification.from_pretrained(teacher_path).to(device)
    teacher.eval()
    for p in teacher.parameters():
        p.requires_grad = False
    tokenizer = AutoTokenizer.from_pretrained(teacher_path)

    # Load Student + tokenizer (vocab aligned)
    student = AutoModelForSequenceClassification.from_pretrained(student_model_name, num_labels=5, ignore_mismatched_sizes=True).to(device)
    student_tokenizer = tokenizer
    student.resize_token_embeddings(len(student_tokenizer))

    # keep dynamic padding (no tokenization here—collator only pads per batch)
    collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

    # W&B init (project name includes teacher dir + student)
    if wandb.run is not None:
        wandb.finish()
    wandb.init(
        project=f"KD_{model_name}__student_{student_slug}_21.8.2025",
        entity="idoshahar96-tel-aviv-university",
        config={
            "learning_rate": best_params["learning_rate"],
            "weight_decay": best_params["weight_decay"],
            "batch_size": best_params["batch_size"],
            "num_layers_finetune": best_params.get("num_layers_finetune", None),
            "teacher_path": teacher_path,
            "student_model": student_model_name,
            "alpha": alpha,
            "temperature": temperature,
            "epochs": num_epochs,
        },
        name=f"KD_{model_name}__{student_slug}",
        reinit=True
    )

    # TrainingArguments for the Hugging Face Trainer
    training_args = TrainingArguments(
        output_dir=KD_path, # where checkpoints will be saved
        per_device_train_batch_size=best_params["batch_size"],
        per_device_eval_batch_size=best_params["batch_size"],
        learning_rate=best_params["learning_rate"],
        weight_decay=best_params["weight_decay"],
        num_train_epochs=num_epochs, # Setting the number of epochs for training - 5
        eval_strategy="epoch",        # evaluate at the end of each epoch
        save_strategy="epoch",        # save a checkpoint at the end of each epoch
        logging_strategy="epoch",     # log metrics at the end of each epoch
        load_best_model_at_end=True,  # reload the best checkpoint (based on metric_for_best_model)
        metric_for_best_model="accuracy", # optimize w.r.t accuracy
        greater_is_better=True,
        save_total_limit=1,
        remove_unused_columns=False,
        label_names=["labels"],
        report_to="wandb",
    )

    # Trainer
    trainer_distill = DistillationTrainer(
        model=student,
        teacher_model=teacher,
        args=training_args,
        train_dataset=full_train_dataset,
        eval_dataset=test_dataset,
        data_collator=collator,
        compute_metrics=compute_metrics,
        temperature=temperature,
        alpha=alpha,
    )

    # per-epoch: TEST (from on_evaluate) then TRAIN (from on_epoch_end)
    trainer_distill.add_callback(TrainEvalCallback(trainer_distill, full_train_dataset))

    # TRAIN - KD
    trainer_distill.train()
    print("\nDistillation complete. Student trained & best model loaded.")

    # Final metrics (TEST first, then TRAIN) for teacher & student
    def _eval_with(model_to_eval, dataset, prefix):
        tmp = Trainer(
            model=model_to_eval,
            args=training_args,
            eval_dataset=dataset,
            data_collator=collator,
            compute_metrics=compute_metrics,
        )
        return tmp.evaluate(metric_key_prefix=prefix)

    # Teacher metrics
    teacher_test   = _eval_with(teacher, test_dataset,        "teacher_test")
    teacher_train  = _eval_with(teacher, full_train_dataset, "teacher_train")

    # Student metrics
    student_test   = trainer_distill.evaluate(eval_dataset=test_dataset,        metric_key_prefix="student_test")
    student_train  = trainer_distill.evaluate(eval_dataset=full_train_dataset, metric_key_prefix="student_train")

    # Log to W&B (TEST first, then TRAIN) so panels order naturally
    if wandb.run is not None:
        def log_group(prefix, who, d):
            out = {}
            for k, v in d.items():
                if isinstance(v, (int, float)):
                    key = k.split(f"{who}_", 1)[-1] if f"{who}_" in k else k
                    out[f"{prefix}/{who}_{key}"] = float(v)
            wandb.log(out)

        log_group("test",  "teacher", teacher_test)
        log_group("test",  "student", student_test)
        log_group("train", "teacher", teacher_train)
        log_group("train", "student", student_train)

    # Count params & save student model (best-epoch)
    teacher_params = sum(p.numel() for p in teacher.parameters())
    student_params = sum(p.numel() for p in student.parameters())
    print(f"Teacher params: {teacher_params:,}")
    print(f"Student params: {student_params:,}")
    trainer_distill.model.save_pretrained(KD_path)
    student_tokenizer.save_pretrained(KD_path)
    print(f"Best student model saved to {KD_path}")
    wandb.finish()

    # Build output DataFrame (params statistics, then metrics, then drops in metrics while comparing the models)
    def g(d, k):
        v = d.get(k, float("nan"))
        try:
            return float(v)
        except Exception:
            return float("nan")

    def pack(metric_name, t_train, t_test, s_train, s_test):
        tt = g(t_train, f"teacher_train_{metric_name}")
        te = g(t_test,  f"teacher_test_{metric_name}")
        st = g(s_train, f"student_train_{metric_name}")
        se = g(s_test,  f"student_test_{metric_name}")
        return {
            f"teacher_train_{metric_name}": tt,
            f"teacher_test_{metric_name}":  te,
            f"student_train_{metric_name}": st,
            f"student_test_{metric_name}":  se,
            f"drop_train_{metric_name}":    tt - st,  # teacher − student
            f"drop_test_{metric_name}":     te - se,
        }

    row = {
        "teacher_params": teacher_params,
        "student_params": student_params,
        "param_reduction": teacher_params - student_params,
        "param_ratio": student_params / teacher_params,
    }
    for m in ["accuracy", "f1", "precision", "recall"]:
        row.update(pack(m, teacher_train, teacher_test, student_train, student_test))

    results = pd.DataFrame([row], index=[model_name])
    return results

In [None]:
# Performing Knowledge-Distillation (KD) over all 4 models with their corresponding best-params (typed manually!)
model_configs = {
    "BERTweet-Base (rec4)": ("best_model_bertweet_base_rec4", {'learning_rate': 0.0001184412471705182, 'weight_decay': 1.2699696348040995e-05, 'patience': 10, 'batch_size': 128, 'num_layers_finetune': 3}),
    "BERTweet-Base (rec5 - HF)": ("best_model_bertweet_base_rec5", {'learning_rate': 7.668855564109297e-05, 'weight_decay': 4.8978169582912055e-06, 'patience': 9, 'batch_size': 64, 'num_layers_finetune': 3, 'lr_scheduler_type': 'linear'}),
    "RoBERTa-Base-Tweet (rec4)": ("best_model_roberta_base_tweet_rec4", {'learning_rate': 0.0003834791389042033, 'weight_decay': 2.88286253103848e-06, 'patience': 7, 'batch_size': 128, 'num_layers_finetune': 3}),
    "RoBERTa-Base-Tweet (rec5 - HF)": ("best_model_roberta_base_tweet_rec5", {'learning_rate': 0.0000860370374400373, 'weight_decay': 0.00008459884214639005, 'patience': 10, 'batch_size': 128, 'num_layers_finetune': 3, 'lr_scheduler_type': 'polynomial'})
}

all_results = []
student_model_name = "arampacha/roberta-tiny" # Example student model - RoBERTa-Tiny (truly)

for model_name, (model_name_dir, best_params) in model_configs.items():
    print(f"\nKnowledge-Distillation (KD) Results for TEACHER: {model_name}, STUDENT: {student_model_name} (5 epochs):")
    results_df = distill_evaluate_and_compare(model_name, model_name_dir, best_params, student_model_name=student_model_name)
    results_df.index.name = "model_name"
    all_results.append(results_df)
    display(results_df)

# Concatenate into one DataFrame
all_results_df = pd.concat(all_results, ignore_index=False)

In [None]:
# Add student name column (redundant because the student model name would be clear from the CSV file directory, but we wanted to make the results even clearer, by displaying the student model name even more explicitly)
all_results_df["student_model_name"] = "arampacha-roberta-tiny"

# Reorder so "student_model_name" is right after the index
cols = all_results_df.columns.tolist()
cols = ["student_model_name"] + [c for c in cols if c != "student_model_name"]
all_results_df = all_results_df[cols]
all_results_df["student_model_name"] = "arampacha-roberta-tiny"
# Display Knowledge-Distillation (KD) results over all 4 models
display(all_results_df)

student_model_name_for_csv = "arampacha_roberta_tiny"

# Save for future use
save_path = f"{KD_root}/KD_results_{student_model_name_for_csv}.csv"
all_results_df.to_csv(save_path, index=True)
print(f"\nAll Knowledge-Distillation (KD) results saved to: {save_path}")