# The Optimal BERT Surgeon: Scalable and Accurate Second-Order Pruning for Large Language Models (oBERT)

### Paper: [https://arxiv.org/abs/2203.07259](https://arxiv.org/abs/2203.07259)

The oBERT implementation is integrated with the SparseML library in the form of [OBSPruningModifier](https://github.com/neuralmagic/sparseml/blob/main/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py), making it very easy to run experiments with, reproduce results from the paper or even compress new models.
We also provide [bash scripts](https://github.com/neuralmagic/sparseml/tree/main/research/optimal_BERT_surgeon_oBERT/scripts) and [recipes](https://github.com/neuralmagic/sparseml/tree/main/research/optimal_BERT_surgeon_oBERT/recipes) used to produce results from the paper, and they can be easily modified to encompass new models and datasets.

Here, we extract the algoritmic part for oBERT unstructured pruning from the OBSPruningModifier to showcase the main operations involved in the pruning process.

In [1]:
import math
import torch
from torch import Tensor

The following `EmpiricalBlockFisherInverse` class implements and holds the block-wise approximation of the inverse Hessian. The approximation is in the form of a dampened empirical Fisher information matrix:
$$
H_{\mathcal{L}}(\mathbf{w}) \simeq \widehat{\mathbf{F}} (\mathbf{w}) = \lambda \mathbf{I}_d + \frac{1}{m} \sum_{i=1}^{m} \nabla \mathcal{L}_i(\mathbf{w}) \nabla \mathcal{L}^\top_i(\mathbf{w})
$$
Relying on the fact that this is a sum of rank-1 matrices, the Woodbury/Sherman-Morrison inversion formula can be utilized to exactly calculate the Fisher inverse. Unrolling the recursive formulation with $ \widehat{\mathbf{F}}^{-1}_0(\mathbf{w}) = \frac{1}{\lambda} \mathbf{I}_d$, we can obtain an iterative formula to exactly calculate the inverse of the empirical Fisher matrix as:
$$
\widehat{\mathbf{F}}^{-1}(\mathbf{w}) = \widehat{\mathbf{F}}^{-1}_m(\mathbf{w}) = \frac{1}{\lambda} \mathbf{I}_d - \sum_{i=1}^{m} \frac{\left(\widehat{\mathbf{F}}^{-1}_{i-1}(\mathbf{w}) \nabla \mathcal{L}_i(\mathbf{w})\right)\left(\widehat{\mathbf{F}}^{-1}_{i-1}(\mathbf{w}) \nabla \mathcal{L}_i(\mathbf{w})\right)^\top}{m + \nabla \mathcal{L}_i^\top(\mathbf{w}) \widehat{\mathbf{F}}^{-1}_{i-1}(\mathbf{w}) \nabla \mathcal{L}_i(\mathbf{w})}
$$

This is implemented via the `add_grad` method, which efficiently updates the inverse with a new gradient.

`diag` fetches the diagonal of the inverse Fisher, which is used in calculations of the saliency score $\rho$ and of the optimal weight update $\delta \mathbf{w}$.
`mul` efficiently computes matrix-vector products between a given vector `v` and the block-wise inverse Fisher matrix, which is used to calculate the optimal weight update $\delta \mathbf{w}$.

In [2]:
class EmpiricalBlockFisherInverse:
    def __init__(
        self,
        num_grads: int,
        fisher_block_size: int,
        num_weights: int,
        damp: float,
        device: torch.device,
    ):
        self.m = num_grads
        self.B = fisher_block_size
        self.d = num_weights
        self.damp = damp
        self.dev = device

        self.num_blocks = math.ceil(self.d / self.B)
        self.F_inv = (
            (1.0 / self.damp * torch.eye(n=self.B, device=self.dev))
            .unsqueeze(0)
            .repeat(self.num_blocks, 1, 1)
        )  # takes O(d x B) memory on a device

    def add_grad(self, g: Tensor):
        """
        Updates empirical Fisher inverse with a new gradient
        :param g: a collected gradient
        """
        # if 'd / B' is not integer, pad with zeros for batch calculations
        if g.numel() < self.num_blocks * self.B:
            g = torch.cat(
                [g, torch.zeros(self.num_blocks * self.B - g.numel(), device=g.device)]
            )

        # prepare grad for batch calculations
        g = g.view(self.num_blocks, self.B)

        # batched F_inv x g: (batch, B, B) x (batch, B) -> (batch, B)
        Finv_g = torch.einsum("bij,bj->bi", self.F_inv, g)

        # scalar denominator for each batch: (batch)
        alpha = (self.m + torch.einsum("bi,bi->b", g, Finv_g)).sqrt().unsqueeze(1)
        Finv_g /= alpha

        # update F_inv with new outer product: (batch, B) x (batch, B) -> (batch, B, B)
        self.F_inv.baddbmm_(Finv_g.unsqueeze(2), Finv_g.unsqueeze(1), alpha=-1)

    def diag(self) -> Tensor:
        """
        :return: diagonal of the Fisher inverse matrix
        """
        return self.F_inv.diagonal(dim1=1, dim2=2).flatten()[: self.d]

    def mul(self, v: Tensor) -> Tensor:
        """
        Computes matrix-vector product of the Fisher inverse matrix and a vector
        :param v: a vector to compute matrix-vector product with
        :return: result of the matrix-vector multiplication
        """
        if v.numel() < self.num_blocks * self.B:
            v = torch.cat(
                [v, torch.zeros(self.num_blocks * self.B - v.numel(), device=v.device)]
            )
        return torch.bmm(
            self.F_inv, v.view(self.num_blocks, self.B).unsqueeze_(2)
        ).flatten()[: self.d]

Now, we define a dummy neural-network model:

In [3]:
device = torch.device('cuda:7')
d = 1000                          # number of prunable weights
w = torch.rand(d, device=device)  # dummy weights
target_sparsity = 0.7             # [0, 1.] range

Now, we specify oBERT pruning hyper-parameters:

In [4]:
m = 100           # number of gradients
B = 50            # block size
lambd = 1e-7      # dampening

# initialize Fisher inverse, occupies O(Bd) memory
# for example: d=85_000_000, B=50 -> 85_000_000 * 50 * 4 / 1024^3 = 16GB
fisher_inv = EmpiricalBlockFisherInverse(m, B, d, lambd, device)

Now, we collect `m` gradients used to approximate the Fisher inverse:

In [5]:
for i in range(m):
    grad = torch.rand(d, device=device)  # a dummy gradient
    fisher_inv.add_grad(grad)
    print(f"Fisher inverse updated with {i+1} gradients", end="\r")
print('\n')

Fisher inverse updated with 1 gradientsFisher inverse updated with 2 gradientsFisher inverse updated with 3 gradientsFisher inverse updated with 4 gradientsFisher inverse updated with 5 gradientsFisher inverse updated with 6 gradientsFisher inverse updated with 7 gradientsFisher inverse updated with 8 gradientsFisher inverse updated with 9 gradientsFisher inverse updated with 10 gradientsFisher inverse updated with 11 gradientsFisher inverse updated with 12 gradientsFisher inverse updated with 13 gradientsFisher inverse updated with 14 gradientsFisher inverse updated with 15 gradientsFisher inverse updated with 16 gradientsFisher inverse updated with 17 gradientsFisher inverse updated with 18 gradientsFisher inverse updated with 19 gradientsFisher inverse updated with 20 gradientsFisher inverse updated with 21 gradientsFisher inverse updated with 22 gradientsFisher inverse updated with 23 gradientsFisher inverse updated with 24 gradientsFisher inverse updated wi

Now, we calculate saliency scores for each weight $j \in \{1, 2, 3, \dots, d\}$ in the form:
$$
\rho_j = \frac{w_j^2}{2 \widehat{\mathbf{F}}^{-1}_{j,j}}
$$

In [6]:
scores = (w**2) / (2.0 * fisher_inv.diag())

Now, we prune `target_sparsity * d` weights:

In [7]:
# find pruning threshold
kth_score = torch.kthvalue(scores, round(target_sparsity * d))[0]

# prune (i.e. set masks)
mask = scores > kth_score
print(f"Pruned model's sparsity = {1 - torch.sum(mask)/mask.numel()}")

Pruned model's sparsity = 0.699999988079071


Besides pruning a weight $w_j$, the OBS framework updates the unpruned weights to compensate for the loss incurred by pruning. The optimal weight update, which prunes the weight $w_j$ and updates the remaining ones, is given by:
$$
\delta\mathbf{w}_j = -\frac{w_j}{\widehat{\mathbf{F}}^{-1}_{j,j}}\widehat{\mathbf{F}}^{-1} \mathbf{e}_j
$$
As described in the paper, due to the intractable combinatorial complexity when pruning multiple weights at once, we have to manually zero-out the pruned weights as they can be perturbed to a non-zero value by the optimal weight update coming from other pruned weights.

In [8]:
w -= fisher_inv.mul(w * (mask == 0) / fisher_inv.diag())
w[mask == 0] = 0.0

The 4-block oBERT pruning follows the same procedure, except that it implements a slightly different scoring and the optimal weight update equations, which can be found in the paper and in the [SparseML integration](https://github.com/neuralmagic/sparseml/blob/main/src/sparseml/pytorch/sparsification/pruning/modifier_pruning_obs.py).