In [1]:
import sys
sys.path.append('/workspace/src/')


In [28]:
#first get dois and corresponding full_text


import ir_datasets
from tqdm import tqdm

from dotenv import dotenv_values, load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from database.model import Base, Document, Table
from database.chunk_model import Chunk_Base, Chunk

from langchain_text_splitters import RecursiveCharacterTextSplitter
import langchain_core.documents

db_vals = dotenv_values("/workspace/src/.env")
load_dotenv()


True

In [3]:
engine = create_engine(f"postgresql+psycopg2://{db_vals['USER']}:{db_vals['PASSWORD']}@{db_vals['ADDRESS']}:{db_vals['PORT']}/{db_vals['DB']}", echo=False)
session = Session(engine)
Base.metadata.create_all(engine)

In [4]:
all_dois = list(set(doi[0] for doi in session.query(Document.doi).all()))

In [None]:
BATCH_SIZE = 1000
doi_full_text = {}

for i in tqdm(range(0, len(all_dois), BATCH_SIZE)):
    batch = all_dois[i:i + BATCH_SIZE]
    docs = session.query(Document).filter(Document.doi.in_(batch)).all()
    for doc in docs:
        doi_full_text[doc.doi] = (doc.title, doc.full_text)

In [6]:
#dataset = ir_datasets.load("cord19/fulltext/trec-covid")
#docs = {}
#
#for doc in tqdm(dataset.docs_iter()):
#    if doc.doi in all_dois:
#        docs[doc.doi] = {"title": doc.title, "abstract": doc.abstract, "body" : doc.body}

In [7]:
def gen_full_text(doc):
    full_text = ""
    full_text += f"{doc['title']} \n"
    full_text += f"{doc['abstract']} \n"

    for section in doc["body"]:
        full_text += f"{section.title} \n" 
        full_text += f"{section.text} \n" 
    return full_text

In [8]:
def gen_full_text_docling(row):
    return f"Title: {row[0]} \n\n {row[1]}"

In [None]:
full_texts = {}
for doi, doc in tqdm(doi_full_text.items()):
    full_texts[doi] = gen_full_text_docling(doi_full_text[doi])

In [None]:
metadatas, texts = [doi for doi in sorted(full_texts.keys())], [full_texts[doi] for doi in sorted(full_texts.keys())]
print(len(metadatas), len(texts))

In [12]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,  # chunk size (characters)
    chunk_overlap=100,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)

In [13]:
lang_docs = [langchain_core.documents.Document(page_content=text, metadata={"doi": doi}) for doi, text in full_texts.items()]

In [None]:
all_splits = text_splitter.split_documents(lang_docs)

In [31]:
chunk_engine = create_engine(f"postgresql+psycopg2://{db_vals['USER']}:{db_vals['PASSWORD']}@{db_vals['ADDRESS']}:{db_vals['PORT']}/cord19chunks", echo=False)
chunk_session = Session(chunk_engine)
Chunk_Base.metadata.create_all(chunk_engine)

In [None]:
#add splits to database
cnt = 0
for split in tqdm(all_splits):
    chunk = Chunk(doi=split.metadata["doi"], chunk_text=split.page_content, chunk_type="RCTS_512_100", modality_type="text")
    chunk_session.add(chunk)
    cnt += 1
    if cnt % 1000 == 0:
        chunk_session.commit()
chunk_session.commit()


In [None]:
#chunk tables

In [5]:
all_table_dois = list(set(doi[0] for doi in session.query(Table.ir_tab_id).all()))

In [7]:
BATCH_SIZE = 1000
doi_table = {}

for i in tqdm(range(0, len(all_table_dois), BATCH_SIZE)):
    batch = all_table_dois[i:i + BATCH_SIZE]
    docs = session.query(Table).filter(Table.ir_tab_id.in_(batch)).all()
    for doc in docs:
        doi_table[doc.ir_tab_id] = (doc.table_name, doc.header, doc.content, doc.caption, doc.references)

100%|██████████| 145/145 [01:10<00:00,  2.05it/s]


In [12]:
def gen_table(row):
    table = ""
    table += f"Table Name: {row[0]} \n"
    table += f"Header: {row[1]} \n"
    table += f"Content: {row[2]} \n"
    table += f"Caption: {row[3]} \n"
    table += f"References: {row[4]} \n"
    return table

In [19]:
full_tables = {}
for doi, doc in tqdm(doi_table.items()):
    full_tables[doi] = gen_table(doc)

100%|██████████| 144202/144202 [00:02<00:00, 54034.14it/s]


In [35]:
table_splitter = RecursiveCharacterTextSplitter(
    chunk_size=8192,  # chunk size (characters)
    chunk_overlap=1000,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)

In [36]:
metadatas_t, texts_t = [doi for doi in sorted(full_tables.keys())], [full_tables[doi] for doi in sorted(full_tables.keys())]
print(len(metadatas_t), len(texts_t))

144202 144202


In [37]:
lang_tables= [langchain_core.documents.Document(page_content=text, metadata={"doi": doi}) for doi, text in full_tables.items()]

In [38]:
all_table_splits = table_splitter.split_documents(lang_tables)

144202it [00:20, 7157.12it/s]


In [39]:
cnt = 0
for split in tqdm(all_table_splits):
    chunk = Chunk(doi=split.metadata["doi"], chunk_text=split.page_content, chunk_type="RCTS_8192_1000", modality_type="table")
    chunk_session.add(chunk)
    cnt += 1
    if cnt % 1000 == 0: 
        chunk_session.commit()
chunk_session.commit()

100%|██████████| 152861/152861 [01:36<00:00, 1576.53it/s]


In [34]:
#chunk_session.query(Chunk).filter(Chunk.modality_type == "table").delete(synchronize_session=False)
#chunk_session.commit()