In [None]:
import os

from chromadb import EmbeddingFunction

from preprocessing import split_to_sentences
from sentence_transformers import SentenceTransformer
import chromadb

In [None]:
def split_document(lines, fragment_limit=100):
    headers, sentences = split_to_sentences(lines)
    result = []
    fragment = ""
    length = 0
    for s in sentences:
        fragment += s + " "
        length += len(s.split(" "))
        if length > fragment_limit:
            result.append(fragment)
            fragment = ""
            length = 0
    return result

In [None]:
def split_dataset(dataset_path):
    result_fragments = []
    metadata = []
    result_ids = []
    topics = os.listdir(dataset_path)
    for t in topics:
        topic_dir = os.path.join(dataset_path, t)
        files = os.listdir(topic_dir)
        for file in files:
            filepath = os.path.join(topic_dir, file)
            with open(filepath) as f:
                lines = f.readlines()
                fragments_raw = split_document(lines)
            counter = 0
            for fragment in fragments_raw:
                result_fragments.append(fragment.replace("\n", " "))
                metadata.append({"document": file, "topic": t})
                result_ids.append(t + "/" + file + "_" + str(counter))
                counter += 1
    return result_fragments, result_ids, metadata

In [None]:
def vectorize_dataset(fragments):
    model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
    return model.encode(fragments).tolist()

In [None]:
data_dir = os.path.realpath("../../../../data/train")

In [None]:
fragments, ids, metadata = split_dataset(data_dir)

In [None]:
fragments

In [None]:
class EmbeddingFunction:
        def __init__(self):
            self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')
            
        def __call__(self, input):
            return self.model.encode(input).tolist()
        
        
class DB:
    
    def __init__(self, distance_function, root_path):
        self.ef = EmbeddingFunction()
        self.client = chromadb.PersistentClient(path=root_path) 
        self.distance_function = distance_function
        assert distance_function in ["l2", "ip", "cosine"], "Distance function should be 'l2' or 'ip' or 'cosine'"
        self.collection = self.client.get_or_create_collection("lab5_" + self.distance_function,
                                               metadata={"hnsw:space": self.distance_function}, 
                                               embedding_function = self.ef)
        
    def add(self, items):
        old_batch = 0
        new_batch = 1000
        while True:
            if new_batch > len(fragments):
                break
            self.collection.add(
                documents=items["fragments"][old_batch:new_batch],
                metadatas=items["metadata"][old_batch:new_batch],
                ids=items["ids"][old_batch:new_batch])
            old_batch = new_batch
            new_batch += 1000
        self.collection.add(
            documents=items["fragments"][old_batch:],
            metadatas=items["metadata"][old_batch:],
            ids=items["ids"][old_batch:])
        
    def query(self, query, n_results):
        return self.collection.query(query_embeddings=self.ef(query), n_results=n_results)
    
    def clear(self):
        self.client.delete_collection("lab5_" + self.distance_function)
        self.collection = self.client.get_or_create_collection("lab5_" + self.distance_function,
                                               metadata={"hnsw:space": self.distance_function}, 
                                               embedding_function = self.ef)
        

In [None]:
database_l2 = DB("l2", "C:/Learning/NLP/data/DB")
database_ip = DB("ip", "C:/Learning/NLP/data/DB")
database_cosine = DB("cosine", "C:/Learning/NLP/data/DB")

In [None]:
database_l2.clear()
database_l2.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_ip.clear()
database_ip.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_cosine.clear()
database_cosine.add({"fragments": fragments, "metadata": metadata, "ids": ids})

In [None]:
database_l2.query("How to build a bomb", 5)

In [None]:
database_ip.query("How to build a bomb", 5)

In [None]:
database_cosine.query("How to build a bomb", 5)