### Sentence Transformers SVC Multilabel
This notebook demonstrates the efficacy and simplicity of Sentence Transformers embeddings for multilabel text classification when paired with a One Vs Rest Linear Support Vector Classifier.

In [1]:
import pandas as pd
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer, util
from tqdm.notebook import tqdm
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import label_ranking_average_precision_score

In [2]:
df = pd.read_csv('res/toxic_comments.csv')

In [3]:
sample = df.sample(6_000)
texts = sample.comment_text.values

In [4]:
topics = df.columns[2:]

In [5]:
text_embeddings = SentenceTransformer(
    "T-Systems-onsite/cross-en-de-roberta-sentence-transformer"
).encode(
    texts.astype(str), 
    convert_to_tensor=True
).cpu().detach().numpy()

In [6]:
class SVC:
    def __init__(self, labels):
        self.labels = labels
        self.svc = OneVsRestClassifier(LinearSVC())
        
    def train(self, X, y):
        self.svc.fit(X, y)
    
    def predict(self, X):
        return self.svc.decision_function(X)
    
    def strings_to_multihot(self, strings):
        encodings = [
            np.array([1 if l in st else 0 for l in self.labels]) for st in strings
        ]
        return np.array(encodings)
    
    def multihot_to_string(self, multihots):
        decodings = [
            *filter(lambda x: x[0] == 1, zip(multihot, self.labels))[1]
        ]
        return decodings

In [7]:
with open('res/embeddings.pkl', 'wb') as f:
    pickle.dump(text_embeddings, f)

In [8]:
with open('res/embeddings.pkl', 'rb') as f:
    text_embeddings = pickle.load(f)

In [9]:
def get_labels(row):
    labels = []
    for t in topics:
        if row[t] == 1:
            labels.append(t)
            
    return labels

In [10]:
tqdm.pandas()

In [11]:
labels = sample.progress_apply(get_labels, axis=1)

  0%|          | 0/6000 [00:00<?, ?it/s]

In [12]:
labels = labels.values

In [13]:
classifier = SVC(labels=topics)

In [14]:
labels_multihot = classifier.strings_to_multihot(labels.tolist())

In [15]:
X_train, X_test, y_train, y_test = train_test_split(text_embeddings, labels_multihot, shuffle=True)

In [16]:
classifier.train(
    X=X_train,
    y=y_train
)



In [17]:
pred = classifier.predict(
    X=X_test,
)

In [18]:
label_ranking_average_precision_score(y_test, pred)

0.9818694444444446