<a href="https://colab.research.google.com/github/daisysong76/AI--Machine--learning/blob/main/Task_Aware_Distillation_and_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Task-Aware Distillation and Pruning with a Transformer Model (BERT)
Here, we'll use Hugging Face’s DistilBERT as an example of model distillation and torch.nn.utils.prune for pruning. Suppose we're distilling and pruning a BERT model for a text classification task on a custom dataset.

1. Distillation: Train a DistilBERT model on a task-specific dataset using knowledge distillation from a pre-trained BERT model.
2. Pruning: Prune the smaller DistilBERT model for further efficiency.
Step 1: Set Up Libraries

In [1]:
!pip install torch transformers dataset

Collecting dataset
  Downloading dataset-1.6.2-py2.py3-none-any.whl.metadata (1.9 kB)
Collecting sqlalchemy<2.0.0,>=1.3.2 (from dataset)
  Downloading SQLAlchemy-1.4.54-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting alembic>=0.6.2 (from dataset)
  Downloading alembic-1.13.3-py3-none-any.whl.metadata (7.4 kB)
Collecting banal>=1.0.1 (from dataset)
  Downloading banal-1.0.6-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting Mako (from alembic>=0.6.2->dataset)
  Downloading Mako-1.3.6-py3-none-any.whl.metadata (2.9 kB)
Downloading dataset-1.6.2-py2.py3-none-any.whl (18 kB)
Downloading alembic-1.13.3-py3-none-any.whl (233 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.2/233.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading banal-1.0.6-py2.py3-none-any.whl (6.1 kB)
Downloading SQLAlchemy-1.4.54-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux

In [None]:
import torch
from transformers import BertForSequenceClassification, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# Load a sample dataset for text classification
dataset = load_dataset("imdb")


Load the Teacher Model (BERT) and Student Model (DistilBERT)

In [None]:
# Load pre-trained BERT as the teacher model
teacher_model = BertForSequenceClassification.from_pretrained("bert-base-uncased")

# Load DistilBERT as the student model
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")


Define Distillation Training Loop

Define the training arguments and set up the Trainer for knowledge distillation. The student model will learn by matching its logits with the teacher's logits.

In [None]:
from torch.nn import functional as F

# Define a custom distillation loss function
def distillation_loss(student_outputs, teacher_outputs, labels, alpha=0.5, temperature=2.0):
    soft_teacher_logits = F.log_softmax(teacher_outputs.logits / temperature, dim=-1)
    soft_student_logits = F.log_softmax(student_outputs.logits / temperature, dim=-1)

    # Cross-entropy between soft labels
    distill_loss = F.kl_div(soft_student_logits, soft_teacher_logits, reduction="batchmean") * (temperature ** 2)

    # Cross-entropy with ground truth labels
    hard_loss = F.cross_entropy(student_outputs.logits, labels)

    # Combine losses with weighting factor alpha
    return alpha * distill_loss + (1 - alpha) * hard_loss


Define a custom Trainer class to use the distillation loss:

In [None]:
from transformers import TrainerCallback, TrainerState, TrainerControl

class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)

        # Get teacher model predictions for distillation
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)

        # Calculate the distillation loss
        loss = distillation_loss(outputs, teacher_outputs, labels)
        return (loss, outputs) if return_outputs else loss


Fine-Tune the Distilled Model

In [None]:
training_args = TrainingArguments(
    output_dir="./distilled_model",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
)

trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"]
)

# Start distillation training
trainer.train()


Prune the Fine-Tuned Model

Once the student model has been fine-tuned, we’ll apply pruning. Here, we’ll prune the linear layers in DistilBERT by removing low-magnitude weights.

In [None]:
import torch.nn.utils.prune as prune

# Function to apply pruning to all linear layers
def apply_pruning(model, pruning_amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=pruning_amount)
            prune.remove(module, 'weight')  # Make pruning permanent

# Apply pruning to the student model
apply_pruning(student_model, pruning_amount=0.3)


Save and Evaluate the Pruned Model
After pruning, save and evaluate the model on the test set to ensure it performs well despite the reduced size.

In [None]:
# Save the pruned model
student_model.save_pretrained("./pruned_distilled_model")

# Evaluate the pruned model
eval_results = trainer.evaluate()
print(f"Pruned Model Evaluation Results: {eval_results}")
