In [1]:
import copy 
from pathlib import Path
from typing import Optional

import torch
from torch import nn
from torch import Tensor, device

import pandas as pd
import numpy as np

from tqdm.notebook import tqdm, trange
from fastprogress.fastprogress import master_bar, progress_bar

from sentence_transformers import SentenceTransformer, util

import config
from faiss_indexer import FAISS

  warn("Couldn't import ipywidgets properly, progress bar will use console behavior")


In [2]:
config.MODEL

'paraphrase-mpnet-base-v2'

In [19]:
distill_model = SentenceTransformer(config.MODEL_SMALL)
config.MODEL_SMALL

'paraphrase-MiniLM-L3-v2'

In [3]:
model = SentenceTransformer(config.MODEL)
LAYERS = 12

In [4]:
dim = model.encode(["hello"]).shape[-1]
dim

768

In [5]:
n_emb = 10000
emb_size = dim
a = torch.rand(n_emb, emb_size)
a.dtype, a.shape

(torch.float32, torch.Size([10000, 768]))

## Prune top layers from sentence-transformer

In [6]:
# encoder = model._modules['0']._modules['auto_model']._modules['encoder']._modules['layer']
# encoder

In [7]:
def deleteEncodingLayers(model, layers_to_keep):  # must pass in the full model
    oldModuleList = model._modules['0']._modules['auto_model']._modules['encoder']._modules['layer']
    newModuleList = nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in layers_to_keep:
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel._modules['0']._modules['auto_model']._modules['encoder']._modules['layer'] = newModuleList

    return copyOfModel

In [8]:
for remove_n_layers in range(3):
  layers_to_keep = list(range(0, LAYERS-remove_n_layers))
  print(layers_to_keep)
  small_model = deleteEncodingLayers(model, layers_to_keep)
  new_encodings = small_model.encode('hello')
  encodings = model.encode('hello')
  cos = util.cos_sim(encodings, new_encodings)
  print(cos)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
tensor([[1.0000]])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
tensor([[0.6800]])
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
tensor([[0.5433]])


## Check time reduction

In [9]:
%timeit model.encode('hello how are you?')

11.6 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
%timeit small_model.encode('hello how are you?')

9.58 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Check quality

In [11]:
dataset_name = config.DATASET
save_path = Path(f"data/{dataset_name}")
df = pd.read_pickle(save_path/"data.pkl")
df.head()

Unnamed: 0,query_text,doc_text,relevance
0,How does Quora look to a moderator?,What does the Quora website look like to members of Quora moderation?,1
1,How do I refuse to chose between different things to do in my life?,Is it possible to pursue many different things in life?,1
2,Did Ben Affleck shine more than Christian Bale as Batman?,"According to you, whose Batman performance was best: Christian Bale or Ben Affleck?",1
3,Did Ben Affleck shine more than Christian Bale as Batman?,"No fanboys please, but who was the true batman, Christian Bale or Ben Affleck?",1
4,Did Ben Affleck shine more than Christian Bale as Batman?,Who do you think portrayed Batman better: Christian Bale or Ben Affleck?,1


In [12]:
texts = df.query_text.unique().tolist()

In [24]:
vectors = model.encode(texts)
small_vectors = small_model.encode(texts)
distill_vectors = distill_model.encode(texts)
vectors.shape, small_vectors.shape, distill_vectors.shape

((5000, 768), (5000, 768), (5000, 384))

In [25]:
index = FAISS(768, gpu=False)
small_index = FAISS(768, gpu=False)
distill_index = FAISS(384, gpu=False)

In [26]:
for i in trange(len(texts)):
    index.add(texts[i], [vectors[i]])
    small_index.add(texts[i], [small_vectors[i]])
    distill_index.add(texts[i], [distill_vectors[i]])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5000.0), HTML(value='')))




## paraphrase-mpnet-base-v2 results

In [27]:
query = "how to get better health"

query_vector = model.encode([query])
print(query_vector.shape)
index.search(query_vector, dataframe=True, k=20)

(1, 768)


Unnamed: 0,text,cosine_sim
0,Is long distance running healthy?,0.55
1,What are some good tips to live to be 100?,0.55
2,What is the healthiest food?,0.53
3,How do you stay energetic?,0.53
4,What is the best way to improve stamina?,0.53
5,What is the best supplement to use if I need more energy?,0.52
6,What does a healthy diet consist of?,0.52
7,How can I gain weight and develop fitness?,0.51
8,Is eating bread good for health?,0.47
9,What are the health benefits of doing pushups everyday?,0.47


## cropped paraphrase-mpnet-base-v2 results

In [28]:
query_vector = small_model.encode([query])
print(query_vector.shape)
small_index.search(query_vector, dataframe=True, k=20)

(1, 768)


Unnamed: 0,text,cosine_sim
0,How do I make myself more productive and happy?,0.73
1,How can we lead a better life?,0.72
2,What is the best way to improve stamina?,0.72
3,How can I change my baby fat in a healthy way？?,0.71
4,How can I take excellent care of my teeth?,0.71
5,How do I maintain our face clean and oily less?,0.7
6,How does one become more strategic?,0.7
7,How do you stay energetic?,0.7
8,What are the best ways to improve your body language?,0.7
9,How can I gain weight and develop fitness?,0.69


## paraphrase-MiniLM-L3-v2 results

In [31]:
query_vector = distill_model.encode([query])
print(query_vector.shape)
distill_index.search(query_vector, dataframe=True, k=20)

(1, 384)


Unnamed: 0,text,cosine_sim
0,Which is the best health insurance policy?,0.52
1,What EHR/EMR is best for public health departments?,0.48
2,How do I get better at developing a new skill?,0.47
3,Is eating bread good for health?,0.44
4,What are the health benefits of doing pushups everyday?,0.43
5,Is it better for our health to shower at night or in the morning?,0.42
6,Why do good people suffer more in life?,0.41
7,Is M.D quite a sufficient degree to treat the patient?,0.41
8,What is the best medicine for sex?,0.41
9,How can we lead a better life?,0.4
