In [2]:
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn.svm import SVC
import torch

from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses, SentenceTransformerTrainingArguments
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
print("Model is loaded on:", model.device)

  from tqdm.autonotebook import tqdm, trange


Model is loaded on: mps:0


In [6]:
def combine_title_and_description(title, description):
    return f'TITLE: {title}\nDESCRIPTION: {description}\n'

In [7]:
filename = 'AG_news_dataset/train.csv'
df = pd.read_csv(filename)
categories, titles, descriptions = shuffle(
    np.array(df['Class Index']) - 1,
    np.array(df['Title']),
    np.array(df['Description']),
)
sentences = np.vectorize(combine_title_and_description)(titles, descriptions)

In [13]:
# Using a 1/3 of examples from the AG_news_dataset
# in the interest of not destroying my computer.

train_count = 40_000
test_count = 10_000

sentences_train = sentences[:train_count]
categories_train = categories[:train_count]

sentences_test = sentences[train_count:train_count+test_count]
categories_test = categories[train_count:train_count+test_count]

In [20]:
from datasets import Dataset

finetune_dataset = Dataset.from_dict({
    'sentence': sentences_train,
    'label': categories_train,
})

loss = losses.BatchAllTripletLoss(model)

args = SentenceTransformerTrainingArguments(
    output_dir='models/all-MiniLM-L6-v2-triplet-AG-news/checkpoint-5000',
    num_train_epochs=1,
)

trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=finetune_dataset,
    loss=loss,
    args=args,
)

trainer.train() # Training it on the entire dataset would take 3 hours and your computer would run out of memory.

  0%|          | 0/5000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
checkpoint_model = SentenceTransformer('models/all-MiniLM-L6-v2-triplet-AG-news/checkpoint-5000', device=device)

In [23]:
embeddings_train = checkpoint_model.encode(sentences_train)
embeddings_test = checkpoint_model.encode(sentences_test)

In [24]:
svm = SVC()
# Check performance on 10_000/120_000 examples: 8.4 seconds to train on 10_000, so * 12 would expect ~100 secs or 1m40s on full set if linear.
# It is not linear: takes >9 minutes.

# 8.4 seconds
# first_embeddings = embeddings[:10_000]
# first_categories = categories[:10_000]
# svm.fit(first_embeddings, first_categories)

# 2m52s
# first_embeddings = embeddings[:40_000]
# first_categories = categories[:40_000]
# svm.fit(first_embeddings, first_categories) # After finetuning, less difficult to fit, less SV's so 1m10s train.

svm.fit(embeddings_train, categories_train)

# More than 9 minutes
# svm.fit(embeddings, categories)


In [25]:
rest_categories_predicted = svm.predict(embeddings_test)

In [26]:
sum(rest_categories_predicted == categories_test) / len(categories_test)

0.9185