# Language Model compression with SVD

In [1]:
from transformers import DistilBertForMaskedLM, DistilBertTokenizer, AutoModelForMaskedLM, AutoTokenizer
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


## Define model name

In [2]:
# model_name = 'distilbert-base-cased'
model_name = 'bert-base-cased'

## Load Model and Tokenizer

In [3]:
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# set inference model for the model 

model.eval();

## Simple function to visually evaluate the model

In [5]:
def evaluate_model(model):
    # Prompt text for generation
    prompt = "The quick brown [MASK] jumps over the lazy dog."

    # Tokenize the prompt
    input_ids = tokenizer.encode(prompt, add_special_tokens=True, return_tensors='pt')

    # Find the position of the [MASK] token
    mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1][0]

    # Generate predictions for the [MASK] token
    with torch.no_grad():
        outputs = model(input_ids)
        predictions = outputs[0][0, mask_token_index].topk(5)

    # Print the top 5 predicted words
    for i, (word_idx, score) in enumerate(zip(predictions.indices, predictions.values)):
        word = tokenizer.decode([word_idx])
        print(f'{i+1}. {word} ({score:.2f})')

## Truncated SVD for matrix X

In [6]:
def svd(X, k:float=1.0):
    """
    SVD for matrix X, k is the compression level from 0.0 to 1.0, where 1.0 is the original X.
    Only make sense computationally if k < U.shape[1]/2
    """
    U, S, V = torch.svd(X)
    n = U.shape[1]
    k = round(n*k)
    if k:
        return U[:,:k], S[:k], V[:,:k]
    else:
        return U, S, V


def eval_svd(X, U, S, V):
    return (X - U @ S.diag() @ V.T).norm()

## Simple code to test SVD for a matrix

In [22]:
X = torch.rand(3000, 100)

In [23]:
X.shape

torch.Size([3000, 100])

In [24]:
for k in [0.1, 0.2, 0.3, 0.5, 0.7, 1]:
    U, S, V = svd(X, k=k)
    print(
        f"k={k}, Norm(X-SVD(X))={eval_svd(X, U, S, V)}"
    ) 

k=0.1, Norm(X-SVD(X))=147.70260620117188
k=0.2, Norm(X-SVD(X))=136.99066162109375
k=0.3, Norm(X-SVD(X))=126.0494155883789
k=0.5, Norm(X-SVD(X))=103.02754211425781
k=0.7, Norm(X-SVD(X))=76.99295806884766
k=1, Norm(X-SVD(X))=0.000343029125360772


## New Compressed Linear layer definition

In [25]:
class CompressedLinear(nn.Module):
    def __init__(self, U, S, V, b):
        super(CompressedLinear, self).__init__()
        self.lin1 = nn.Linear(*U.shape, bias=False)
        self.lin2 = nn.Linear(*V.T.shape)
        
        self.lin1.weight = nn.Parameter(V.T)
        self.lin2.weight = nn.Parameter(U @ S.diag())
        self.lin2.bias = nn.Parameter(b)
        

    def forward(self, x):
        logits = self.lin2(self.lin1(x))
        return logits

## Code for model compression 
Walk agross all layers and replace `Linear` to `CompressedLinear`

In [26]:
def compress_linear(module, k=0.4):
    X, b = module.weight, module.bias
    U, S, V = svd(X, k=k)
    return CompressedLinear(U, S, V, b)

In [27]:
def compress_model(model: nn.Module, k=0.4):
    for name, child in model.named_children():
        if isinstance(child, nn.Linear):
            new_child = compress_linear(child, k=k)
            setattr(model, name, new_child)
        elif isinstance(child, nn.Module):
            compress_model(child, k=k)
            
    return model

In [34]:
new_model = compress_model(
    AutoModelForMaskedLM.from_pretrained(model_name), k=0.6
);

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Compressed model results

In [35]:
evaluate_model(new_model)

1. dog (7.49)
2. cat (6.67)
3. horse (6.29)
4. man (6.24)
5. bird (5.79)


## Original model results

In [36]:
evaluate_model(model)

1. dog (12.20)
2. ##ie (11.23)
3. cat (10.60)
4. bear (10.13)
5. puppy (10.01)
