# Data2vec vs. SBERT
https://www.kaggle.com/datasets/shivamkushwaha/bbc-full-text-document-classification

In [None]:
# !pip install transformers
# !pip install sentence_transformers
# !pip install scikit-learn-intelex
# !pip3 install memory_profiler
# %load_ext memory_profiler

In [25]:
import torch
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
from sklearnex import patch_sklearn
patch_sklearn()
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [14]:
df = pd.read_csv("/content/bbc-text.csv")

In [15]:
X = df.text.copy()
y = df.category.copy()
y = pd.factorize(y)[0]

In [16]:
print(df.category.unique())
print(pd.factorize(y)[1])

['tech' 'business' 'sport' 'entertainment' 'politics']
[0 1 2 3 4]


## SBERT

In [17]:
%%time
%%memit
model = SentenceTransformer('all-mpnet-base-v2',device='cuda')
model.max_seq_length = 128


#Sentences are encoded by calling model.encode()
sentence_embeddings = X.apply(model.encode)
sentence_embeddings = pd.DataFrame(sentence_embeddings.tolist())

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

peak memory: 4598.75 MiB, increment: 3695.29 MiB
CPU times: user 47.8 s, sys: 4.28 s, total: 52 s
Wall time: 1min 12s


In [18]:
%%time
X_train, X_test, y_train, y_test = train_test_split(sentence_embeddings, y, test_size=0.4)

CPU times: user 11.1 ms, sys: 3.17 ms, total: 14.3 ms
Wall time: 17.6 ms


In [19]:
rfc = RandomForestClassifier().fit(X_train, y_train)

In [20]:
prediction = rfc.predict(X_test)

In [30]:
print(classification_report(y_test, prediction, target_names=df.category.unique()))

               precision    recall  f1-score   support

         tech       0.93      0.99      0.96       167
     business       0.96      0.96      0.96       196
        sport       1.00      1.00      1.00       216
entertainment       1.00      0.94      0.97       148
     politics       0.97      0.96      0.96       163

     accuracy                           0.97       890
    macro avg       0.97      0.97      0.97       890
 weighted avg       0.97      0.97      0.97       890



## Data2vec

In [31]:
%%time
%%memit
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("facebook/data2vec-text-base")
model = AutoModel.from_pretrained("facebook/data2vec-text-base")

#Tokenize sentences
encoded_input = tokenizer(list(X), padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings = pd.DataFrame(sentence_embeddings.tolist())

Downloading:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.01M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/714 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/data2vec-text-base were not used when initializing Data2VecTextModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing Data2VecTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Data2VecTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Data2VecTextModel were not initialized from the model checkpoint at facebook/data2vec-text-base and are newly initialized: ['data2vec_text.pooler.dense.bias', 'data2vec_text.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it f

peak memory: 16574.69 MiB, increment: 12018.17 MiB
CPU times: user 14min 15s, sys: 13.3 s, total: 14min 28s
Wall time: 7min 27s


In [32]:
%%time
X_train, X_test, y_train, y_test = train_test_split(sentence_embeddings, y, test_size=0.4)

CPU times: user 13.8 ms, sys: 997 µs, total: 14.8 ms
Wall time: 14.1 ms


In [33]:
rfc = RandomForestClassifier().fit(X_train, y_train)

In [34]:
prediction = rfc.predict(X_test)

In [35]:
print(classification_report(y_test, prediction, target_names=df.category.unique()))

               precision    recall  f1-score   support

         tech       0.90      0.88      0.89       164
     business       0.84      0.88      0.86       199
        sport       0.93      0.97      0.95       203
entertainment       0.96      0.85      0.90       157
     politics       0.80      0.81      0.80       167

     accuracy                           0.88       890
    macro avg       0.89      0.88      0.88       890
 weighted avg       0.88      0.88      0.88       890

