In [None]:
# Install libraries
!pip install -q datasets -U sentence_transformers chromadb

In [None]:
import torch

# Check for GPU availability
def try_gpu():
  return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Download dataset from the hub
from datasets import load_dataset, Dataset

docs = load_dataset(f"Cohere/wikipedia-22-12-simple-embeddings", split="train")

# Extract only columns [ids, title, text]
ids, titles, passages = docs["id"], docs["title"], docs["text"]

In [None]:
# Download sentence transformer embeddings model
from sentence_transformers import SentenceTransformer, util

device = try_gpu()
model = SentenceTransformer('multi-qa-mpnet-base-dot-v1').to(device)

In [None]:
# Create Vector database with ChromaDB
import os
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

# Specify embedding function to use with ChromaDB. Should be the same as the one used for embedding passages.
ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="multi-qa-mpnet-base-dot-v1")

In [None]:
def check_connection(client, db):
    cols = client.list_collections()
    if len(cols) == 0:
      return False
    col_name = cols[0].dict().get("name")

    # CHECK if nsmq_ai db exists in collection
    if col_name==db:
        return True

    return False

In [None]:
def create_db(client, db, docs, ids, embds=None):
  # Create a new db if db does not exist, else get existing db and update it
  db_exists = check_connection(client, db)

  if db_exists:
    collection = client.get_collection(name=db, embedding_function=ef)
  else:
    collection = client.create_collection(name=db, metadata={"hnsw:space":"cosine"}, embedding_function=ef)

  ids_as_strings = [str(id) for id in ids]

  print("Embeddings generation will take some time depending on number of documents. Processing...")

  if embds is not None:
    collection.add(
        documents=docs,
        embeddings=embds.tolist(),
        ids=ids_as_strings
    )
  else:
    collection.add(
        documents=docs,
        ids=ids_as_strings
    )

  print(f"Total docs stored in db: {collection.count()}")

  client.persist()

In [None]:
# Specify database save location (google drive in this case)
from google.colab import drive

drive.mount("/content/drive")

In [None]:
# SETUP DATABASE
db = "<DB NAME>"    # Replace with prefered name for databse.
db_path = "<PATH/TO/DB" # Replace with prefered storage path for database (e.g., a location in Google Drive).
path_to_db = os.path.join(db_path, "Database")
client = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=path_to_db))

db_exists = check_connection(client, db)

if db_exists:
    collection = client.get_collection(name=db, embedding_function=ef)
else:
    if not os.path.isdir(path_to_db): os.mkdir(path_to_db)
    # Create vector database for the first 10 elements.
    create_db(client, db, docs=passages[:10], ids=ids[:10])
    collection = client.get_collection(name=db, embedding_function=ef)

In [None]:
# Get number of rows in database
collection.count()

In [None]:
# Generate and save embeddings for new passages
# Due to resource (memory) constraints, we add
# new entries to the database in chunks of 10,000
%timeit
from tqdm import tqdm

step = 10000
for x in tqdm(range(10, len(ids), step)):
  start = x
  end = x + step
  corpus_embds = model.encode(passages[start:end], batch_size=128, show_progress_bar=True)
  corpus_embds = torch.from_numpy(corpus_embds)
  create_db(client, db, passages[start:end], ids[start:end], corpus_embds)

In [None]:
def query_db(query, n_results=1):
  results = collection.query(
      query_texts=[query],
      n_results=n_results
  )

  docs, scores = results["documents"], results["distances"]
  context = ""
  for doc in docs[0]:
    context += doc + " | "
  scores = round((sum(scores[0]))/n_results, 2)

  return context, scores