In [16]:
import os
import pandas as pd
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

In [158]:
def get_files(dir_path: list) -> dict:
    files = [file for file in os.listdir(dir_path)]
    file_paths = [dir_path + file for file in files]
    return pd.DataFrame({"file_name": files, "file_path": file_paths})


def read_file(file_path):
    with open(file_path, "r") as file:
        return file.read()


def get_text(db: pd.DataFrame) -> pd.DataFrame:
    db["text"] = db["file_path"].apply(read_file)
    return db


# def get_context(db: pd.DataFrame, overlap: int = 300) -> pd.DataFrame:
#     context = []
#     for row in db.iterrows():
#         i = row["chunk"].metadata["start_index"]
#         if i - overlap < 0 or i + overlap > len(row["text"]):
#             continue

#         context.append(row["text"][i - overlap : i + overlap])

#     db["context"] = context
#     return db


def get_context_for_row(row, chunk_size, overlap):
    start_index = row["chunk"].metadata["start_index"]
    text_length = len(row["text"])

    if start_index - overlap < 0:
        return row["text"][start_index : start_index + chunk_size + overlap]
    if start_index + chunk_size + overlap > text_length:
        return row["text"][start_index - overlap : text_length]

    return row["text"][start_index - overlap : start_index + chunk_size + overlap]


def get_context(db: pd.DataFrame, chunk_size, overlap: int = 100) -> pd.DataFrame:
    # Apply the function to each row of the DataFrame
    db["context"] = db.apply(
        get_context_for_row, chunk_size=chunk_size, overlap=overlap, axis=1
    )
    return db


def get_chunks(db: pd.DataFrame, text_splitter) -> pd.DataFrame:
    db["chunk"] = db["text"].apply(lambda s: text_splitter.create_documents([s]))
    return db.explode("chunk")


def get_embeddings(db: pd.DataFrame, model) -> pd.DataFrame:
    db["embedding"] = db["chunk"].apply(lambda s: model.encode(s.page_content))
    return db


def separate_tables(db):
    return db[["file_name", "file_path", "text"]].drop_duplicates(
        subset="file_name"
    ), db.drop(columns=['file_path', 'text'])


def to_postgres():
    NotImplemented


CHUNK_SIZE = 512

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=20,
    length_function=len,
    add_start_index=True,
)
db = get_files(os.getcwd() + "\data\\")
db = get_text(db)
db = get_chunks(db, text_splitter)
db = get_context(db, CHUNK_SIZE)
db = get_embeddings(db, model)
text_db, vector_db = separate_tables(db)
#to_postgres(db)