In [None]:
from __future__ import annotations
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
import os
import sys
import chromadb

# change the path to the backend directory
sys.path.append(os.path.join(os.path.dirname("."), "../../backend/"))

In [None]:
from modules.utils import load_config_and_device
from modules.llm import setup_vector_db_and_qa
from modules.results_gen import aggregate_multiple_queries_and_count

In [None]:
# Config and DB

# load the configuration and device
config = load_config_and_device("../../backend/config.json")
config["persist_dir"] = "../../backend/data/chroma_db/"
config["data_dir"] = "../../backend/data/"

In [None]:
# load the persistent database using ChromaDB
client = chromadb.PersistentClient(path=config["persist_dir"])

In [None]:
# Setup llm chain, initialize the retriever and llm, and setup Retrieval QA
qa_dataset = setup_vector_db_and_qa(config=config, data_type="dataset", client=client)

## Aggregate results

In [None]:
queries = [
    "Find datasets related to COVID-19",
    "Find datasets related to COVID-19 and India",
    "COVID-19 dataset",
    "COVID-19 dataset India",
    "Mexico historical covid",
]
combined_df = aggregate_multiple_queries_and_count(
    queries,
    qa_dataset=qa_dataset,
    config=config,
    group_cols=["id", "name"],
    sort_by="query",
    count=True,
)

In [None]:
combined_df.head()

## Just collate

In [None]:
queries = [
    "Find datasets related to COVID-19",
    "Find datasets related to COVID-19 and India",
    "COVID-19 dataset",
    "COVID-19 dataset India",
    "Mexico historical covid",
]
combined_df = aggregate_multiple_queries_and_count(
    queries,
    qa_dataset=qa_dataset,
    config=config,
    group_cols=["id", "name"],
    sort_by="query",
    count=False,
)

In [None]:
combined_df.head()