In [3]:
import os
from pymongo import MongoClient
import pymongo
from tqdm import tqdm
import re
import pickle
import json
import torch
from nltk.tokenize import sent_tokenize

from langchain.docstore.document import Document
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.embeddings import SentenceTransformerEmbeddings


from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from langchain.vectorstores import Chroma
import json
from bson import ObjectId


In [4]:
def json_serialize(obj):
    """
    JSON serializer for objects not serializable by default json code
    """
    if isinstance(obj, ObjectId):
        return str(obj)
    raise TypeError ("Type %s not serializable" % type(obj))
    
def json_deserialize(obj):
    # Implement deserialization if needed, e.g., converting string back to datetime
    return obj

In [5]:
with open('filtered_records.json', 'r') as file:
    filtered_records = json.load(file, object_hook=json_deserialize)

In [7]:
def process_dataset(records):
    all_documents = []
    seen_page_contents = set()  # To track duplicates

    for record in tqdm(records, desc="Filtering records"):
        # Exclude titles and records with 'review' or 'survey' in title
        if isinstance(record.get('title', ''), str):
            title = record.get('title', '').lower()
            if 'review' in title or 'survey' in title:
                continue

        # Process abstract
        abstract = record.get('abstract')
        if isinstance(abstract, str):
            for sentence in sent_tokenize(abstract):
                if len(sentence.split()) >= 10 and sentence not in seen_page_contents:
                    seen_page_contents.add(sentence)
                    doc_sentence = Document(page_content=sentence,
                                            metadata={"section": "abstract", "doi": record.get('doi')})
                    all_documents.append(doc_sentence)

        # Process sentences in paragraphs
        for paragraph in record.get('paragraphs', []):
            section_name = paragraph.get('section_name', '').lower()
            text = paragraph.get('text', '')

            # Check for section 'introduction'
            if 'introduction' in section_name:
                continue

            for sentence in sent_tokenize(text):
                # Check for sentence length and duplication
                if len(sentence.split()) >= 10 and sentence not in seen_page_contents:
                    seen_page_contents.add(sentence)
                    doc_sentence = Document(page_content=sentence,
                                            metadata={"section": section_name, "doi": record.get('doi')})
                    all_documents.append(doc_sentence)

    return all_documents

In [8]:
processed_documents = process_dataset(filtered_records)

Filtering records: 100%|██████████| 36857/36857 [02:54<00:00, 211.63it/s]


In [9]:
len(processed_documents)

4541573

In [10]:
# Define metadata field information
metadata_field_info = [
    AttributeInfo(name="section", description="Section of the paper", type="string"),
    AttributeInfo(name="doi", description="Digital Object Identifier of the document", type="string"),
]

In [11]:
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
persist_directory = 'data/chroma/sminilm'

In [12]:
svdb = Chroma.from_documents(processed_documents, embedding=embeddings, persist_directory=persist_directory)

In [13]:
svdb.persist()

In [14]:
# To load the vectordb
svdb = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
print(svdb._collection.count())

4541573
