<a href="https://colab.research.google.com/github/daisysong76/AI--Machine--learning/blob/main/Quantization%2C_Pruning%2C_and_Distillation_Optimizing_BERT_for_Intent_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch


In [None]:
import torch
from transformers import BertForSequenceClassification, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# Load BERT (teacher) and DistilBERT (student) models
teacher_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)

# Load a sample intent classification dataset (e.g., banking77 from Hugging Face)
dataset = load_dataset("banking77", split="train[:5%]")  # For demo purposes


In [None]:
from torch.nn import functional as F

# Custom distillation loss function
def distillation_loss(student_outputs, teacher_outputs, labels, alpha=0.5, temperature=2.0):
    # Softmax temperature for distillation
    soft_teacher_logits = F.log_softmax(teacher_outputs.logits / temperature, dim=-1)
    soft_student_logits = F.log_softmax(student_outputs.logits / temperature, dim=-1)

    # KL divergence loss
    distill_loss = F.kl_div(soft_student_logits, soft_teacher_logits, reduction="batchmean") * (temperature ** 2)

    # Cross-entropy with ground truth labels
    hard_loss = F.cross_entropy(student_outputs.logits, labels)

    return alpha * distill_loss + (1 - alpha) * hard_loss

# Define a custom Trainer for distillation
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
        loss = distillation_loss(outputs, teacher_outputs, labels)
        return (loss, outputs) if return_outputs else loss

# Set up Trainer with distillation loss
training_args = TrainingArguments(
    output_dir="./distilled_student",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
)
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()


In [None]:
import torch.nn.utils.prune as prune

# Function to prune linear layers
def prune_model(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=amount)
            prune.remove(module, 'weight')  # Make pruning permanent

# Apply pruning
prune_model(student_model, amount=0.3)


In [None]:
# Apply dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    student_model, {torch.nn.Linear}, dtype=torch.qint8
)


In [None]:
def evaluate(model, dataset):
    model.eval()
    correct, total = 0, 0
    for batch in dataset:
        inputs = tokenizer(batch["text"], return_tensors="pt", padding=True, truncation=True)
        labels = torch.tensor(batch["label"]).unsqueeze(0)
        outputs = model(**inputs)
        _, predicted = torch.max(outputs.logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    print(f"Accuracy: {correct / total * 100:.2f}%")

# Test the quantized model
evaluate(quantized_model, dataset)
