In [4]:
import os
# Add the parent directory to the path so we can import the modules
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
import uuid

from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DataFrameLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
# from langchain_community.llms import HuggingFaceHub
from langchain_community.vectorstores.chroma import Chroma
from tqdm import tqdm

from modules.metadata_utils import *

In [5]:
import os
import pickle
# from pqdm.processes import pqdm
from typing import List, Union

import openml
import pandas as pd
from pqdm.threads import pqdm

In [6]:

from modules import *
from modules.utils import *
from langchain.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
set_llm_cache(SQLiteCache(database_path="./data/.langchain.db"))
from langchain_community.vectorstores import Milvus

In [7]:
from langchain_text_splitters import CharacterTextSplitter

In [8]:
config = load_config_and_device("config.json")
config["training"] = False
config["type_of_data"] = "dataset"
config["device"] = "mps"
# config["embedding_model"] = "BAAI/bge-small-en-v1.5"
config["embedding_model"] = "BAAI/bge-base-en-v1.5"

[INFO] Finding device.
[INFO] Device found: cpu


In [9]:

def load_and_process_data(metadata_df, page_content_column):
    """
    Description: Load and process the data for the vector store. Split the documents into chunks of 1000 characters.

    Input: metadata_df (pd.DataFrame), page_content_column (str)

    Returns: chunked documents (list)
    """
    # Load data
    loader = DataFrameLoader(metadata_df, page_content_column=page_content_column)
    documents = loader.load()

    # Split documents
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
    documents = text_splitter.split_documents(documents)

    return documents

In [14]:
# model_kwargs = {"device": "cpu"}
model_kwargs = {"device": config["device"]}
encode_kwargs = {"normalize_embeddings": True}
embeddings = HuggingFaceEmbeddings(
    model_name=config["embedding_model"],
    model_kwargs=model_kwargs,
    encode_kwargs=encode_kwargs,
    show_progress = True,
)



In [15]:
dict_collection_names = {"dataset": "datasets", "flow": "flows"}
collection_name = dict_collection_names[config["type_of_data"]]

In [16]:
# config["training"] = True

In [17]:
openml_data_object, data_id, all_metadata = get_all_metadata_from_openml(
        config=config
    )

[INFO] Loading metadata from file.
[INFO] Metadata loaded.


In [18]:
metadata_df, all_metadata = create_metadata_dataframe(
        openml_data_object, data_id, all_metadata, config=config
    )

In [19]:
# from pymilvus import MilvusClient
# client = MilvusClient(uri = "./data/milvus_db.db")
# # client.drop_collection(collection_name = config["type_of_data"])
# client.create_collection(config["type_of_data"], 4, id_type = "str", max_length = 2000)

In [20]:
metadata_df = metadata_df.rename(columns={"Combined_information": "page_content", "did":"id"})

In [21]:
documents = load_and_process_data(metadata_df, "page_content")

In [23]:
new_document_ids = set([str(x.metadata["id"]) for x in documents])

In [24]:
len(new_document_ids)

5679

In [25]:
from langchain_community.vectorstores.chroma import Chroma
import chromadb

In [26]:
from langchain.retrievers import BM25Retriever, EnsembleRetriever

In [27]:
client = chromadb.PersistentClient(path=config["persist_dir"])

In [28]:
db = Chroma(
            client=client,
            embedding_function=embeddings,
            persist_directory=config["persist_dir"],
            collection_name="datasets",
        )

In [89]:
ret_vec = db.as_retriever(search_type = "similarity", search_kwargs = {"k": 5, })

In [28]:
ret_key = BM25Retriever.from_documents(documents)

In [29]:
ret_key.k = 4

In [50]:
ensemble_retriever = EnsembleRetriever(retrievers=[ret_vec, ret_key], weights=[0.8, 0.2])

In [101]:
[x.metadata["name"] for x in ret_vec.invoke("titanic")]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['The-Estonia-Disaster-Passenger-List',
 'The-Estonia-Disaster-Passenger-List',
 'Dota2-Games-Results-Data-Set',
 'Internet-Advertisements',
 'Dota2-Games-Results-Data-Set']

In [102]:
[x.metadata["name"] for x in ensemble_retriever.invoke("titanic")]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['The-Estonia-Disaster-Passenger-List',
 'The-Estonia-Disaster-Passenger-List',
 'Dota2-Games-Results-Data-Set',
 'Internet-Advertisements',
 'titanic_1',
 'titanic']

In [29]:
test_db = db.get()

In [31]:
old_dids = set([str(x["did"]) for x in test_db["metadatas"]])

In [33]:
# new documents
new_dids = new_document_ids - old_dids

In [36]:
new_documents = [x for x in documents if str(x.metadata["id"]) in new_dids]

In [78]:
db.add_documents(documents[:100], ids= document_ids)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

DuplicateIDError: Expected IDs to be unique, found 15 duplicated IDs: 3, 4, 8, 11, 13, ..., 10, 12, 15, 2, 7

In [68]:
db.add_documents(documents[:100])

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

['220ada12-6eba-413a-affd-d76a13b72f33',
 '657be025-7b8c-40d1-8a23-f4172544a6e8',
 '989d3786-e7d1-4bf9-80b9-d5574e669429',
 '67d77c70-0fe9-4d45-a304-bc0805f357d8',
 'adeeefbf-fa21-40e1-922d-438c0767f853',
 'a8abbda9-403b-4f2d-8afb-c85f81000f4b',
 'db2db5e7-371d-4337-922a-78f549493b86',
 '3dbcd1c8-f93c-4b4b-a463-6ce61f86bc85',
 'e0ef4b8c-6d63-408e-b1e2-b2801267629d',
 '2809eb11-a2dd-4cad-8d13-061c4d05546d',
 '86987207-a3e3-4507-93e8-9e50f73930ea',
 '72dd2a43-b912-43ac-8aac-e7a33f301564',
 '185a13ba-5dd2-4017-8849-91d589f0bc79',
 'ac4f2971-f947-4a58-986b-2766a37a3670',
 '5c9b5f57-cc0b-42b6-98fe-0ca9b93d3784',
 '599a8830-a7a0-45be-9716-e5d84e1e9e96',
 '96557f4c-78c9-4869-a18b-842870850694',
 '6dd53553-3c0e-45ad-b388-f27ac5e5bc17',
 'b0197119-7702-4e67-a682-dc0519e3c316',
 '1cf53651-03e4-4ae1-b99c-b2a1ba7f062b',
 '70fd5f6d-261e-4e26-83a0-e7ce82a0dc8a',
 '2fdd808a-5944-49eb-999e-2dde2108e09a',
 '77cfc17a-17ce-4175-a9d0-50018aaed8e9',
 '0a90f85f-6a0a-48a3-b906-75837f430e73',
 '28a78281-0c4b-

In [67]:
db.as_retriever().invoke("ship disaster")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[Document(page_content='### Relevant papers', metadata={'id': 11, 'name': 'balance-scale'}),
 Document(page_content='10. Exceptions from format instructions: no commas between attribute values., qualities - AutoCorrelation : 0.75, CfsSubsetEval_DecisionStumpAUC : 0.7297297297297297, CfsSubsetEval_DecisionStumpErrRate : 0.24561403508771928, CfsSubsetEval_DecisionStumpKappa : 0.4608108108108108, CfsSubsetEval_NaiveBayesAUC : 0.7297297297297297, CfsSubsetEval_NaiveBayesErrRate : 0.24561403508771928, CfsSubsetEval_NaiveBayesKappa : 0.4608108108108108, CfsSubsetEval_kNN1NAUC : 0.7297297297297297, CfsSubsetEval_kNN1NErrRate : 0.24561403508771928, CfsSubsetEval_kNN1NKappa : 0.4608108108108108, ClassEntropy : 0.9348490242345945, DecisionStumpAUC : 0.7378378378378379, DecisionStumpErrRate : 0.3157894736842105, DecisionStumpKappa : 0.32232496697490093, Dimensionality : 0.2982456140350877, EquivalentNumberOfAtts : 9.28828714469122, J48.00001.AUC : 0.8054054054054054, J48.00001.ErrRate : 0.2807017

In [60]:
for doc_number in tqdm(range(len(documents))):
    try:
        documents[doc_number].metadata.__delitem__("Unnamed: 0")
    except KeyError:
        pass

100%|██████████| 694432/694432 [00:07<00:00, 91972.36it/s] 


In [21]:
# from langchain_milvus import Milvus
from pymilvus import MilvusClient, model

In [22]:
vectorb = Milvus.from_documents(documents, embeddings, connection_args = {
    "uri": "http://127.0.0.1:19530",
    # "uri": "./data/milvus_db.db",
    "collection_name": "datasets",
    "partition_name": "metadata_store"
})

Batches:   0%|          | 0/21701 [00:00<?, ?it/s]

In [92]:
ret = vectorb.as_retriever()

In [93]:
[x.metadata["name"] for x in ret.invoke("salaryman ")]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

['labor', 'labor', 'labor', 'labor']

In [66]:
vectorb.similarity_search(query="business salary", k = 5)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[Document(page_content='RandomTreeDepth1Kappa : 0.3376623376623376, RandomTreeDepth2AUC : 0.7500711237553342, RandomTreeDepth2ErrRate : 0.2982456140350877, RandomTreeDepth2Kappa : 0.3376623376623376, RandomTreeDepth3AUC : 0.7500711237553342, RandomTreeDepth3ErrRate : 0.2982456140350877, RandomTreeDepth3Kappa : 0.3376623376623376, StdvNominalAttDistinctValues : 0.5270462766947299, kNN1NAUC : 0.7675675675675676, kNN1NErrRate : 0.21052631578947367, kNN1NKappa : 0.5581395348837209,, features - 0 : [0 - duration (numeric)], 1 : [1 - wage-increase-first-year (numeric)], 2 : [2 - wage-increase-second-year (numeric)], 3 : [3 - wage-increase-third-year (numeric)], 4 : [4 - cost-of-living-adjustment (nominal)], 5 : [5 - working-hours (numeric)], 6 : [6 - pension (nominal)], 7 : [7 - standby-pay (numeric)], 8 : [8 - shift-differential (numeric)], 9 : [9 - education-allowance (nominal)], 10 : [10 - statutory-holidays (numeric)], 11 : [11 - vacation (nominal)], 12 : [12 - longterm-disability-assist

In [37]:
qa = vectordb.as_retriever(
    search_type = "similarity",
    search_kwargs = {"k": config["num_return_documents"]}
)

In [38]:
qa.invoke(input = "anealling")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

ERROR:pymilvus.decorators:RPC error: [search], <MilvusException: (code=2000, message=vector dimension mismatch, expected vector size(byte) 3072, actual 1536.: segcore error)>, <Time:{'RPC start': '2024-06-14 11:28:20.212134', 'RPC error': '2024-06-14 11:28:20.215885'}>


MilvusException: <MilvusException: (code=2000, message=vector dimension mismatch, expected vector size(byte) 3072, actual 1536.: segcore error)>

In [48]:
res = client.insert(collection_name=config["type_of_data"], data=dict_metadata_df[0])

ERROR:pymilvus.decorators:RPC error: [insert_rows], <ParamError: (code=1, message=Field vector don't match in entities[0])>, <Time:{'RPC start': '2024-06-14 10:36:33.011920', 'RPC error': '2024-06-14 10:36:33.018961'}>


ParamError: <ParamError: (code=1, message=Field vector don't match in entities[0])>