In [1]:
import os
import json
import logging
from typing import Optional
from pathlib import Path

import pandas as pd
import numpy as np
from tqdm import tqdm
from file_processing import File
from file_processing.tools.errors import EmptySelection

import faiss
from sentence_transformers import SentenceTransformer

# Preprocess the files to embeddings

In [2]:
from file_processing import Directory

directory = Directory('./tests/resources/similarity_test_files/')

Preprocess the files to extract the text

In [3]:
data = [file.processor.__dict__ for file in directory._file_generator()]
data = pd.json_normalize(data, max_level=1, sep='_')
df = pd.DataFrame(data)

if df.empty:
    raise EmptySelection(f'Filtered selection of files is empty')
elif not df.empty:
    df = df.get(['size', 'extension', 'file_name',
                'metadata_text', 'absolute_path'])
    
df['metadata_text'] = df['metadata_text'].str.strip()
df['metadata_text'] = df['metadata_text'].str.replace('\n', '')

# Only keeping pdf/docx/txt files with sufficiently long 'text' metadata
df = df[(df['extension'].isin(['.pdf', '.docx', '.txt'])) &
        (df['metadata_text'].str.len() > 10) &
        (df['metadata_text'].notnull())]
df = df.reset_index(drop=True)
file_names = df.file_name

# Encoding
encoder = SentenceTransformer("paraphrase-MiniLM-L3-v2")
vectors = encoder.encode(df['metadata_text'])

df['metadata_text']

Processing files: 20 files completed [00:00, 79.66 files completed/s]


0     Aviation safety in CanadaFrom: Transport Canad...
1     The Canadian ConstitutionA constitution provid...
2     Causes of climate changeWhat is the most impor...
3     COVID-19: Symptoms, treatment, what to do if y...
4     Canada Pension Plan disability benefitsOvervie...
5     CPP Retirement pensionOverviewThe Canada Pensi...
6     Documents for Express EntryYou need certain do...
7     EI regular benefitsHow much you could receiveF...
8     How Express Entry worksExpress Entry is an onl...
9     Funding - Culture, history and sportCOVID-19: ...
10    Canada's health care systemLearn about Canada'...
11    History of CanadaCanadian history does not beg...
12    How the Courts are OrganizedPrevious Page Tabl...
13    Our Security, Our RightsOn June 21, 2019, an A...
14    Net-zero emissions by 2050The transition to a ...
15    Origin of the name "Canada"Today, it seems imp...
16    Personal income taxGet ready to do your taxesC...
17    Starting a businessTable of contentsBefore

In [4]:
vectors.shape

(20, 384)

In [5]:
np.save("tests/resources/faiss_test_files/sample_embeddings.npy", vectors)

In [6]:
np.load("tests/resources/faiss_test_files/sample_embeddings.npy").shape

(20, 384)

Do the FAISS Development

In [7]:
d = vectors.shape[1]

index = faiss.IndexFlatIP(d)
index.add(vectors)
index.ntotal

20

In [8]:
query_vec = encoder.encode(["Define data science", "what is the meaning of life?"])
# query_vec = query_vec[np.newaxis,:]
query_vec.shape

(2, 384)

In [9]:
np.save("tests/resources/faiss_test_files/sample_query_vector.npy", query_vec)

In [10]:
D, I = index.search(query_vec, 3)

In [11]:
print(D)

[[2.786271   2.458615   1.839469  ]
 [1.2800186  1.0237789  0.94802046]]


In [12]:
print(I)

[[14  2 10]
 [15  1 14]]


In [13]:
from importlib import reload
from file_processing.faiss_index import faiss_strategy
from file_processing.faiss_index import flat_index
reload(flat_index)
reload(faiss_strategy)

flat = flat_index.FlatIndex(vectors)
print(flat.query(query_vec, 3))
flat.save_index("file_processing/faiss_index/test_index.faiss")

(array([[37.001266, 37.346497, 38.138252],
       [26.461792, 27.62289 , 28.612665]], dtype=float32), array([[14, 12, 10],
       [15, 12, 19]], dtype=int64))


In [14]:
from file_processing import faiss_index
reload(faiss_index)

flat_index = faiss_index.create_flat_index(vectors, "file_processing/faiss_index/flat.faiss")
flat_index.query(query_vec, 3)

(array([[37.001266, 37.346497, 38.138252],
        [26.461792, 27.62289 , 28.612665]], dtype=float32),
 array([[14, 12, 10],
        [15, 12, 19]], dtype=int64))

Create some files for testing

In [15]:
flat = faiss_index.create_flat_index(vectors, "tests/resources/faiss_test_files/flat.faiss")
ivf = faiss_index.create_IVF_flat_index(vectors, file_path="tests/resources/faiss_test_files/ivf.faiss")
# hnsw = faiss_index.create_HNSW_index(vectors, file_path="tests/resources/faiss_test_files/hnsw.faiss")

In [16]:
hnsw = faiss_index.create_HNSW_index(vectors, file_path="tests/resources/faiss_test_files/hnsw.faiss")

In [25]:
type(hnsw)

file_processing.faiss_index.HNSW_index.HNSWIndex

In [37]:
ivf2 = faiss_index.create_IVF_flat_index(vectors, nlist=10)
ivf2.query(query_vec, k=8, nprobe=1)

(array([[3.7346497e+01, 3.8138252e+01, 3.4028235e+38, 3.4028235e+38,
         3.4028235e+38, 3.4028235e+38, 3.4028235e+38, 3.4028235e+38],
        [2.6461792e+01, 2.8612665e+01, 2.9678551e+01, 2.9906456e+01,
         3.1259750e+01, 3.2736320e+01, 3.4028235e+38, 3.4028235e+38]],
       dtype=float32),
 array([[12, 10, -1, -1, -1, -1, -1, -1],
        [15, 19,  6, 11,  9, 18, -1, -1]], dtype=int64))

In [44]:
# try creating new index
d = vectors.shape[1]
quantizer = faiss.IndexFlatL2(d)
ivfpq = faiss.IndexIVFPQ(quantizer, d, 4, 8, 3)
ivfpq.train(vectors)
ivfpq.add(vectors)
faiss.write_index(ivfpq, "tests/resources/faiss_test_files/ivfpq.faiss")