In [1]:
import torch
from sentence_transformers import SentenceTransformer
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
print("Model is loaded on:", model.device)

  from tqdm.autonotebook import tqdm, trange


Model is loaded on: mps:0


In [3]:
import numpy as np
import pandas as pd
from sklearn.utils import shuffle

filename = 'AG_news_dataset/train.csv'
df = pd.read_csv(filename)
descriptions_train = np.array(df['Description'])
categories_train = np.array(df['Class Index'] - 1) # This -1 is CRUCIAL because classes must be 0...N-1 where N is number of classes.
descriptions_train, categories_train = shuffle(descriptions_train, categories_train, random_state=0)

In [4]:
embeddings_train = model.encode(descriptions_train)

In [9]:
filename_test = 'AG_news_dataset/test.csv'
df_test = pd.read_csv(filename_test)
descriptions_test = np.array(df_test['Description'])
categories_test = np.array(df_test['Class Index'] - 1)
descriptions_test, categories_test = shuffle(descriptions_test, categories_test, random_state=0)

In [11]:
embeddings_test = model.encode(descriptions_test)

In [14]:
similarities = model.similarity(embeddings_test, embeddings_train)

In [15]:
similarities.shape

torch.Size([7600, 120000])

In [49]:
from collections import Counter

def vote(categories, weights):
    counter = Counter()
    for (category, weight) in zip(categories, weights):
        counter[category] += weight
    [(category, _)] = counter.most_common(1)
    return category

categories_test_predicted = np.zeros_like(categories_test)
k = 10
for i, similarity in enumerate(similarities):
    similar_train_indices = np.argpartition(similarity, -k)[-k:]
    category_test_predicted = vote(categories_train[similar_train_indices], similarity[similar_train_indices])
    categories_test_predicted[i] = category_test_predicted

In [47]:
accuracy = sum(categories_test_predicted == categories_test) / len(categories_test)

In [48]:
print(accuracy)

0.9067105263157895
