In [None]:
!pip install sentence-transformers datasets

## Pruning

In [None]:
from sentence_transformers import SentenceTransformer

distilroberta = SentenceTransformer("stsb-distilroberta-base-v2")

In [None]:
from datasets import load_metric, load_dataset

stsb_metric = load_metric("glue", "stsb")
stsb = load_dataset("glue", "stsb")

mrpc_metric = load_metric("glue", "mrpc")
mrpc = load_dataset("glue", "mrpc")

In [None]:
import math
import tensorflow as tf


def roberta_sts_benchmark(batch):
    sts_encode1 = tf.nn.l2_normalize(distilroberta.encode(batch["sentence1"]), axis=1)
    sts_encode2 = tf.nn.l2_normalize(distilroberta.encode(batch["sentence2"]), axis=1)
    cosine_similarities = tf.reduce_sum(tf.multiply(sts_encode1, sts_encode2), axis=1)
    clip_cosine_similarities = tf.clip_by_value(cosine_similarities, -1.0, 1.0)
    scores = 1.0 - tf.acos(clip_cosine_similarities) / math.pi
    return scores

In [None]:
references = stsb["validation"][:]["label"]

In [None]:
distilroberta_results = roberta_sts_benchmark(stsb["validation"])

In [None]:
from torch.nn.utils import prune

pruner = prune.L1Unstructured(amount=0.2)

In [None]:
state_dict = distilroberta.state_dict()

for key in state_dict.keys():
    if "weight" in key:
        state_dict[key] = pruner.prune(state_dict[key])

In [None]:
distilroberta.load_state_dict(state_dict)

In [None]:
distilroberta_results_p = roberta_sts_benchmark(stsb["validation"])

In [None]:
import pandas as pd

pd.DataFrame(
    {
        "DistillRoberta": stsb_metric.compute(
            predictions=distilroberta_results, references=references
        ),
        "DistillRobertaPruned": stsb_metric.compute(
            predictions=distilroberta_results_p, references=references
        ),
    }
)

## Quantization

In [None]:
import torch

distilroberta = torch.quantization.quantize_dynamic(
    model=distilroberta,
    qconfig_spec={torch.nn.Linear: torch.quantization.default_dynamic_qconfig},
    dtype=torch.qint8,
)

In [None]:
distilroberta_results_pq = roberta_sts_benchmark(stsb["validation"])

In [None]:
pd.DataFrame(
    {
        "DistillRoberta": stsb_metric.compute(
            predictions=distilroberta_results, references=references
        ),
        "DistillRobertaPruned": stsb_metric.compute(
            predictions=distilroberta_results_p, references=references
        ),
        "DistillRobertaPrunedQINT8": stsb_metric.compute(
            predictions=distilroberta_results_pq, references=references
        ),
    }
)