NOTE: This originally was a notebook by Andrej Karpathy. Check it out here: https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb

# kNN vs. SVM

A very common workflow is to index some data based on its embeddings and then given a new query embedding retrieve the most similar examples with k-Nearest Neighbor search. For example, you can imagine embedding a large collection of papers by their abstracts and then given a new paper of interest retrieve the most similar papers to it.

TLDR in my experience it ~always works better to use an SVM instead of kNN, if you can afford the slight computational hit. Example below:

In [None]:
import numpy as np
np.random.seed(42)

embeddings = np.random.randn(1000, 1536) # 1000 documents, 1536-dimensional embeddings
embeddings = embeddings / np.sqrt((embeddings**2).sum(1, keepdims=True)) # L2 normalize the rows, as is common

query = np.random.randn(1536) # the query vector
query = query / np.sqrt((query**2).sum())

In [None]:
# Tired: use kNN
similarities = embeddings.dot(query)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")
for k in sorted_ix[:10]:
  print(f"row {k}, similarity {similarities[k]}")

top 10 results:
row 545, similarity 0.07956628031855817
row 790, similarity 0.07109372365891174
row 973, similarity 0.06920799481214632
row 597, similarity 0.06474824575503951
row 479, similarity 0.06350781255023313
row 229, similarity 0.0614321834997024
row 976, similarity 0.061222853526241586
row 568, similarity 0.060888722805113274
row 800, similarity 0.06007081261453451
row 654, similarity 0.058158824328240384


In [None]:
query.shape

(1536,)

In [None]:
# Wired: use an SVM
from sklearn import svm

# create the "Dataset"
x = np.concatenate([query[None,...], embeddings]) # x is (1001, 1536) array, with query now as the first row
y = np.zeros(1001)
y[0] = 1 # we have a single positive example, mark it as such

# train our (Exemplar) SVM
# docs: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y) # train

# infer on whatever data you wish, e.g. the original data
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")
for k in sorted_ix[:10]:
  print(f"row {k}, similarity {similarities[k]}")

top 10 results:
row 0, similarity 0.9797112617216351
row 546, similarity -0.8360649738915675
row 791, similarity -0.8519226181122038
row 974, similarity -0.8585435504683989
row 480, similarity -0.8620392370633861
row 598, similarity -0.8653315003700203
row 230, similarity -0.8671983886478062
row 569, similarity -0.8674761579346135
row 977, similarity -0.8705646065664832
row 801, similarity -0.8728033782558365




In practice you will find that this ordering:

- is of higher quality
- is slower: we have to train an SVM
- can easily accommodate a number of positives not just one, so it is more flexible
- don't be scared of having a single positive and everything else being negative. this is totally fine!
- if you have way way too many negatives, consider subsampling and only using a portion of them.

**Value of C**: You'll want to tune C. You'll most likely find the best setting to be between 0.01 and 10. Values like 10 very severely penalize the classifier for any mispredictions on your data. It will make sure to fit your data. Values like 0.01 will incur less penalty and will be more regularized. Usually this is what you want. I find that in practice a value like 0.1 works well if you only have a few examples that you don't trust too much. If you have more examples and they are very noise-free, try more like 1.0

**Why does this work?** In simple terms, because SVM considers the entire cloud of data as it optimizes for the hyperplane that "pulls apart" your positives from negatives. In comparison, the kNN approach doesn't consider the global manifold structure of your entire dataset and "values" every dimension equally. The SVM basically finds the way that your positive example is unique in the dataset, and then only considers its unique qualities when ranking all the other examples.

Ok cool try it out.

In [None]:
!pip install datasets sentence_transformers

In [21]:
import numpy as np
import pandas as pd
import textwrap

In [5]:
# load IMDB dataset from huggingface
from datasets import load_dataset
imdb = load_dataset('imdb')



Downloading builder script:   0%|          | 0.00/4.31k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.59k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [6]:
# load a pretrained sentence transformer model, and embed the IMDB reviews
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-distilroberta-base-v1')
embeddings = model.encode(imdb['train']['text'][:1000], show_progress_bar=True)


Downloading (…)7f4ef/.gitattributes:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)f279f7f4ef/README.md:   0%|          | 0.00/3.74k [00:00<?, ?B/s]

Downloading (…)79f7f4ef/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading (…)279f7f4ef/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/329M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)7f4ef/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

Downloading (…)279f7f4ef/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)9f7f4ef/modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

Batches:   0%|          | 0/32 [00:00<?, ?it/s]

In [23]:
# query a random review
query = model.encode([imdb['test']['text'][0]], show_progress_bar=True)[0]
print(textwrap.fill(f"Query, label {imdb['test']['label'][0]}, text: {imdb['test']['text'][0]}", width=100))

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query, label 0, text: I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are
usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it
is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard
sets, stilted dialogues, CG that doesn't match the background, and painfully one-dimensional
characters cannot be overcome with a 'sci-fi' setting. (I'm sure there are those of you out there
who think Babylon 5 is good sci-fi TV. It's not. It's clichéd and uninspiring.) While US viewers
might like emotion and character development, sci-fi is a genre that does not take itself seriously
(cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It's really
difficult to care about the characters here as they are not simply foolish, just missing a spark of
life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of
Earth KNOW it's rubbish

In [20]:
# Tired: use kNN
similarities = embeddings.dot(query)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")

df = pd.DataFrame(columns=['Row', 'Similarity', 'Label', 'Text'])

# Populate the DataFrame with data
for k in sorted_ix[:10]:
    df = df.append({
        'Row': k,
        'Similarity': similarities[k],
        'Label': imdb['train']['label'][k],
        'Text': imdb['train']['text'][k]
    }, ignore_index=True)

# Display the DataFrame in the notebook
df


top 10 results:


  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({


Unnamed: 0,Row,Similarity,Label,Text
0,368,39.398106,0,I cannot believe how popular this show is. I c...
1,219,33.196064,0,"I'm not a huge Star Trek fan, but I was lookin..."
2,644,30.747005,0,I love Columbo and have seen pretty much all o...
3,940,30.543634,0,Anyone who has a remote interest in science fi...
4,223,30.486486,0,"I really wanted to like this western, being a ..."
5,794,30.141048,0,Where to start... If this movie had been a dar...
6,11,29.86915,0,I can't believe that those praising this movie...
7,488,29.799562,0,I love special effects and witnessing new tech...
8,577,29.78463,0,"This movie was messed up. A sequel to ""John Ca..."
9,513,29.598888,0,The Good Earth is perhaps the most boring film...


In [24]:
# Wired: use an SVM
from sklearn import svm

# create the "Dataset"
x = np.concatenate([query[None,...], embeddings]) # x is (1001, 1536) array, with query now as the first row
y = np.zeros(1001)
y[0] = 1 # we have a single positive example, mark it as such

# train our (Exemplar) SVM
# docs: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y) # train

# infer on whatever data you wish, e.g. the original data
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")
for k in sorted_ix[:10]:
  # need to adjust k index since rows of x are one offset from imdb data
  if k != 0:
    print(f"row {k-1}, similarity {similarities[k]}, label {imdb['train']['label'][k-1]}, text: {imdb['train']['text'][k-1]}")
  else:
    print(f"row {k}, similarity {similarities[k]}, label {imdb['test']['label'][0]}, text: {imdb['test']['text'][0]}")

df = pd.DataFrame(columns=['Row', 'Similarity', 'Label', 'Text'])

# Populate the DataFrame with data
for k in sorted_ix[:10]:
    if k != 0:
      df = df.append({
          'Row': k-1,
          'Similarity': similarities[k],
          'Label': imdb['train']['label'][k-1],
          'Text': imdb['train']['text'][k-1]
      }, ignore_index=True)
    else:
      df = df.append({
          'Row': k,
          'Similarity': similarities[k],
          'Label': imdb['test']['label'][0],
          'Text': imdb['test']['text'][0]
      }, ignore_index=True)

# Display the DataFrame in the notebook
df

top 10 results:
row 0, similarity 0.9989627820856244, label 0, text: I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match the background, and painfully one-dimensional characters cannot be overcome with a 'sci-fi' setting. (I'm sure there are those of you out there who think Babylon 5 is good sci-fi TV. It's not. It's clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It's really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful 

  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({
  df = df.append({


Unnamed: 0,Row,Similarity,Label,Text
0,0,0.998963,0,I love sci-fi and am willing to put up with a ...
1,368,-0.73205,0,I cannot believe how popular this show is. I c...
2,70,-0.876357,0,"What does the "" Executive producer "" do in a m..."
3,257,-0.897535,0,Hail Bollywood and men Directors !<br /><br />...
4,728,-0.902951,0,"This has an interesting, albeit somewhat fanci..."
5,371,-0.905119,0,1 How is it that everyone can understand each ...
6,858,-0.910527,0,"Robert Wagner is the evil boss of Digicron, a ..."
7,497,-0.91608,0,C'mon guys some previous reviewers have nearly...
8,940,-0.920194,0,Anyone who has a remote interest in science fi...
9,219,-0.930657,0,"I'm not a huge Star Trek fan, but I was lookin..."


**What about the computational cost of SVM?** You might be worried about the fact that SVM is more expensive to run than kNN. This is true, but not by much. For example, in the above example, the 1000x1536 matrix was indexed in 0.5 seconds by the kNN approach. The SVM approach took 0.6 seconds to train, and then 0.4 seconds to infer on the same 1000 examples. So in total it took 1.0 seconds, which is only 2x slower than kNN. This is a small price to pay for the quality gains.

**What about the memory cost of SVM?** You might also be worried about the memory cost of SVM. This is also true, but again not by much. In the above example