In [88]:
import os, tempfile
from pathlib import Path
from glob import glob

from langchain.chains import RetrievalQA, ConversationalRetrievalChain, LLMChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.chroma import Chroma

# from langchain.llms import VertexAI
from langchain_community.llms import HuggingFaceHub
from langchain_community.document_loaders import (
    DirectoryLoader,
    PyPDFLoader,
    PyPDFDirectoryLoader,
    DirectoryLoader,
)
from langchain.text_splitter import (
    CharacterTextSplitter,
    TextSplitter,
    RecursiveCharacterTextSplitter,
)
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.indexes import VectorstoreIndexCreator
from tqdm import tqdm
import pickle
from pqdm.threads import pqdm

import pandas as pd

from typing import Union, List

import openml


In [89]:
def get_dataset_description(dataset_name) -> openml.datasets.dataset.OpenMLDataset:
    try:
        data = openml.datasets.get_dataset(dataset_name, download_data = False, download_qualities = False, download_features_meta_data = False)
    except Exception as e:
        print(e)
    return data

In [90]:
def get_all_dataset_metadata_from_openml(save_filename = "all_dataset_metadata.pkl") -> Union[List, List]:
    # Gather all OpenML datasets
    all_datasets = openml.datasets.list_datasets(output_format="dataframe")

    # List dataset 'did' to be used as an identifier 
    data_id = [all_datasets.iloc[i]['did'] for i in range(len(all_datasets))]

    dataset_names = all_datasets['name'].tolist() # get a list of all dataset names

    # if the file already exists, load it else get the metadata from openml
    if os.path.exists(save_filename):
        with open(save_filename, 'rb') as f:
            all_data_descriptions = pickle.load(f)
        return all_data_descriptions, data_id
    else:
        # Get all dataset metadata using n_jobs parallel threads from openml
        all_data_descriptions = pqdm(dataset_names, get_dataset_description, n_jobs=10)

        # Save the metadata to a file
        with open(save_filename, 'wb') as f:
            pickle.dump(all_data_descriptions, f)
        
        return all_data_descriptions, data_id

In [91]:
def create_metadata_dataframe(all_data_descriptions, data_id) -> pd.DataFrame:
    descriptions = [all_data_descriptions[i].description for i in range(len(all_data_descriptions))]

    all_data_description = dict(zip(data_id, descriptions))

    return pd.DataFrame(list(all_data_description.items()),columns = ['did','description'])

In [92]:
def clean_metadata_dataframe(metadata_df) -> pd.DataFrame:
    # remove rows with empty descriptions
    metadata_df = metadata_df[metadata_df['description'].notna()]
    return metadata_df

In [93]:
metadata_df = create_metadata_dataframe(*get_all_dataset_metadata_from_openml())
metadata_df = clean_metadata_dataframe(metadata_df)

In [94]:
print(metadata_df.loc[20]['description'])

**Author**: Mary McLeish & Matt Cecile, University of Guelph  
Donor: Will Taylor (taylor@pluto.arc.nasa.gov)   
**Source**: [UCI](https://archive.ics.uci.edu/ml/datasets/Horse+Colic) - 8/6/89   

**Horse Colic database**  
Database of surgeries on horses. Possible class attributes: 24 (whether lesion is surgical), others include: 23, 25, 26, and 27

Notes:
* Hospital_Number is an identifier and should be ignored when modelling

Attribute Information:
> 
   1:  surgery?
           1 = Yes, it had surgery
           2 = It was treated without surgery  
   2:  Age 
           1 = Adult horse
           2 = Young (< 6 months)  
   3:  Hospital Number 
           - numeric id
           - the case number assigned to the horse
             (may not be unique if the horse is treated > 1 time)  
   4:  rectal temperature
           - linear
           - in degrees celsius.
           - An elevated temp may occur due to infection.
           - temperature may be reduced when the animal is in l

## Rag part

In [95]:
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import langchain_core

from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import uuid
import chromadb

In [96]:
def load_document_and_create_vector_store(metadata_df, persist_directory = "./chroma_db/", model_name = "BAAI/bge-base-en-v1.5", device = "cpu", normalize_embeddings = True, recreate_chroma = False) -> Chroma:
    # load model
    model_kwargs = {"device": device}
    encode_kwargs = {"normalize_embeddings": normalize_embeddings}
    embeddings = HuggingFaceEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
    )

    # if the directory already exists, load the vector store else create a new one
    if os.path.exists(persist_directory) and not recreate_chroma:
        db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        return db
    else:
        # load data
        # might need to chunk if the descriptions are too large, fine for now
        loader = DataFrameLoader(metadata_df, page_content_column="description")
        documents = loader.load() 
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
        documents = text_splitter.split_documents(documents)
       
        ids = [str(uuid.uuid5(uuid.NAMESPACE_DNS, doc.page_content)) for doc in documents]
        unique_ids = list(set(ids))
        # Ensure that only docs that correspond to unique ids are kept and that only one of the duplicate ids is kept
        seen_ids = set()
        unique_docs = [doc for doc, id in zip(documents, ids) if id not in seen_ids and (seen_ids.add(id) or True)]

        # db = Chroma.from_documents(
        #     unique_docs, embedding=embeddings, persist_directory=persist_directory, ids = unique_ids)
        # db.persist()

        # add to chroma db in batches
        # https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
        db = Chroma(embedding_function=embeddings)
        for i in tqdm(range(0, len(unique_docs), 100)):
            db.add_documents(unique_docs[i:i+100], ids = unique_ids[i:i+100])
        db.persist()
        return db

In [97]:
vectordb = load_document_and_create_vector_store(metadata_df)
# vectordb = load_document_and_create_vector_store(metadata_df, recreate_chroma=True)

  8%|▊         | 5/59 [01:21<14:40, 16.31s/it]


KeyboardInterrupt: 

In [79]:
def create_retriever_and_llm(vectordb,model_repo_id = "HuggingFaceH4/zephyr-7b-beta", num_return_documents = 50,search_type = "similarity"):
    HUGGINGFACEHUB_API_KEY = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    # use export HUGGINGFACEHUB_API_TOKEN=your_token_here to set the token (in the shell)

    retriever = vectordb.as_retriever(search_type=search_type, search_kwargs={"k": num_return_documents})
    llm = HuggingFaceHub(repo_id=model_repo_id,model_kwargs={"temperature": 0.1, "max_length": 512}, huggingfacehub_api_token=HUGGINGFACEHUB_API_KEY)
    return retriever, llm

In [80]:
rqa_prompt_template = "This database is a list of dataset metadata. Use the following pieces of context to find the relevant document. Answer only from the context given using the {question} given. If you do not know the answer, say you do not know. {context}"

In [81]:
def create_llm_chain_and_query(vectordb,retriever, llm, prompt_template = "Answer {question} from the following context: {context}"):
    RQA_PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
    )
    rqa_chain_type_kwargs = {"prompt": RQA_PROMPT}
    return RetrievalQA.from_chain_type(
        llm,
        # chain_type="stuff",
        retriever=retriever,
        chain_type_kwargs=rqa_chain_type_kwargs,
        return_source_documents=True,
        verbose=False,
    )

In [82]:
def get_relevant_documents_from_results(results, metadata_df):
    return [doc.metadata['did'] for doc in results["source_douments"]]

In [83]:
retriever, llm = create_retriever_and_llm(vectordb)
qa = create_llm_chain_and_query(vectordb=vectordb,retriever=retriever,llm=llm, prompt_template = rqa_prompt_template)

In [84]:
def create_result_dataframe(query, qa) -> pd.DataFrame:
    results = qa.invoke({"query": query})
    result_to_dict = {result.metadata['did']: result.page_content for result in results['source_documents']}
    return pd.DataFrame(list(result_to_dict.items()),columns = ['did','description'])

In [85]:
query = "Which datasets would be useful for stock market information?"

In [86]:
results = create_result_dataframe(query, qa)
results

Unnamed: 0,did,description


In [56]:
# print the first 300 characters of each description
for i in range(len(results)):
    print(f"Dataset ID: {results.loc[i]['did']}")
    print(f"Description: {results.loc[i]['description'][:300]}")
    print("\n")

Dataset ID: 43006
Description: This datasets covers features from various categories of technical indicators, futures contracts, price of commodities, important indices of markets around the world, price of major companies in the U.S. market, and treasury bill rates. Sources and thorough description of features have been mentione


Dataset ID: 43004
Description: This datasets covers features from various categories of technical indicators, futures contracts, price of commodities, important indices of markets around the world, price of major companies in the U.S. market, and treasury bill rates. Sources and thorough description of features have been mentione


Dataset ID: 43003
Description: This datasets covers features from various categories of technical indicators, futures contracts, price of commodities, important indices of markets around the world, price of major companies in the U.S. market, and treasury bill rates. Sources and thorough description of features have been mentione
