<a href="https://colab.research.google.com/github/joachimasare/tinyml-ondevice-sentiment/blob/main/Channel_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup and Dependencies

In [None]:
# imports needed for pytorch tinyBERT project

!pip install scikit-learn
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
from torch.nn import CrossEntropyLoss, MSELoss
from sklearn.metrics import accuracy_score
import copy
from typing import Union, List

import csv
import logging
import os
import random
import sys



Resources and Links


*   [TinyBERT github](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT)
*   [BERT-base code ](https://github.com/google-research/bert?tab=readme-ov-file)
*   [TinyBERT pretrained model](https://huggingface.co/huawei-noah/TinyBERT_General_4L_312D)
*   [Dataset](https://github.com/nyu-mll/GLUE-baselines.git)







In [None]:
# downloading dataset

!git clone https://github.com/nyu-mll/GLUE-baselines.git
!python GLUE-baselines/download_glue_data.py --data_dir /content --tasks SST

Cloning into 'GLUE-baselines'...
remote: Enumerating objects: 891, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 891 (delta 1), reused 3 (delta 1), pack-reused 886 (from 1)[K
Receiving objects: 100% (891/891), 1.48 MiB | 6.55 MiB/s, done.
Resolving deltas: 100% (610/610), done.
Downloading and extracting SST...
	Completed!


In [None]:
# download tinyBERT source code and install dependencies

!git clone https://github.com/huawei-noah/Pretrained-Language-Model.git

Cloning into 'Pretrained-Language-Model'...
remote: Enumerating objects: 1253, done.[K
remote: Counting objects: 100% (280/280), done.[K
remote: Compressing objects: 100% (161/161), done.[K
remote: Total 1253 (delta 173), reused 120 (delta 119), pack-reused 973 (from 1)[K
Receiving objects: 100% (1253/1253), 29.72 MiB | 17.21 MiB/s, done.
Resolving deltas: 100% (540/540), done.


In [None]:
%cd Pretrained-Language-Model/TinyBERT
!pip install -r requirements.txt

/content/Pretrained-Language-Model/TinyBERT
Collecting boto3 (from -r requirements.txt (line 4))
  Downloading boto3-1.35.81-py3-none-any.whl.metadata (6.7 kB)
Collecting botocore<1.36.0,>=1.35.81 (from boto3->-r requirements.txt (line 4))
  Downloading botocore-1.35.81-py3-none-any.whl.metadata (5.7 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from boto3->-r requirements.txt (line 4))
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Collecting s3transfer<0.11.0,>=0.10.0 (from boto3->-r requirements.txt (line 4))
  Downloading s3transfer-0.10.4-py3-none-any.whl.metadata (1.7 kB)
Downloading boto3-1.35.81-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.2/139.2 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading botocore-1.35.81-py3-none-any.whl (13.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m81.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jmespath-1.0.1-py3-none-any.whl (2

In [None]:
SEED = 42
torch.backends.cudnn.deterministic = True
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7a333e587150>

In [None]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('Installing fast-pytorch-kmeans...')
! pip install fast-pytorch-kmeans 1>/dev/null
print('All required packages have been successfully installed!')

Installing torchprofile...
Installing fast-pytorch-kmeans...
All required packages have been successfully installed!


In [None]:
from torchprofile import profile_macs
from torch import nn

In [None]:
# downloading BERT-base code

!wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip
!unzip cased_L-12_H-768_A-12.zip
!cp cased_L-12_H-768_A-12/bert_config.json cased_L-12_H-768_A-12/config.json # must rename bert_config to config

BERT_BASE_DIR = 'cased_L-12_H-768_A-12'

--2024-12-14 20:05:44--  https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.69.207, 64.233.181.207, 142.251.183.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.69.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 404261442 (386M) [application/zip]
Saving to: ‘cased_L-12_H-768_A-12.zip’


2024-12-14 20:05:48 (88.7 MB/s) - ‘cased_L-12_H-768_A-12.zip’ saved [404261442/404261442]

Archive:  cased_L-12_H-768_A-12.zip
   creating: cased_L-12_H-768_A-12/
  inflating: cased_L-12_H-768_A-12/bert_model.ckpt.meta  
  inflating: cased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  
  inflating: cased_L-12_H-768_A-12/vocab.txt  
  inflating: cased_L-12_H-768_A-12/bert_model.ckpt.index  
  inflating: cased_L-12_H-768_A-12/bert_config.json  


In [None]:
# cloning TinyBert pretrained models

!git clone https://huggingface.co/huawei-noah/TinyBERT_General_4L_312D

STUDENT_CONFIG_DIR = '/content/Pretrained-Language-Model/TinyBERT/TinyBERT_General_4L_312D'

Cloning into 'TinyBERT_General_4L_312D'...
remote: Enumerating objects: 24, done.[K
remote: Total 24 (delta 0), reused 0 (delta 0), pack-reused 24 (from 1)[K
Unpacking objects: 100% (24/24), 111.20 KiB | 3.83 MiB/s, done.
Filtering content: 100% (2/2), 114.58 MiB | 23.45 MiB/s, done.


In [None]:
def get_model_size(model: nn.Module, data_width=32):
    """
    calculate the model size in bits
    :param data_width: #bits per element
    """
    num_elements = 0
    for param in model.parameters():
        num_elements += param.numel()
    return num_elements * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [None]:
from transformer.modeling import TinyBertForPreTraining, BertModel, TinyBertForSequenceClassification

In [None]:
# Setting up student model

STUDENT_CONFIG_DIR = '/content/Pretrained-Language-Model/TinyBERT/TinyBERT_General_4L_312D'
BERT_BASE_DIR = '/content/Pretrained-Language-Model/TinyBERT/cased_L-12_H-768_A-12'

student_model = TinyBertForPreTraining.from_scratch(STUDENT_CONFIG_DIR)
teacher_model = BertModel.from_scratch(BERT_BASE_DIR)

num_labels = 2
student_model = TinyBertForSequenceClassification.from_pretrained(STUDENT_CONFIG_DIR, num_labels=num_labels)

In [None]:
# Getting sizes of models

student_model_size = get_model_size(student_model)
teacher_model_size = get_model_size(teacher_model)

print("Student model size: ", student_model_size/MiB, "MiB")
print("Teacher model size: ", teacher_model_size/MiB, "MiB")

Student model size:  55.661231994628906 MiB
Teacher model size:  413.1708984375 MiB


In [None]:
student_model

TinyBertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=312, out_features=312, bias=True)
              (key): Linear(in_features=312, out_features=312, bias=True)
              (value): Linear(in_features=312, out_features=312, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=312, out_features=312, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=

In [None]:
for n, m in student_model.named_modules():
    if isinstance(m, nn.Linear):
        print(n)
        print(m.weight.data.shape)

bert.encoder.layer.0.attention.self.query
torch.Size([312, 312])
bert.encoder.layer.0.attention.self.key
torch.Size([312, 312])
bert.encoder.layer.0.attention.self.value
torch.Size([312, 312])
bert.encoder.layer.0.attention.output.dense
torch.Size([312, 312])
bert.encoder.layer.0.intermediate.dense
torch.Size([1200, 312])
bert.encoder.layer.0.output.dense
torch.Size([312, 1200])
bert.encoder.layer.1.attention.self.query
torch.Size([312, 312])
bert.encoder.layer.1.attention.self.key
torch.Size([312, 312])
bert.encoder.layer.1.attention.self.value
torch.Size([312, 312])
bert.encoder.layer.1.attention.output.dense
torch.Size([312, 312])
bert.encoder.layer.1.intermediate.dense
torch.Size([1200, 312])
bert.encoder.layer.1.output.dense
torch.Size([312, 1200])
bert.encoder.layer.2.attention.self.query
torch.Size([312, 312])
bert.encoder.layer.2.attention.self.key
torch.Size([312, 312])
bert.encoder.layer.2.attention.self.value
torch.Size([312, 312])
bert.encoder.layer.2.attention.output.dense

In [None]:
for name, param in student_model.named_parameters():
  print(name, param.shape)

bert.embeddings.word_embeddings.weight torch.Size([30522, 312])
bert.embeddings.position_embeddings.weight torch.Size([512, 312])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 312])
bert.embeddings.LayerNorm.weight torch.Size([312])
bert.embeddings.LayerNorm.bias torch.Size([312])
bert.encoder.layer.0.attention.self.query.weight torch.Size([312, 312])
bert.encoder.layer.0.attention.self.query.bias torch.Size([312])
bert.encoder.layer.0.attention.self.key.weight torch.Size([312, 312])
bert.encoder.layer.0.attention.self.key.bias torch.Size([312])
bert.encoder.layer.0.attention.self.value.weight torch.Size([312, 312])
bert.encoder.layer.0.attention.self.value.bias torch.Size([312])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([312, 312])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([312])
bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([312])
bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([312])
bert.encoder

## Evaluation Setup

In [None]:
# declaring functions necessary for fine tuning and evaluating

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id, seq_length=None):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.seq_length = seq_length
        self.label_id = label_id

In [None]:
class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines

class Sst2Processor(DataProcessor):
    """Processor for the SST-2 data set (GLUE version)."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_aug_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train_aug.tsv")), "aug")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            label = line[1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


In [None]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def compute_metrics(task_name, preds, labels):
    assert len(preds) == len(labels)
    return {"acc": simple_accuracy(preds, labels)}

# evaluation function based on  task_distill.py --do_eval
def evaluate_tinybert(model, task_name, eval_dataloader,
            device, output_mode, eval_labels, num_labels):
    eval_loss = 0
    nb_eval_steps = 0
    preds = []

    model.eval()
    for batch_ in tqdm(eval_dataloader, desc="Evaluating"):
        batch_ = tuple(t.to(device) for t in batch_)
        with torch.no_grad():
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch_

            # ValueError: not enough values to unpack (expected 3, got 2)
            # logits, _, _ = model(input_ids, segment_ids, input_mask)
            # TODO: what is the model outputting? What
            logits, _ , _= model(input_ids, segment_ids, input_mask)


        # create eval loss and other metric required by the task
        if output_mode == "classification":
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
        elif output_mode == "regression":
            loss_fct = MSELoss()
            tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(
                preds[0], logits.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps

    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, eval_labels.numpy())
    result['eval_loss'] = eval_loss

    return result


output_mode = "classification"


In [None]:
from task_distill import convert_examples_to_features, get_tensor_data
from torch.utils.data import SequentialSampler

In [None]:
# building the evaluation dataloader

do_lower_case = False
data_dir = '/content/SST-2'
processor = Sst2Processor()
label_list = processor.get_labels()
num_labels = len(label_list)
max_seq_length = 128
eval_batch_size = 32
task_name = "sst2"
output_mode = "classification"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = TinyBertForSequenceClassification.from_pretrained(STUDENT_CONFIG_DIR, num_labels=num_labels)
student_model.to(device)

tokenizer = BertTokenizer.from_pretrained(STUDENT_CONFIG_DIR, do_lower_case=do_lower_case)
eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)


In [None]:
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    predictions = []
    true_labels = []

    with torch.no_grad():  # Disable gradient calculations during evaluation
        for batch in dataloader:

            # Handle cases where batch might have more than 4 elements
            if len(batch) == 4:
                input_ids, attention_masks, segment_ids, labels = batch
            else:  # Adjust this logic based on your data format
                input_ids, attention_masks, segment_ids, labels, *_ = batch
                # *_ unpacks the remaining elements to a throwaway variable _

            print(input_ids.shape)
            print(attention_masks.shape)
            print(segment_ids.shape)
            print(labels.shape)

            input_ids, attention_masks, segment_ids, labels = input_ids.to(device), attention_masks.to(device), segment_ids.to(device), labels.to(device)

            outputs = model(input_ids, attention_masks, segment_ids)
            _, predicted = torch.max(outputs, 1)  # Get predicted labels

            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(true_labels, predictions)
    return accuracy

## Finetuning

Note: Do run the Evaluation Setup to run this section

In [None]:
def train_tinybert(
    model,
    task_name,
    train_dataloader,
    eval_dataloader,
    device,
    output_mode,
    num_labels,
    eval_labels,
    optimizer=None,
    scheduler=None,
    epochs=3
):
    """
    Fine-tune a TinyBERT model with training and validation metrics history.

    Args:
        model: The TinyBERT model to be fine-tuned.
        task_name: Name of the task (used for metric computation).
        train_dataloader: DataLoader for training data.
        eval_dataloader: DataLoader for evaluation data.
        device: Device to train on (e.g., 'cpu' or 'cuda').
        output_mode: Output mode for the task ('classification' or 'regression').
        num_labels: Number of labels for classification tasks.
        eval_labels: Ground truth labels for the evaluation set.
        optimizer: Optimizer for training (default is AdamW).
        scheduler: Learning rate scheduler (optional).
        epochs: Number of training epochs.

    Returns:
        model: The fine-tuned model with a `history` attribute.
    """
    # Initialize optimizer if none is provided
    if optimizer is None:
        optimizer = AdamW(model.parameters(), lr=5e-5)

    # Initialize or extend model's history attribute
    if not hasattr(model, 'history'):
        model.history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    # Move model to the specified device
    model.to(device)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0  # To track total training loss
        nb_train_steps = 0  # To count the number of training steps
        correct_predictions = 0
        total_predictions = 0

        print(f"Epoch {epoch + 1}/{epochs}")

        # Training step
        for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
            # Move each tensor in the batch to the device
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch

            # Zero the gradients to prevent accumulation
            optimizer.zero_grad()

            # Forward pass
            logits, _, _ = model(input_ids, segment_ids, input_mask)

            # Compute loss
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                preds = torch.argmax(logits, dim=1)
            elif output_mode == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
                preds = logits.squeeze()
            else:
                raise ValueError(f"Unknown output mode: {output_mode}")

            # Update metrics
            if output_mode == "classification":
                correct_predictions += (preds == label_ids).sum().item()
                total_predictions += label_ids.size(0)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Update training metrics
            total_loss += loss.item()
            nb_train_steps += 1

        # Adjust learning rate with scheduler (if provided)
        if scheduler:
            scheduler.step()

        # Compute training metrics
        avg_train_loss = total_loss / nb_train_steps
        train_accuracy = correct_predictions / total_predictions if output_mode == "classification" else None
        print(f"Training loss: {avg_train_loss:.4f}")
        if train_accuracy is not None:
            print(f"Training accuracy: {train_accuracy:.4f}")

        # Evaluate the model
        eval_result = evaluate_tinybert(
            model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels
        )
        avg_val_loss = eval_result['eval_loss']
        val_accuracy = eval_result['acc']

        # Print validation results
        print(f"Validation loss: {avg_val_loss:.4f}")
        print(f"Validation accuracy: {val_accuracy:.4f}")

        # Update the model's history
        model.history['train_loss'].append(avg_train_loss)
        model.history['val_loss'].append(avg_val_loss)
        if train_accuracy is not None:
            model.history['train_acc'].append(train_accuracy)
        model.history['val_acc'].append(val_accuracy)

    return model

In [None]:
# building the evaluation dataloader

do_lower_case = False
data_dir = '/content/SST-2'
processor = Sst2Processor()
label_list = processor.get_labels()
num_labels = len(label_list)
max_seq_length = 128
eval_batch_size = 32
train_batch_size = 32
task_name = "sst2"
output_mode = "classification"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model = TinyBertForSequenceClassification.from_pretrained(STUDENT_CONFIG_DIR, num_labels=num_labels)
student_model.to(device)

tokenizer = BertTokenizer.from_pretrained(STUDENT_CONFIG_DIR, do_lower_case=do_lower_case)
eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)

train_examples = processor.get_train_examples(data_dir)
train_features = convert_examples_to_features(train_examples, label_list, max_seq_length, tokenizer, output_mode)
train_data, train_labels = get_tensor_data(output_mode, train_features)
train_sampler = SequentialSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)

# only fine tune for 1 epoch - overfits fast
student_model = train_tinybert(
    student_model,
    task_name,
    train_dataloader,
    eval_dataloader,
    device,
    output_mode,
    num_labels,
    eval_labels,
    optimizer=None,
    scheduler=None,
    epochs=1
)

# SAVE MODEL
model_path = '/content/tinyBert_sst2.pt'
torch.save(student_model.state_dict(), model_path)



Epoch 1/1


Training: 100%|██████████| 2105/2105 [02:35<00:00, 13.51it/s]


Training loss: 0.2796
Training accuracy: 0.8867


Evaluating: 100%|██████████| 28/28 [00:00<00:00, 42.09it/s]


Validation loss: 0.2928
Validation accuracy: 0.8945


# Pruning

## Pruning Functions

In [None]:
def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

#Structured Pruning

In [None]:
from sklearn.decomposition import PCA

def prune_embeddings(model: nn.Module, prune_ratios: Union[float, List[float]]) -> nn.Module:
    """
    Prune the embeddings layer of the model. This function prunes the word_embeddings, position_embeddings, token_type, and LayerNorm layers.

    Args:
        model: The language model to prune (e.g., TinyBERT).
        prune_ratios: A single float or a list of floats specifying the prune ratio per layer.

    Returns:
        The pruned model.
    """

    model = copy.deepcopy(model)  # Prevent overwriting the original model

    # Get the original embedding matrix
    word_embedding_layer = model.bert.embeddings.word_embeddings
    position_embedding_layer = model.bert.embeddings.position_embeddings
    token_type_embedding_layer = model.bert.embeddings.token_type_embeddings

    original_word_embeddings = word_embedding_layer.weight.detach().cpu().numpy()
    original_position_embeddings = position_embedding_layer.weight.detach().cpu().numpy()
    original_token_type_embeddings = token_type_embedding_layer.weight.detach().cpu().numpy()

    # Determine new embedding size (20% pruned)
    original_word_dim = original_word_embeddings.shape[1]
    original_position_dim = original_position_embeddings.shape[1]
    original_token_type_dim = original_token_type_embeddings.shape[1]

    new_word_dim = int(round(original_word_dim * prune_ratios))
    new_position_dim = int(round(original_position_dim * prune_ratios))
    new_token_type_dim = int(round(original_token_type_dim * prune_ratios))

    # Perform PCA to reduce dimensionality
    pca = PCA(n_components=new_word_dim)
    reduced_embeddings = pca.fit_transform(original_word_embeddings)

    pca = PCA(n_components=new_position_dim)
    reduced_position_embeddings = pca.fit_transform(original_position_embeddings)

    reduced_token_type_embeddings = original_token_type_embeddings[ : ,:new_token_type_dim ] # instead of using PCA for the token type layer, slice due to small size

    ## word embeddings - replacing layer
    # Convert reduced embeddings back to PyTorch format
    new_word_embedding_layer = torch.nn.Embedding.from_pretrained(
        torch.tensor(reduced_embeddings, dtype=torch.float32)
    )
    # Replace the embedding layer in the model
    model.bert.embeddings.word_embeddings = new_word_embedding_layer


    ## position embeddings - replacing layer
    new_position_embedding_layer = torch.nn.Embedding.from_pretrained(
        torch.tensor(reduced_position_embeddings, dtype=torch.float32)
    )
    model.bert.embeddings.position_embeddings = new_position_embedding_layer


    ## token type embeddings - replacing layer
    new_token_type_embedding_layer = torch.nn.Embedding.from_pretrained(
        torch.tensor(reduced_token_type_embeddings, dtype=torch.float32)
    )
    model.bert.embeddings.token_type_embeddings = new_token_type_embedding_layer



    ## pruning the layer norm (see the printed out layers a few cells below)
    embedding_layernorm = model.bert.embeddings.LayerNorm
    original_layernorm_dim = embedding_layernorm.weight.shape[0]
    new_layernorm_dim = int(round(original_layernorm_dim * prune_ratios))

    # Get indices to keep (assuming importance is based on weight magnitude)
    importance = torch.abs(embedding_layernorm.weight)
    _, idx_to_keep = torch.topk(importance, k=new_layernorm_dim)
    idx_to_keep_sorted, _ = torch.sort(idx_to_keep)  # Ensure indices are sorted

    # Prune the LayerNorm weight and bias
    embedding_layernorm.weight = nn.Parameter(torch.index_select(embedding_layernorm.weight, 0, idx_to_keep_sorted))
    embedding_layernorm.bias = nn.Parameter(torch.index_select(embedding_layernorm.bias, 0, idx_to_keep_sorted))

    # Update LayerNorm's normalized_shape to match pruned dimension
    embedding_layernorm.normalized_shape = (new_layernorm_dim,)

    return model


In [None]:
# Calculate number of heads to keep
def get_num_channels_to_keep(num_channels: int, prune_ratio: float) -> int:
    return int(round(num_channels * (1 - prune_ratio)))

@torch.no_grad()
def prune_attention_channels(model: nn.Module, prune_ratios: Union[float, List[float]]) -> nn.Module:
    """
    Prune attention channels in the Transformer layers of the model.

    Args:
        model: The language model to prune (e.g., TinyBERT).
        prune_ratios: A single float or a list of floats specifying the prune ratio per layer.

    Returns:
        The pruned model.
    """
    model = copy.deepcopy(model)  # Prevent overwriting the original model

    transformer_layers = model.bert.encoder.layer  # Assuming the model follows BERT's architecture
    n_layers = len(transformer_layers)
    print("Transformer Layers")
    print(transformer_layers)

    # Ensure prune_ratios is a list
    if isinstance(prune_ratios, float):
        prune_ratios = [prune_ratios] * n_layers
    else:
        assert len(prune_ratios) == n_layers, "Length of prune_ratios must match number of layers"


    # Prune channels by selecting the corresponding weights and biases
    def prune_linear_layer(layer, idx_to_keep, dim=0):
        new_weight = torch.index_select(layer.weight.data, 0, idx_to_keep)
        new_weight = torch.index_select(new_weight, 1, idx_to_keep)
        if layer.bias is not None:
            # print("bias in prune_linear_layer: ", layer.bias.shape)
            new_bias = torch.index_select(layer.bias.data, 0, idx_to_keep)

        else:
            new_bias = None
        new_layer = nn.Linear(new_weight.size(1), new_weight.size(0), bias=layer.bias is not None)
        new_layer.weight.data = new_weight.clone().detach()
        if new_bias is not None:
            new_layer.bias.data = new_bias.clone().detach()
        return new_layer

    def prune_LayerNorm(layer, idx_to_keep, dim=0):
        # print("IDX linear: ", idx_to_keep)
        # print("weight in prune_linear_layer: ", layer.weight.shape)

        # Check if weight has more than one dimension before trying to select along dim 1
        if layer.weight.dim() > 1:
            new_weight = torch.index_select(layer.weight.data, 0, idx_to_keep)
            new_weight = torch.index_select(new_weight, 1, idx_to_keep)
        else:  # If it's 1-dimensional (like in LayerNorm), select only along dim 0
            new_weight = torch.index_select(layer.weight.data, 0, idx_to_keep)

        if layer.bias is not None:
            # print("bias in prune_linear_layer: ", layer.bias.shape)
            new_bias = torch.index_select(layer.bias.data, 0, idx_to_keep)
        else:
            new_bias = None

        # Determine input and output features based on pruned weight dimensions
        in_features = new_weight.size(1) if new_weight.dim() > 1 else new_weight.size(0) # Handle 1D case
        out_features = new_weight.size(0)

        new_layer = nn.Linear(in_features, out_features, bias=layer.bias is not None)
        new_layer.weight.data = new_weight.clone().detach()
        if new_bias is not None:
            new_layer.bias.data = new_bias.clone().detach()
        return new_layer

    def prune_intermediate_layer(layer, idx_to_keep, dim=0):
        new_weight = torch.index_select(layer.dense.weight.data, dim, idx_to_keep)

        if layer.dense.bias is not None:
            new_bias = torch.index_select(layer.dense.bias.data, 0, idx_to_keep)
        else:
            new_bias = None
        # for output
        new_layer = nn.Linear(new_weight.size(1), new_weight.size(0), bias=layer.dense.bias is not None)
        new_layer.weight.data = new_weight.clone().detach()
        if new_bias is not None:
            new_layer.bias.data = new_bias.clone().detach()
        return new_layer

    def prune_out(layer, n_keep):
      old_input_dim = layer.in_features
      old_output_dim = layer.out_features
      new_layer = nn.Linear(n_keep, old_output_dim)
      with torch.no_grad():
        if n_keep <= old_input_dim:
          new_layer.weight[:, :n_keep] = layer.weight[:, :n_keep]
          new_layer.bias = layer.bias
        else:
          print("New input dimension is larger than the old one. Reinitializing weights.")
      return new_layer


    ## Prune Encoder Layers
    for layer_idx, prune_ratio in enumerate(prune_ratios):
        layer = transformer_layers[layer_idx]
        attention = layer.attention.self
        print("Layer", layer_idx)

        # Get the number of attention channels, number of heads, and head dimensions
        num_channels = attention.query.weight.shape[0]
        num_heads = attention.num_attention_heads
        head_dim = attention.attention_head_size

        # Get number of channels to keep after pruning
        n_keep = get_num_channels_to_keep(num_channels, prune_ratio)
        assert n_keep > 0, "After pruning, at least one attention channel must remain"

        # Compute importance of each head (e.g., using the norm of the weights)
        q_weight = attention.query.weight.view(num_channels, num_channels, -1)
        k_weight = attention.key.weight.view(num_channels, num_channels, -1)
        v_weight = attention.value.weight.view(num_channels, num_channels, -1)

        # Sum norms across Q, K, V weights for each head
        head_importance = (q_weight.norm(dim=(1, 2)) + k_weight.norm(dim=(1, 2)) + v_weight.norm(dim=(1, 2)))

        # Get indices of heads to keep
        _, idx = torch.sort(head_importance, descending=True)
        idx_to_keep = idx[:n_keep]
        idx_to_keep_sorted, _ = torch.sort(idx_to_keep)

        # Prune Query, Key, Value linear layers
        attention.query = prune_linear_layer(attention.query, idx_to_keep_sorted, dim=0)
        attention.key = prune_linear_layer(attention.key, idx_to_keep_sorted, dim=0)
        attention.value = prune_linear_layer(attention.value, idx_to_keep_sorted, dim=0)

        # Prune Self-Output layer
        layer.attention.output.dense = prune_linear_layer(layer.attention.output.dense, idx_to_keep_sorted, dim=1)

        # Prune LayerNorm layer
        layer.attention.output.LayerNorm = prune_LayerNorm(layer.attention.output.LayerNorm, idx_to_keep_sorted, dim=1)

        # Prune Intermediate layer
        layer.intermediate = prune_intermediate_layer(layer.intermediate, idx_to_keep_sorted, dim=1)

        # Prune Output layer
        layer.output = prune_intermediate_layer(layer.output, idx_to_keep_sorted, dim=0)

        # Update attention_head_size
        attention.attention_head_size = n_keep

    ## Prune Pooler layer
    model.bert.pooler.dense = prune_linear_layer(model.bert.pooler.dense, idx_to_keep_sorted, dim=1) # dim isn't used here actually

    ## Prune Classifier layer
    model.classifier = prune_out(model.classifier, n_keep)

    ## Prune Fit Dense layer
    model.fit_dense = prune_out(model.fit_dense, n_keep)

    return model


In [None]:
# Define prune ratios (e.g., pruning 20% of attention heads and FFN neurons)
attention_prune_ratio = 0.2  # Prune 20% of attention heads
ffn_prune_ratio = 0.2        # Prune 20% of FFN neurons

# Prune attention heads
pruned_model = prune_embeddings(student_model, 0.8)
pruned_model = prune_attention_channels(pruned_model, attention_prune_ratio)

# Prune FFN neurons
# pruned_model = prune_ffn_neurons(pruned_model, ffn_prune_ratio)



# pruned_model = monkey_patch_attention_output(pruned_model, device)
pruned_model.to(device)
# outputs = pruned_model(input_ids.to(device), attention_mask=attention_mask.to(device), token_type_ids=segment_ids.to(device))

Transformer Layers
ModuleList(
  (0-3): 4 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=312, out_features=312, bias=True)
        (key): Linear(in_features=312, out_features=312, bias=True)
        (value): Linear(in_features=312, out_features=312, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=312, out_features=312, bias=True)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=312, out_features=1200, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=1200, out_features=312, bias=True)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
Layer 0
Layer 1
Layer 2
Layer 3


TinyBertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 250)
      (position_embeddings): Embedding(512, 250)
      (token_type_embeddings): Embedding(2, 250)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=250, out_features=250, bias=True)
              (key): Linear(in_features=250, out_features=250, bias=True)
              (value): Linear(in_features=250, out_features=250, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=250, out_features=250, bias=True)
              (LayerNorm): Linear(in_features=250, out_features=250, bias=True)
              (dropout): D

In [None]:
pruned_model

TinyBertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 250)
      (position_embeddings): Embedding(512, 250)
      (token_type_embeddings): Embedding(2, 250)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=250, out_features=250, bias=True)
              (key): Linear(in_features=250, out_features=250, bias=True)
              (value): Linear(in_features=250, out_features=250, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=250, out_features=250, bias=True)
              (LayerNorm): Linear(in_features=250, out_features=250, bias=True)
              (dropout): D

In [None]:
for name, param in pruned_model.named_parameters():
    print(f"{name}: {param.shape}")

bert.embeddings.word_embeddings.weight: torch.Size([30522, 250])
bert.embeddings.position_embeddings.weight: torch.Size([512, 250])
bert.embeddings.token_type_embeddings.weight: torch.Size([2, 250])
bert.embeddings.LayerNorm.weight: torch.Size([250])
bert.embeddings.LayerNorm.bias: torch.Size([250])
bert.encoder.layer.0.attention.self.query.weight: torch.Size([250, 250])
bert.encoder.layer.0.attention.self.query.bias: torch.Size([250])
bert.encoder.layer.0.attention.self.key.weight: torch.Size([250, 250])
bert.encoder.layer.0.attention.self.key.bias: torch.Size([250])
bert.encoder.layer.0.attention.self.value.weight: torch.Size([250, 250])
bert.encoder.layer.0.attention.self.value.bias: torch.Size([250])
bert.encoder.layer.0.attention.output.dense.weight: torch.Size([250, 250])
bert.encoder.layer.0.attention.output.dense.bias: torch.Size([250])
bert.encoder.layer.0.attention.output.LayerNorm.weight: torch.Size([250])
bert.encoder.layer.0.attention.output.LayerNorm.bias: torch.Size([250

In [None]:
evaluate_model(pruned_model, eval_dataloader, device)

torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32])


RuntimeError: shape '[32, 128, 12, 250]' is invalid for input of size 1024000

# Appendix: Old Attempts

In [None]:
import torch.nn as nn

def monkey_patch_attention_output(model):
    for i, layer in enumerate(model.bert.encoder.layer):
        pruned_hidden_size = layer.attention.self.num_attention_heads * layer.attention.self.attention_head_size
        original_hidden_size = layer.attention.output.dense.out_features

        if pruned_hidden_size != original_hidden_size:
            print(f"Monkey-patching layer {i}: {pruned_hidden_size} -> {original_hidden_size}")
            # layer.attention.output.proj_back_to_hidden = nn.Linear(pruned_hidden_size, original_hidden_size, bias=False)
            layer.attention.output.proj_back_to_hidden = nn.Linear(pruned_hidden_size, original_hidden_size, bias=False).to(device)

            def new_forward(self, hidden_states, input_tensor):
              print("Inside patched forward - Before dense:", hidden_states.shape)
              # If the hidden_states are pruned and you have proj_back_to_hidden,
              # apply it BEFORE the dense layer
              if hasattr(self, 'proj_back_to_hidden'):
                  hidden_states = self.proj_back_to_hidden(hidden_states)
                  print("Inside patched forward - After proj_back_to_hidden:", hidden_states.shape)

              hidden_states = self.dense(hidden_states)  # now hidden_states has the expected dimension
              print("Inside patched forward - After dense:", hidden_states.shape)
              hidden_states = self.dropout(hidden_states)
              hidden_states = self.LayerNorm(hidden_states + input_tensor)
              print("Inside patched forward - After LayerNorm:", hidden_states.shape)
              return hidden_states


            layer.attention.output.forward = new_forward.__get__(layer.attention.output, type(layer.attention.output))
    return model

pruned_model = monkey_patch_attention_output(pruned_model)
pruned_model.to(device)

NameError: name 'pruned_model' is not defined

In [None]:
import torch
import torch.nn as nn
def monkey_patch_attention_output(model, device):
    for i, layer in enumerate(model.bert.encoder.layer):
        pruned_hidden_size = layer.attention.self.num_attention_heads * layer.attention.self.attention_head_size
        original_hidden_size = layer.attention.output.dense.out_features

        if pruned_hidden_size != original_hidden_size:
            print(f"Monkey-patching layer {i}: {pruned_hidden_size} -> {original_hidden_size}")
            # Create projection on CPU by default, then move to device
            layer.attention.output.proj_back_to_hidden = nn.Linear(pruned_hidden_size, original_hidden_size, bias=False)
            layer.attention.output.proj_back_to_hidden.to(device)  # Move this new layer to GPU

            def new_forward(self, hidden_states, input_tensor):
                # If pruning changed the dimensions, we must project before dense
                if hasattr(self, 'proj_back_to_hidden'):
                    hidden_states = self.proj_back_to_hidden(hidden_states)  # This will now be on GPU
                hidden_states = self.dense(hidden_states)
                hidden_states = self.dropout(hidden_states)
                hidden_states = self.LayerNorm(hidden_states + input_tensor)
                return hidden_states

            # Monkey-patch the forward method
            layer.attention.output.forward = new_forward.__get__(layer.attention.output, type(layer.attention.output))

    return model

# Define prune ratios (e.g., pruning 20% of attention heads and FFN neurons)
attention_prune_ratio = 0.2  # Prune 20% of attention heads
ffn_prune_ratio = 0.2        # Prune 20% of FFN neurons

# Prune attention heads
pruned_model = prune_attention_heads(student_model, attention_prune_ratio)

# Prune FFN neurons
# pruned_model = prune_ffn_neurons(pruned_model, ffn_prune_ratio)


# pruned_model = monkey_patch_attention_output(pruned_model, device)
pruned_model.to(device)

# outputs = pruned_model(input_ids.to(device), attention_mask=attention_mask.to(device), token_type_ids=segment_ids.to(device))


Transformer Layers
ModuleList(
  (0-3): 4 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=312, out_features=312, bias=True)
        (key): Linear(in_features=312, out_features=312, bias=True)
        (value): Linear(in_features=312, out_features=312, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=312, out_features=312, bias=True)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=312, out_features=1200, bias=True)
    )
    (output): BertOutput(
      (dense): Linear(in_features=1200, out_features=312, bias=True)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
Layer 0
Attention BertSelfAttention(
  (query): Linear(in_features=312, out_features=312, bias=True)
  (key): L

TinyBertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=260, out_features=260, bias=True)
              (key): Linear(in_features=260, out_features=260, bias=True)
              (value): Linear(in_features=260, out_features=260, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=260, out_features=312, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=

In [None]:
pruned_model

TinyBertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 312, padding_idx=0)
      (position_embeddings): Embedding(512, 312)
      (token_type_embeddings): Embedding(2, 312)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=260, out_features=260, bias=True)
              (key): Linear(in_features=260, out_features=260, bias=True)
              (value): Linear(in_features=260, out_features=260, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=260, out_features=312, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=

In [None]:
def monkey_patch_ffn_output(model):
    for i, layer in enumerate(model.bert.encoder.layer):
        # After pruning FFN, input to BertOutput might be reduced
        pruned_hidden_size = layer.intermediate.dense.out_features
        original_hidden_size = layer.output.dense.in_features

        if pruned_hidden_size != original_hidden_size:
            layer.output.proj_back_to_hidden = nn.Linear(pruned_hidden_size, original_hidden_size, bias=False)

            def new_output_forward(self, hidden_states, input_tensor):
                hidden_states = self.dense(hidden_states)
                # Apply projection if present
                if hasattr(self, 'proj_back_to_hidden'):
                    hidden_states = self.proj_back_to_hidden(hidden_states)
                hidden_states = self.dropout(hidden_states)
                hidden_states = self.LayerNorm(hidden_states + input_tensor)
                return hidden_states

            layer.output.forward = new_output_forward.__get__(layer.output, type(layer.output))
    return model


In [None]:

@torch.no_grad()
def prune_ffn_neurons(model: nn.Module, prune_ratios: Union[float, List[float]]) -> nn.Module:
    """
    Prune neurons in the feed-forward networks of the Transformer layers.

    Args:
        model: The language model to prune.
        prune_ratios: A single float or a list of floats specifying the prune ratio per layer.

    Returns:
        The pruned model.
    """
    model = copy.deepcopy(model)
    transformer_layers = model.bert.encoder.layer
    n_layers = len(transformer_layers)

    # Ensure prune_ratios is a list
    if isinstance(prune_ratios, float):
        prune_ratios = [prune_ratios] * n_layers
    else:
        assert len(prune_ratios) == n_layers, "Length of prune_ratios must match number of layers"

    for layer_idx, prune_ratio in enumerate(prune_ratios):
        layer = transformer_layers[layer_idx]
        ffn = layer.intermediate
        output_ffn = layer.output

        # Get the number of neurons in the intermediate dense layer
        hidden_dim = ffn.dense.weight.size(0)

        # Calculate number of neurons to keep
        n_keep = get_num_units_to_keep(hidden_dim, prune_ratio)
        assert n_keep > 0, "After pruning, at least one neuron must remain in FFN"

        # Compute importance of each neuron (e.g., using the norm of the weights)
        neuron_importance = ffn.dense.weight.norm(dim=1)

        # Get indices of neurons to keep
        _, idx = torch.sort(neuron_importance, descending=True)
        idx_to_keep = idx[:n_keep]
        idx_to_keep_sorted, _ = torch.sort(idx_to_keep)

        # Prune the intermediate dense layer
        new_ffn_dense = nn.Linear(ffn.dense.in_features, n_keep)
        new_ffn_dense.weight.data = torch.index_select(ffn.dense.weight.data, 0, idx_to_keep_sorted).clone().detach()
        new_ffn_dense.bias.data = torch.index_select(ffn.dense.bias.data, 0, idx_to_keep_sorted).clone().detach()
        ffn.dense = new_ffn_dense

        # Prune the output dense layer
        new_output_dense = nn.Linear(n_keep, output_ffn.dense.out_features)
        new_output_dense.weight.data = torch.index_select(output_ffn.dense.weight.data, 1, idx_to_keep_sorted).clone().detach()
        new_output_dense.bias.data = output_ffn.dense.bias.data.clone().detach()
        output_ffn.dense = new_output_dense

        # Add a projection layer for the residual connection
        if n_keep != output_ffn.dense.out_features:
            output_ffn.residual_proj = nn.Linear(n_keep, output_ffn.dense.out_features, bias=False)
        else:
            output_ffn.residual_proj = nn.Identity()  # If dimensions match, use an identity layer

    return model


In [None]:
print(" * Without sorting...")
pruned_model_accuracy = evaluate(pruned_model)
unpruned_model_accuracy = evaluate(student_model)
print(f"Unpruned model has accuracy={unpruned_model_accuracy['acc']*100:.2f}%")
print(f"pruned model has accuracy={pruned_model_accuracy['acc']*100:.2f}%")
pruned_model_size = get_model_size(pruned_model)
student_model_size = get_model_size(student_model)
print(f"Unpruned model has size={student_model_size/MiB:.2f} MiB")
print(f"Pruned model has size={pruned_model_size/MiB:.2f} MiB")

# only finetune for 1 epoch - overfits fast
fine_tuned_pruned_model = train_tinybert(
    pruned_model,
    task_name,
    train_dataloader,
    eval_dataloader,
    device,
    output_mode,
    num_labels,
    eval_labels,
    optimizer=None,
    scheduler=None,
    epochs=1
)

# print(" * With sorting...")
# sorted_model = apply_channel_sorting(student_model)
# pruned_model = channel_prune(sorted_model, channel_pruning_ratio)
# pruned_model_accuracy = evaluate(pruned_model)
# print(f"pruned model has accuracy={pruned_model_accuracy:.2f}%")