In [47]:
import os
import torch
import pandas as pd
from IPython.display import display
from transformers import RobertaModel, DistilBertModel, AutoModel

In [48]:
# Checking if GPU is available
print(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

cuda


In [49]:
def load_model(model_type, model_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    if model_path and os.path.exists(model_path):
        # Try loading as a generic PyTorch model
        try:
            model = torch.load(model_path, map_location=device)
        except Exception as e:
            raise IOError(f"Error loading model from {model_path}: {e}")
    elif model_type == 'distilroberta-base':
        model = DistilBertModel.from_pretrained('distilroberta-base')
    elif model_type == 'roberta-base':
        model = RobertaModel.from_pretrained('roberta-base')
    else:
        raise ValueError("Invalid model type or path")

    return model.to(device)

def check_sparsity(model):
    total_params = 0
    nonzero_params = 0
    layer_sparsity = {}
    for name, param in model.named_parameters():
        if not param.requires_grad:  # exclude non-trainable parameters
            continue
        layer_size = param.numel()
        layer_nonzero = torch.count_nonzero(param)
        layer_sparsity[name] = 1 - layer_nonzero.item() / layer_size
        total_params += layer_size
        nonzero_params += layer_nonzero.item()
    overall_sparsity = 1 - nonzero_params / total_params
    print(f"Overall Sparsity: {overall_sparsity:.4%}")
    layer_sparsity_df = pd.DataFrame(layer_sparsity.items(), columns=['Layer Name', 'Sparsity'])
    # layer_sparsity_df.sort_values(by='Sparsity', ascending=False, inplace=True)
    display(layer_sparsity_df)

def compute_global_threshold(model, pruning_rate, batch_size, device):
    all_weights = [param.view(-1) for param in model.parameters() if param.requires_grad and param.dim() > 1]
    all_weights = torch.cat(all_weights).to(device)
    batched_quantiles = [torch.quantile(batch.abs(), pruning_rate) for batch in all_weights.split(batch_size)]

    return torch.tensor(batched_quantiles).mean().to(device)

def mpruner_global(model, pruning_rate, batch_size, device):
    threshold = compute_global_threshold(model, pruning_rate, batch_size, device)
    print("Global threshold:", threshold)
    for name, param in model.named_parameters():
        if param.requires_grad and param.dim() > 1:
            mask = param.abs() > threshold
            param.data.mul_(mask.to(torch.float32))

    return model

In [50]:
model_type = 'roberta-base'     # Can be 'roberta-base', 'distilroberta-base', or a custom model path
model_path = None               # Set this to None if you want to use pre-trained models

# Loading model
model = load_model(model_type, model_path)
print(type(model))

Using device: cuda


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<class 'transformers.models.roberta.modeling_roberta.RobertaModel'>


In [51]:
# Checking sparsity before pruning
check_sparsity(model)

Overall Sparsity: 0.0019%


Unnamed: 0,Layer Name,Sparsity
0,embeddings.word_embeddings.weight,3.108525e-07
1,embeddings.position_embeddings.weight,1.945525e-03
2,embeddings.token_type_embeddings.weight,1.000000e+00
3,embeddings.LayerNorm.weight,0.000000e+00
4,embeddings.LayerNorm.bias,0.000000e+00
...,...,...
194,encoder.layer.11.output.dense.bias,0.000000e+00
195,encoder.layer.11.output.LayerNorm.weight,0.000000e+00
196,encoder.layer.11.output.LayerNorm.bias,0.000000e+00
197,pooler.dense.weight,0.000000e+00


In [52]:
# Pruning model
pruning_rate = 0.2              # Between 0 and 1
batch_size = 5000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pruned_model = mpruner_global(model, pruning_rate, batch_size, device)

Global threshold: tensor(0.0167, device='cuda:0')


In [53]:
# Checking sparsity after pruning
check_sparsity(pruned_model)

Overall Sparsity: 24.8828%


Unnamed: 0,Layer Name,Sparsity
0,embeddings.word_embeddings.weight,0.118609
1,embeddings.position_embeddings.weight,0.293711
2,embeddings.token_type_embeddings.weight,1.000000
3,embeddings.LayerNorm.weight,0.000000
4,embeddings.LayerNorm.bias,0.000000
...,...,...
194,encoder.layer.11.output.dense.bias,0.000000
195,encoder.layer.11.output.LayerNorm.weight,0.000000
196,encoder.layer.11.output.LayerNorm.bias,0.000000
197,pooler.dense.weight,0.597107


In [46]:
# Saving the model
filename = f"{model_type}-mpruned-global-{pruning_rate:.2f}.pt"
print(f"Saving model to {filename}")
torch.save(pruned_model, filename)
torch.cuda.empty_cache()

Saving model to roberta-base-mpruned-global-0.20.pt
