In [1]:
from modules.llm import *
from modules.utils import *

## Config

In [6]:
config = {
    "rqa_prompt_template" : "This database is a list of dataset metadata. Use the following pieces of context to find the relevant document. Answer only from the context given using the {question} given. If you do not know the answer, say you do not know. {context}",
    "num_return_documents" : 50,
    "embedding_model": "BAAI/bge-base-en-v1.5",
    # "embedding_model": "Intel/bge-small-en-v1.5-rag-int8-static",
    "llm_model": "HuggingFaceH4/zephyr-7b-beta",
    "persist_dir": "./chroma_db/",
    # "recreate_chroma": False,
    "recreate_chroma": True,
    "recreate_data_cache" : False,
    # "recreate_data_cache" : True,
    "data_download_n_jobs" : 20,
    "device" : "mps" # Change to "cuda" if you have a GPU or "cpu" if you don't. MPS is for Mac M{1..3} machines.
}

## Data Processing

In [3]:
openml_data_object, data_id, all_dataset_metadata = get_all_dataset_metadata_from_openml(recreate_cache=config["recreate_data_cache"], n_jobs=config["data_download_n_jobs"])

In [4]:
metadata_df, all_dataset_metadata = create_metadata_dataframe(openml_data_object, data_id, all_dataset_metadata)
metadata_df = clean_metadata_dataframe(metadata_df)

In [10]:
def load_document_and_create_vector_store(
    metadata_df,
    persist_directory="./chroma_db/",
    # model_name="BAAI/bge-base-en-v1.5",
    model_name = "Intel/bge-small-en-v1.5-rag-int8-static",
    device="cpu",
    normalize_embeddings=True,
    recreate_chroma=False,
) -> Chroma:
    # load model
    model_kwargs = {"device": device}
    encode_kwargs = {"normalize_embeddings": normalize_embeddings, "quantized": True}
    embeddings = HuggingFaceEmbeddings(
        model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs,
    )
    # embeddings = QuantizedBiEncoderEmbeddings(
    #     model_name=model_name,
    #     encode_kwargs=encode_kwargs,
    #     model_kwargs=model_kwargs,
    # )

    # if the directory already exists, load the vector store else create a new one
    if os.path.exists(persist_directory) and recreate_chroma == False:
        db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        return db
    else:
        # load data
        loader = DataFrameLoader(metadata_df, page_content_column="Combined_information")
        documents = loader.load()
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000, chunk_overlap=150
        )
        documents = text_splitter.split_documents(documents)

        ids = [
            str(uuid.uuid5(uuid.NAMESPACE_DNS, doc.page_content)) for doc in documents
        ]
        unique_ids = list(set(ids))
        # Ensure that only docs that correspond to unique ids are kept and that only one of the duplicate ids is kept
        seen_ids = set()
        unique_docs = [
            doc
            for doc, id in zip(documents, ids)
            if id not in seen_ids and (seen_ids.add(id) or True)
        ]

        db = Chroma(embedding_function=embeddings, persist_directory=persist_directory)
        # add documents to the vector store in batches of 100
        # if len(unique_docs) < 100:
        #     db.add_documents(unique_docs, ids=unique_ids)
        # else:
        #     for i in tqdm(range(0, len(unique_docs), 100)):
        #         db.add_documents(unique_docs[i : i + 100], ids=unique_ids[i : i + 100])
        # return db
        for i in tqdm(len(unique_docs)):
            db.add_document(unique_docs[i], ids=unique_ids[i])
        return db

In [11]:
vectordb = load_document_and_create_vector_store(metadata_df, model_name=config['embedding_model'], recreate_chroma=config['recreate_chroma'], persist_directory=config['persist_dir'], device=config['device'])



In [9]:
retriever, llm = create_retriever_and_llm(vectordb,num_return_documents=config["num_return_documents"], model_repo_id=config["llm_model"])
qa = create_llm_chain_and_query(vectordb=vectordb,retriever=retriever,llm=llm, prompt_template = config["rqa_prompt_template"])

  warn_deprecated(


## Getting results

In [10]:
# %time
query = "Which datasets would be useful for stock market support?"
# query = "Which datasets would be useful for heart disease"
# query = "Which datasets would be useful for flowers"
# query = "Which datasets would be useful for image classification"
# query = "My supervisor wants me to work on cloud cover, which datasets can I use"
# query = "Are there any datasets from the netherlands?"
# query = "Are there any datasets about farm animals?"
# query = "Find chinese authors"
results = create_result_dataframe(query, qa, all_dataset_metadata)

In [11]:
results

Unnamed: 0,did,name,version,uploader,status,format,MajorityClassSize,MaxNominalAttDistinctValues,MinorityClassSize,NumberOfClasses,NumberOfFeatures,NumberOfInstances,NumberOfInstancesWithMissingValues,NumberOfMissingValues,NumberOfNumericFeatures,NumberOfSymbolicFeatures,description
0,150,covertype,3,1,active,ARFF,283301.0,7.0,2747.0,7.0,55.0,581012.0,0.0,0.0,10.0,45.0,**Author**: Albert Bifet \n**Source**: [MOA](...
1,40992,sylva_agnostic,2,4265,active,ARFF,,2.0,,,217.0,14395.0,0.0,0.0,40.0,177.0,This dataset contains samples from both traini...
2,41523,test_dataset,3,8229,active,ARFF,,,,0.0,61.0,15547.0,14.0,280.0,61.0,0.0,Test dataset
3,41533,Domainome,1,8231,active,arff,1059.0,,14.0,3.0,9839.0,1637.0,1637.0,13231887.0,9838.0,0.0,Domain dataset
4,42183,dataset_sales,1,10333,active,ARFF,,,,0.0,15.0,10738.0,0.0,0.0,15.0,0.0,Dataset sales
5,42223,dataset-autoHorse_fixed,1,10792,active,ARFF,2.0,,1.0,186.0,69.0,201.0,0.0,0.0,68.0,1.0,Fixed dataset for autoHorse.csv I suggest...
6,42460,Dataset,1,12366,active,ARFF,,,,,12.0,63.0,52.0,52.0,0.0,12.0,dataset for bme
7,42680,RelevantImagesDatasetTEST,1,17002,active,ARFF,129149.0,,5582.0,2.0,31.0,134731.0,0.0,0.0,27.0,4.0,Testing dataset
8,42689,MY_DB,1,17516,active,ARFF,,,,,7.0,120.0,0.0,0.0,0.0,7.0,MY Dataset
9,42874,Laboratorio_dataset_car,1,25966,active,ARFF,,,,,7.0,1750.0,0.0,0.0,0.0,7.0,Laboratory dataset


In [9]:
results['description'].values[:10]

array(['**Author**: Robert P.W. Duin, Department of Applied Physics, Delft University of Technology  \n**Source**: [UCI](https://archive.ics.uci.edu/ml/datasets/Multiple+Features) - 1998  \n**Please cite**: [UCI](https://archive.ics.uci.edu/ml/citation_policy.html)   \n\n**Multiple Features Dataset: Factors**  \nOne of a set of 6 datasets describing features of handwritten numerals (0 - 9) extracted from a collection of Dutch utility maps. Corresponding patterns in different datasets correspond to the same original character. 200 instances per class (for a total of 2,000 instances) have been digitized in binary images. \n\n### Attribute Information  \nThe attributes represent 216 profile correlations. No more information is known.',
       "**Author**: Robert P.W. Duin, Department of Applied Physics, Delft University of Technology  \n**Source**: [UCI](https://archive.ics.uci.edu/ml/datasets/Multiple+Features) - 1998  \n**Please cite**: [UCI](https://archive.ics.uci.edu/ml/citation_poli

In [11]:
results['name'].values[:10]

array(['covertype', 'sylva_agnostic', 'test_dataset', 'Domainome',
       'dataset_sales', 'dataset-autoHorse_fixed', 'Dataset',
       'RelevantImagesDatasetTEST', 'MY_DB', 'Laboratorio_dataset_car'],
      dtype=object)