In [22]:
import os
import glob
import pymilvus
from typing import List

from langchain.document_loaders import (
    CSVLoader,
    EverNoteLoader,
    PyMuPDFLoader,
    TextLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredODTLoader,
    UnstructuredPowerPointLoader,
    UnstructuredWordDocumentLoader,
)

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Milvus
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document

In [16]:
db_url = os.environ.get('DB_URL', 'http://127.0.0.1:19530')
source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME', 'all-MiniLM-L6-v2')
chunk_size = 500
chunk_overlap = 50
collection_name = 'LangChainCollection'

In [17]:
LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    ".doc": (UnstructuredWordDocumentLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".enex": (EverNoteLoader, {}),
    ".epub": (UnstructuredEPubLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".md": (UnstructuredMarkdownLoader, {}),
    ".odt": (UnstructuredODTLoader, {}),
    ".pdf": (PyMuPDFLoader, {}),
    ".ppt": (UnstructuredPowerPointLoader, {}),
    ".pptx": (UnstructuredPowerPointLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
    # Add more mappings for other file extensions and loaders as needed
}

In [18]:
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
db = Milvus(connection_args={"uri": db_url}, embedding_function=embeddings)

In [21]:
def get_db_connection(embeddings: HuggingFaceEmbeddings) -> Milvus:
    address = db_url
    if address.startswith('http://'):
        address = address[len('http://'):]
    elif address.startswith('https://'):
        address = address[len('https://'):]
    return Milvus(connection_args={'address': address}, embedding_function=embeddings)

def get_existing_sources() -> List[str]:
    pymilvus.connections.connect(uri=db_url)
    try:
        collection = pymilvus.Collection(collection_name)
    except pymilvus.exceptions.SchemaNotReadyException:
        return []
    collection.load()
    sources = []
    query_iterator = collection.query_iterator(100, 65536, '', ['source'])
    while True:
        docs = query_iterator.next()
        if len(docs) == 0:
            break
        for doc in docs:
            source = doc.get("source")
            if source not in sources:
                sources.append(source)
    return sources

In [None]:
all_files = []
for ext in LOADER_MAPPING:
    all_files.extend(
        glob.glob(os.path.join(source_directory, f"**/*{ext}"), recursive=True)
    )
documents = []
for _, file_path in enumerate(all_files):
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext not in LOADER_MAPPING:
        print(f"Unsupported file extension '{ext}'")
        continue
    print(f"Loading {file_path}")
    loader_class, loader_args = LOADER_MAPPING[ext]
    loader = loader_class(file_path, **loader_args)
    documents.extend(loader.load())
texts = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap).split_documents(documents)
texts

In [None]:
db.add_documents(texts)