In [82]:
import autorootcwd
import chromadb
import requests
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from chromadb import Documents, EmbeddingFunction, Embeddings
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
import pandas as pd
import numpy as np
import uuid
import os

# Prepare data from taxonomy for migration to ChromaDB

Our final solution includes creating a vector database for storing IPTC category embeddings. Therefore, we migrate the embeddings of our IPTC taxonomy to ChromaDB.

First of all, we load the old taxonomy along with expanded taxonomy, created with the extended IPTC category descriptions

In [31]:
taxonomy = pd.read_csv('data/taxonomy/taxonomy.csv', index_col=0)
taxonomy_expanded = pd.read_csv('data/taxonomy/taxonomy_expanded.csv', index_col=0)

In our vector database, we will keep the information about name, hierarchy and extended description of the IPTC category

In [98]:
taxonomy_chroma = pd.DataFrame(index=taxonomy.index, columns=['name', 'hierarchy', 'final_description'])
taxonomy_chroma[['name', 'hierarchy']] = taxonomy[['name', 'hierarchy']]
# combine name and description
taxonomy_chroma['final_description'] = 'Name: ' + taxonomy['name'] + '\nDescription: ' + taxonomy['description']
# for final_description in taxonomy_expanded, update taxonomy_chroma
for index, row in taxonomy_expanded.iterrows():
    taxonomy_chroma.loc[index, 'final_description'] = row['final_description']

In [100]:
taxonomy_chroma.to_csv('data/taxonomy/taxonomy_chroma.csv')

We prepare a metadata list (which will be used when migrating data from taxonomy to ChromaDB) out of all gathered information about our IPTC categories

In [111]:
hierarchy = list(taxonomy_chroma['hierarchy'])
name = list(taxonomy_chroma['name'])
name_code = list(taxonomy_chroma.index)
documents = list(taxonomy_chroma['final_description'])
ids = [str(uuid.uuid4()) for _ in range(len(name))]

metadata = {
    'name': name,
    'name_code': name_code,
    'hierarchy': hierarchy
}

metadatas = [{k: v[i] for k, v in metadata.items()} for i in range(len(metadata['name']))]


# Initialize chroma database

Setup local chroma server by using command `chroma run --path ./chroma`

In [103]:
client = chromadb.HttpClient(host='localhost', port=8000)

INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.
INFO:chromadb.telemetry.product.posthog:Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.


Get embedding function from HuggingFace (you need to provide your HuggingFace API key)

In [84]:
hugging_face_key = os.environ['HUGGING_FACE_KEY']

huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
    api_key=hugging_face_key,
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

Create vector database with IPTC top hierarchy embeddings

In [114]:
collection = client.create_collection("vector_db", embedding_function=huggingface_ef)
collection.add(
    ids = ids,
    documents = documents,
    metadatas = metadatas
)

Example query

In [None]:
collection.query(query_texts=["covid-19 - people die in hospital"],
                 n_results=1)

Another example query with filtering over hierarchy = 1

In [122]:
collection.query(query_texts=["covid-19 - people die in hospital"],
                 n_results=2,
                 where={"hierarchy": 1})

{'ids': [['970caa29-7d80-46c0-8334-81baa7d322cf',
   'deff917c-84d5-4d76-abb5-a16057f6f657']],
 'distances': [[1.7081502676010132, 1.7603156566619873]],
 'embeddings': None,
 'metadatas': [[{'hierarchy': 1,
    'name': 'disaster and accident',
    'name_code': 'subj:03000000'},
   {'hierarchy': 1, 'name': 'health', 'name_code': 'subj:07000000'}]],
 'documents': [['Name: disaster and accident\nDescription: Man made and natural events resulting in loss of life or injury to living creatures and/or damage to inanimate objects or property. \nKeywords: earthquake, flood, hurricane, tsunami, tornado, wildfire, drought, volcanic eruption, avalanche, landslide, industrial accident, nuclear disaster, chemical spill, oil spill, radiation leak, explosion, structural collapse, mining disaster, transportation accident, airplane crash, train derailment, shipwreck, traffic collision, bridge failure, pipeline burst, electrical fire, gas leak explosion, terrorism, armed conflict, riot, building fire, da

# *Optional: use Langchain

In [71]:
embedding_function = SentenceTransformerEmbeddings(model_name='all-MiniLM-L6-v2')

db_langchain = Chroma(
    client=client,
    collection_name="all-MiniLM-L6-v2_DB",
    embedding_function=embedding_function,
)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:sentence_transformers.SentenceTransformer:Use pytorch device: cpu


In [74]:
query = "John doe killed 2 people in a car crash"
docs = db_langchain.similarity_search(query)

Batches: 100%|██████████| 1/1 [00:00<00:00, 29.83it/s]
