In [1]:
import chromadb

from modules.llm import *
from modules.utils import *
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
set_llm_cache(SQLiteCache(database_path="./data/.langchain.db"))

In [6]:
# load the configuration and device
config = load_config_and_device("config.json")
# load the persistent database using ChromaDB
client = chromadb.PersistentClient(path=config["persist_dir"])
print(config)
# Loading the metadata for all types

# Setup llm chain, initialize the retriever and llm, and setup Retrieval QA
qa_dataset = setup_vector_db_and_qa(config=config, data_type="dataset", client=client)

[INFO] Finding device.
[INFO] Device found: cpu
{'rqa_prompt_template': 'This database is a list of dataset metadata. Use the following pieces of context to find the relevant document. Answer only from the context given using the {question} given. If you do not know the answer, say you do not know. {context}', 'num_return_documents': 50, 'embedding_model': 'BAAI/bge-base-en-v1.5', 'persist_dir': './data/chroma_db/', 'data_download_n_jobs': 20, 'training': False, 'search_type': 'similarity', 'temperature': 0.95, 'top_p': 0.95, 'reranking': False, 'hybrid_search': True, 'device': 'cpu'}
[INFO] Loading metadata from file.
[INFO] Metadata loaded.
[INFO] Loading model...




[INFO] Model loaded.


In [7]:
openml_data_object, data_id, all_metadata = get_all_metadata_from_openml(
        config=config
    )
# Create the combined metadata dataframe
metadata_df, all_metadata = create_metadata_dataframe(
    openml_data_object, data_id, all_metadata, config=config
)
# Create the vector store
vectordb = load_document_and_create_vector_store(
    metadata_df, config=config, chroma_client=client
)

[INFO] Loading metadata from file.
[INFO] Metadata loaded.
[INFO] Loading model...




[INFO] Model loaded.


In [None]:
qa = initialize_llm_chain(vectordb=vectordb, config=config)

## Get the most frequently used words in the dataset descriptions

In [3]:
with open("data/all_dataset_description.csv", "r") as f:
    dataset_descriptions = f.read()

In [4]:
dataset_descriptions[:100]

'did,description,qualities,features\n2,"**Author**: Unknown. Donated by David Sterling and Wray Buntin'

## Aggregate and test multiple queries

In [3]:
## Aggregate results from multiple queries
queries = ["Find datasets related to COVID-19", "Find datasets related to COVID-19 and India", "COVID-19 dataset", "COVID-19 dataset India", "Mexico historical covid"]

combined_df = aggregate_multiple_queries_and_count(queries,qa_dataset=qa_dataset, config=config, group_cols = ["id", "name"], sort_by="query")

In [4]:
combined_df.head(10)

Unnamed: 0,id,name,query
56,43733,Covid-19--historical-data,5
35,43400,COVID-19-community-mobility-reports,4
47,43509,COVID-19-Rio-de-Janeiro-(City),4
33,43365,Covid-19-Case-Surveillance-Public-Use-Dataset,4
34,43367,COVID-19-Indonesia-Dataset,4
37,43405,Covid-19-Turkey-Daily-Details-Dataset,4
38,43410,Coronavirus-Disease-(COVID-19),4
39,43412,COVID-19-Visualisation-and-Epidemic-Analysis-Data,4
41,43428,Mexico-COVID-19-clinical-data,4
42,43457,COVID19-Dataset-with-100-World-Countries,4


In [6]:
queries = ["Find me datasets related to mushrooms", "Fungi dataset", "Mushroom dataset", "shroom data", "types of mushroom", "earth fungus", "low features mushroom dataset"]
combined_df = aggregate_multiple_queries_and_count(queries,qa_dataset=qa_dataset, config=config, group_cols = ["id", "name"], sort_by="query")
combined_df.head(10)

Unnamed: 0,id,name,query
80,44272,Meta_Album_FNG_Micro,6
97,44302,Meta_Album_FNG_Mini,6
1,24,mushroom,6
113,44335,Meta_Album_FNG_Extended,6
98,44303,Meta_Album_PLT_DOC_Mini,5
71,44242,Meta_Album_PLT_VIL_Micro,5
108,44321,Meta_Album_PLT_VIL_Extended,5
81,44273,Meta_Album_PLT_DOC_Micro,5
114,44336,Meta_Album_PLT_DOC_Extended,5
67,44237,Meta_Album_BCT_Micro,5


In [9]:
queries = ["plant datasets, low features", "plant, less number of features", "plant dataset, tiny"]
combined_df = aggregate_multiple_queries_and_count(queries,qa_dataset=qa_dataset, config=config, group_cols = ["id", "name"], sort_by="query")
combined_df.head(10)

Unnamed: 0,id,name,query
22,44154,iris_reproduced,3
34,44299,Meta_Album_MED_LF_Mini,3
29,44273,Meta_Album_PLT_DOC_Micro,3
24,44242,Meta_Album_PLT_VIL_Micro,3
20,40983,wilt,3
32,44286,Meta_Album_PLT_VIL_Mini,3
33,44293,Meta_Album_PLT_NET_Mini,3
16,1493,one-hundred-plants-texture,3
15,1492,one-hundred-plants-shape,3
14,1491,one-hundred-plants-margin,3
