# How to add a faiss index to your datasets
https://huggingface.co/docs/datasets/faiss_es#faiss

In [None]:
!pip install datasets transformers faiss-cpu >> /dev/null
# faiss-cpu pip package is not official

In [None]:
from datasets import load_dataset
dataset=load_dataset("huggan/smithsonian_butterflies_subset") # any dataset you want

In [None]:
# Any model you want, also look at feature extraction pipelines!
# This example is with images but you can embed anything! just pick your model
from transformers import BeitFeatureExtractor, BeitModel
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224')

In [None]:
# Writing the function this way we can use it also during query time
def embed(images):
    inputs = feature_extractor(images=images, return_tensors="pt")
    outputs = model(**inputs,output_hidden_states= True)
    final_emb=outputs.pooler_output.detach().numpy() # this line depends on the model you are using
    return final_emb    

In [None]:
# Add embeddings to dataset
dataset_emb = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
dataset_emb.add_faiss_index(column='beit_embeddings')
dataset_emb.save_faiss_index('beit_embeddings', 'beit_index.faiss') # (optional) save to disk

In [None]:
# or just load from disk skip the .map cell above
# dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')  

In [None]:
# query
query_image=dataset[0]["image"]
scores, result_k=dataset_emb.get_nearest_examples('beit_embeddings', embed(query_image), k=5)

In [None]:
# View results
from IPython.display import display
for x in result_k["image"]:
    display(x)

# If you werent using datasets faiss support you would have to do it like the below cells:
there maybe better ways but this is what I was doing

In [None]:
import faiss

def embed(ex, idx=None, add_index=True): # Ugh, how ugly!
    if add_index:
        images = ex["image"]
    else:
        images = ex
    inputs = feature_extractor(images=images, return_tensors="pt")
    outputs = model(**inputs,output_hidden_states= True)
    final_emb=outputs.pooler_output.detach().numpy()
    
    if add_index:
        ex['idx'] = idx
        index.add(final_emb)                  # add vectors to the index
    else: 
        dist,idx=index.search(final_emb,idx) 
        return dist,idx

In [None]:
d= 768
index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
dataset=dataset.map(embed,batched=True,batch_size=20,with_indices=True)
print(index.ntotal)

In [None]:
query_image=dataset[3]
distances,res_ids=embed(query_image['image'],5,add_index=False)