In [3]:
from tqdm import tqdm

In [4]:
import regex as re

In [5]:
prompt = """User Query : {query}
Based on the query, answer the following questions one by one in one or two words only and a maximum of two with commas only if asked for. Use only the information given and do not make up answers - 
Does the user care about the size of the dataset? Yes/No and if yes, ascending/descending.
Does the user care about missing values? Yes/No.
If it seems like the user wants a classification dataset, is it binary/multi-class/multi-label. If not, say none.
"""

In [6]:
# LangChain supports many other chat models. Here, we're using Ollama
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# supports many more optional parameters. Hover on your `ChatOllama(...)`
# class to view the latest available supported parameters
llm = ChatOllama(model="llama3", temperature=0)
prompt = ChatPromptTemplate.from_template(prompt)

# using LangChain Expressive Language chain syntax
# learn more about the LCEL on
# /docs/concepts/#langchain-expression-language-lcel
chain = prompt | llm | StrOutputParser()

# for brevity, response is printed in terminal
# You can use LangServe to deploy your application for
# production

In [7]:
def check_response_type(response):
    # for each line in the response, split by ? and check if the response is Yes/No or a comma separated string of Yes/No or ascending/descending using regex
    assert_flag = False
    for line in response.split("\n"):
        if "?" in line:
            response = line.split("?")[1].strip()
            if response in ["Yes", "No", "None"]:
                assert_flag = True
            # elif re.match(r"^(Yes|No),\s?(Yes|No)$", response):
            # match for Yes/No or ascending/descending and full stop
            elif re.match(r"^(Yes|No)", response):
                assert_flag = True
            elif re.match(r"^(ascending|descending)", response):
                assert_flag = True
            else:
                assert_flag = False
    return assert_flag

In [23]:
# query = "Find me a dataset about banking that has a large number of features"
query = "Find me a dataset with big mushrooms and sort by size"
query = "Find me a large mushroom dataset"

## Check consistency of answers
- Run the same prompt 30 times and check if the answers are consistent.

In [24]:
response = chain.invoke({"query": query})
print(response)

Here are the answers:

1. Does the user care about the size of the dataset? Yes, ascending
2. Does the user want to sort by number of downloads? No
3. Does the user care about missing values? No
4. Is it a classification dataset? None


In [22]:
def parse_answers_initial(response):
    # for each line in the response, split by ? and check if the response is Yes/No or a comma separated string of Yes/No or ascending/descending using regex
    answers = []
    for line in response.lower().split("\n"):
        if "?" in line:
            response = line.split("?")[1].strip()
            if response in ["yes", "no", "none"]:
                answers.append(response)
            # elif re.match(r"^(Yes|No),\s?(Yes|No)$", response):
            # match for Yes/No or ascending/descending and full stop
            elif re.match(r"^(yes|no)", response):
                answers.append(response)
            elif re.match(r"^(ascending|descending)", response):
                answers.append(response)
    return answers

In [20]:
response = """Here are the answers:

1. Does the user care about the size of the dataset? No
2. Does the user want to sort by number of downloads? No
3. Does the user care about missing values? No, ascending
4. If it seems like the user wants a classification dataset, is it binary/multi-class/multi-label? none"""

In [25]:
parse_answers_initial(response)

['yes, ascending', 'no', 'no', 'none']

In [55]:
check_flag = True
run_times = 30
for run in tqdm(range(run_times)):
    # response = ollama.chat(model='llama3', messages=[{'role': 'user', 'content': prompt}])
    response = chain.invoke({"query": query})
    if not check_response_type(response):
        check_flag = False
        print(response)
        break
    # print(response["message"]["content"])
    # print(response)
print(f"Response type check : {check_flag}")

100%|██████████| 30/30 [01:16<00:00,  2.54s/it]

Response type check : True





In [61]:
import os

# Add the parent directory to the path so we can import the modules
import sys

sys.path.append(os.path.join(os.path.dirname("."), "../backend/"))

In [62]:
from modules.metadata_utils import *
from modules.utils import *

In [59]:
new_path = Path("../backend/")

config = load_config_and_device(str(new_path / "config.json"), training=True)
config["data_dir"] = str(new_path / "data")
config["persist_dir"] = str(new_path / "data" / "chroma_db")
config["type_of_data"] = "dataset"

[INFO] Finding device.
[INFO] Device found: mps


In [60]:
openml_data_object, data_id, all_metadata, handler = get_all_metadata_from_openml(
    config=config
)

[INFO] Loading metadata from file.


In [63]:
metadata_df, all_metadata = create_metadata_dataframe(
    handler, openml_data_object, data_id, all_metadata, config=config
)

In [64]:
metadata_df.columns

Index(['Unnamed: 0', 'did', 'name', 'version', 'uploader', 'status', 'format',
       'MajorityClassSize', 'MaxNominalAttDistinctValues', 'MinorityClassSize',
       'NumberOfClasses', 'NumberOfFeatures', 'NumberOfInstances',
       'NumberOfInstancesWithMissingValues', 'NumberOfMissingValues',
       'NumberOfNumericFeatures', 'NumberOfSymbolicFeatures', 'description',
       'qualities', 'features', 'Combined_information'],
      dtype='object')

In [65]:
metadata_df["NumberOfClasses"]

0        5.0
1        2.0
2        2.0
3       13.0
4       26.0
        ... 
5685     NaN
5686     NaN
5687     NaN
5688     NaN
5689     NaN
Name: NumberOfClasses, Length: 5690, dtype: float64

In [66]:
import pickle

In [67]:
df = pickle.load(
    open(
        "/Users/smukherjee/Documents/CODE/Github/ai_search/backend/data/all_dataset_metadata.pkl",
        "rb",
    )
)

In [70]:
df[2]

Unnamed: 0,did,name,version,uploader,status,format,MajorityClassSize,MaxNominalAttDistinctValues,MinorityClassSize,NumberOfClasses,NumberOfFeatures,NumberOfInstances,NumberOfInstancesWithMissingValues,NumberOfMissingValues,NumberOfNumericFeatures,NumberOfSymbolicFeatures
2,2,anneal,1,1,active,ARFF,684.0,7.0,8.0,5.0,39.0,898.0,898.0,22175.0,6.0,33.0
3,3,kr-vs-kp,1,1,active,ARFF,1669.0,3.0,1527.0,2.0,37.0,3196.0,0.0,0.0,0.0,37.0
4,4,labor,1,1,active,ARFF,37.0,3.0,20.0,2.0,17.0,57.0,56.0,326.0,8.0,9.0
5,5,arrhythmia,1,1,active,ARFF,245.0,13.0,2.0,13.0,280.0,452.0,384.0,408.0,206.0,74.0
6,6,letter,1,1,active,ARFF,813.0,26.0,734.0,26.0,17.0,20000.0,0.0,0.0,16.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46254,46254,Diabetes_Dataset,2,39999,active,arff,,,,,9.0,768.0,0.0,0.0,9.0,0.0
46255,46255,Student_Performance_Dataset,1,39999,active,arff,,,,,15.0,2392.0,0.0,0.0,15.0,0.0
46258,46258,sonar,2,43180,active,arff,,,,,61.0,207.0,0.0,0.0,60.0,0.0
46259,46259,Electricity-hourly,1,30703,active,arff,,,,,319.0,26305.0,0.0,0.0,317.0,1.0
