In [None]:
!pip install faiss-cpu
!pip install -U sentence-transformers
import numpy as np
import torch
import os
import json
from itertools import islice
import pandas as pd
import faiss
import time
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from faiss.contrib.ondisk import merge_ondisk

### Constants

In [None]:
# FAISS constants
d = 768  # vectors dimensionality
nlist = 100
m = 12                             # number of subquantizers
k = 10
nbits = 8   # each subvector encodes with nbits, bucketing
probe = 10

index_dir = "/content/"
path_to_vectors = "/content/vectors/"


device = "cpu"
if torch.cuda.is_available():
  device = "cuda:0"

### Loading and generating vectors utils

In [None]:
def generate_vectors(path_to_vectors=path_to_vectors, partitions=10, number=1_000_000, size=768):
  os.makedirs(path_to_vectors, exist_ok=True)
  for i in range(partitions):
    data = np.random.rand(number // partitions, 768)
    f = os.path.join(path_to_vectors + f"block_{i}.npy")
    np.save(f, data)
  
  with open(path_to_vectors + "meta.txt", "w") as f:
    files = [f"block_{i}.npy" for i in range(partitions)]
    ids = [(number // partitions * idx, number // partitions * (idx + 1)) for idx in range(partitions)]
    f.write(json.dumps({"file_num": partitions, "number": number, "file_to_ids" : 
                        dict(zip(files, ids))}))


def read_vectors_iterator(path_to_vectors=path_to_vectors):
  meta = {}
  try:
    with open(path_to_vectors + "meta.txt", "r") as f:
      meta = json.load(f)
    meta["file_to_ids"] # throughs exception
  except:
    print("wrong file format or file is not present in dirictory")
    return

  for filename, ids in meta["file_to_ids"].items():
    f = os.path.join(path_to_vectors, filename)
    yield np.load(f), ids

### Model and data loading for text task

In [None]:
def load_data():
  # load simple dataset 
  df=pd.read_csv("https://github.com/franciscadias/data/raw/master/abcnews-date-text.csv")
  data=df.headline_text.to_list()
  return data

def load_model():
  return SentenceTransformer('distilbert-base-nli-mean-tokens')

### FAISS utility functions

In [None]:
def train_index(data, d=d, nlist=nlist, m=m, k=k, nbits=nbits, index_dir=index_dir):
  quantizer = faiss.IndexFlatL2(d)
  index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
  index.train(data)
  faiss.write_index(index, index_dir + "trained.index")
  return index

def populate_index(data_iterator, index_dir=index_dir):
  partition = 0
  for partition, data in enumerate(data_iterator):
    data, indicies = data
    print(f"adding part of data {partition} to index")
    index = faiss.read_index(index_dir + "trained.index")
    index.add_with_ids(data, np.arange(indicies[0], indicies[1]))
    faiss.write_index(index, index_dir + f"block_{partition}.index")
  
  # construct the output index
  index = faiss.read_index(index_dir + "trained.index")
  block_fnames = [
      index_dir + f"block_{idx}.index"
      for idx in range(partition)
  ]
  merge_ondisk(index, block_fnames, index_dir + "merged_index.ivfdata")
  faiss.write_index(index, index_dir + "populated.index")

def execute_query(query, k=k, probe=probe, populated_index_path=index_dir + "populated.index"):
  index = faiss.read_index(index_dir + "populated.index")
  index.nprobe = probe
  return index.search(query, k)

### Text task solution pipeline

In [None]:
def text_generating_index_pipeline(partitions=10, training_part=0.5, path_to_vectors=path_to_vectors, path_to_index=index_dir):
  data = load_data()
  model = load_model()
  model.to(device)
  os.makedirs(path_to_vectors, exist_ok=True)

  print("decoding texts")
  meta = ({"file_num": partitions, "number": len(data), "file_to_ids" : {}})
  for idx in tqdm(range(partitions)):
    i0, i1 = int(idx * len(data) / partitions), int((idx + 1) * len(data) / partitions)
    if idx == partitions - 1:
      i1 = len(data)
    encoded = model.encode(data[i0:i1])
    f = os.path.join(path_to_vectors + f"block_{idx}.npy")
    np.save(f, encoded)
    meta["file_to_ids"][f"block_{idx}.npy"] = (i0, i1)

  with open(path_to_vectors + "meta.txt", "w") as f:
    f.write(json.dumps(meta))
  
  training_data = np.array([]).reshape((0, 768))
  print("generating train data")
  for idx in tqdm(range(int(partitions * training_part))):
    new_part = np.load(path_to_vectors + f"block_{idx}.npy")
    training_data = np.concatenate((training_data, new_part))
  
  print("training index")
  train_index(training_data, index_dir=path_to_index)
  populate_index(read_vectors_iterator(path_to_vectors), index_dir=path_to_index)

### Generating index

In [None]:
text_generating_index_pipeline(1000)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/550 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/450 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

decoding texts


100%|██████████| 1000/1000 [07:18<00:00,  2.28it/s]


generating train data


100%|██████████| 500/500 [05:03<00:00,  1.65it/s]


training index
adding part of data 0 to index
adding part of data 1 to index
adding part of data 2 to index
adding part of data 3 to index
adding part of data 4 to index
adding part of data 5 to index
adding part of data 6 to index
adding part of data 7 to index
adding part of data 8 to index
adding part of data 9 to index
adding part of data 10 to index
adding part of data 11 to index
adding part of data 12 to index
adding part of data 13 to index
adding part of data 14 to index
adding part of data 15 to index
adding part of data 16 to index
adding part of data 17 to index
adding part of data 18 to index
adding part of data 19 to index
adding part of data 20 to index
adding part of data 21 to index
adding part of data 22 to index
adding part of data 23 to index
adding part of data 24 to index
adding part of data 25 to index
adding part of data 26 to index
adding part of data 27 to index
adding part of data 28 to index
adding part of data 29 to index
adding part of data 30 to index
add

### Querying index

In [None]:
model = load_model()
model.to(device)
data = load_data()

In [None]:
execute_query(model.encode(data[:5]))

(array([[ 71.38407 ,  84.14787 ,  86.641426,  86.70545 ,  86.98414 ,
          88.393394,  89.011826,  89.27888 ,  90.28634 ,  90.89114 ],
        [ 89.8884  , 108.68803 , 109.99781 , 110.69162 , 111.576546,
         112.152794, 113.09178 , 113.09522 , 113.553474, 113.887436],
        [ 66.44713 ,  81.6371  ,  82.32962 ,  83.49385 ,  85.31445 ,
          85.33592 ,  85.78581 ,  85.80759 ,  86.21858 ,  87.25972 ],
        [ 90.70556 ,  97.824234, 100.00536 , 101.104004, 101.54645 ,
         102.37016 , 104.18228 , 105.69875 , 106.291   , 106.41821 ],
        [ 68.743195,  85.63048 ,  86.62741 ,  87.617516,  88.79246 ,
          89.03182 ,  90.634476,  91.612206,  91.84524 ,  91.846275]],
       dtype=float32),
 array([[      0,  116826,   75117,   33519, 1063832,  181021,   21498,
          849971,  665155, 1062953],
        [      1,  262540,  901689,  621413,   43449,  273522,  105431,
            4572,  667934,  142663],
        [      2,  350162,  390736,    2029,  261048,   11678, 

Можно заметить, в полученных результатах, текстовый эквивалент которых приведен ниже, что движок для поиска ближайших соседей основанный на индексе FAISS действительно работает.

In [None]:
for data_index in [0, 116826, 75117, 33519, 1063832, 181021, 21498,849971, 665155, 1062953]:
  print(data[data_index])

aba decides against community broadcasting licence
council bans bondi butts
commonwealth requests media ban from high court
court overturns muslim centre ban
judge blocks trumps travel ban
lantana to be banned from wa
aba suspends wagin radio licence
judge throws out attempt to limit eastman inquiry
council bans csg projects
brenton kelly dickson appeal denied


In [None]:
for data_index in [1, 262540, 901689, 621413, 43449, 273522, 105431, 4572, 667934, 142663]:
  print(data[data_index])

act fire witnesses must be aware of defamation
accused abuser threatens defamation against
cfmeu officials respond to allegations of intimidation
bashing victim pleads for witnesses
carr stands by blackmail claims
prouds accused of misleading advertising
ethicist under fire for human rights views
warne and lee blackmail witness statements
albanese accused of plagiarising speech
lawyers criticise surveillance of habib


In [None]:
for data_index in [3, 408801, 28128, 4805, 512029, 33887, 289143, 169987, 387149, 999341]:
  print(data[data_index])

air nz staff in aust strike for pay rise
tafe lecturers call for pay rise
childcare workers apply to airc for pay rise
workers union set to campaign for pay rise
maritime workers strike for big pay rise
salt works staff seek pay rise
bridgestone workers to strike to get pay rise
councillors push for pay rise
union seeks pay rise for academic staff
call for mackay council workers to get interim pay rise
