In [4]:
import json
import codecs
import math
import numpy as np
from collections import defaultdict, Counter
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import nltk
from multiprocessing import Pool, cpu_count
import itertools
import os # Needed for file operations
import pickle # Needed for saving/loading state

In [2]:
with open('data/allMeSH_2022.json', "rb") as f:
    num_lines = sum(1 for _ in f)

num_lines

16218839

In [3]:
with codecs.open('data/allMeSH_2022.json', 'r', encoding='utf-8', errors='ignore') as corpus:
    line_nr = 0
    for line in corpus:
        if line_nr > 16218829:
            print(line[:-2])
        line_nr += 1

{"journal":"Journal of bacteriology","meshMajor":["Cephalothin","Chloramphenicol","Drug Resistance, Microbial","Drug Stability","Enterobacter","Escherichia","Klebsiella","Pharmaceutical Preparations","Proteus","Salmonella","Shigella","Tetracycline"],"year":"1964","abstractText":"Wick, Warren E. (The Lilly Research Laboratories, Indianapolis, Ind.). Influence of antibiotic stability on the results of in vitro testing procedures. J. Bacteriol. 87:1162-1170. 1964.-Certain antibiotics undergo at least partial degradation under the conditions of in vitro testing procedures. With cephalothin used as an example, experimental evidence is presented to indicate the necessity for re-evaluation of results obtained from in vitro sensitivity testing methods for some antibiotics. The in vitro activity of cephalothin, tetracycline, and chloramphenicol against a variety of gram-negative bacteria is described. Plate counts demonstrate changes in the viable cell population over a 48-hr period in tubes of

In [None]:
stop_words = set(stopwords.words('english'))
stemmer = PorterStemmer()

def preprocess(text):
    tokens = word_tokenize(text.lower())
    tokens = [stemmer.stem(token) for token in tokens if token.isalnum() and token not in stop_words]
    return tokens

def process_line(line):
    processed_line = line.strip()
    if processed_line.endswith(','):
        processed_line = processed_line[:-1]
    if not processed_line:
        return None
    try:
        line_json = json.loads(processed_line)
        pmid = line_json.get("pmid", None)
        abstract = line_json.get("abstractText", "")
        if pmid and abstract:
            tokens = preprocess(abstract)
            return pmid, Counter(tokens), len(tokens)
    except json.JSONDecodeError as e:
        return None
    except Exception as e:
        return None
    return None

In [6]:
json_file = 'data/allMeSH_2022.json'
list_counter = []

# showcase preprocessing
with codecs.open(json_file, 'r', encoding='utf-8', errors='ignore') as corpus:
    line_nr = 0
    for line in corpus:
        if line_nr > 0:
            line_json = json.loads(line[:-2])
            print(line_json["pmid"])
            print(preprocess(line_json["abstractText"]))
            list_counter.append(Counter(preprocess(line_json["abstractText"])))
            #print(line[:-2])
        line_nr += 1
        if line_nr >= 11:
            break

34823483
['background', 'worldwid', 'hypertens', 'disord', 'pregnanc', 'hdp', 'fetal', 'growth', 'restrict', 'fgr', 'preterm', 'birth', 'remain', 'lead', 'caus', 'matern', 'fetal', 'mortal', 'morbid', 'fetal', 'cardiac', 'deform', 'chang', 'first', 'sign', 'placent', 'dysfunct', 'associ', 'hdp', 'fgr', 'preterm', 'birth', 'addit', 'preterm', 'birth', 'like', 'associ', 'chang', 'electr', 'activ', 'across', 'uterin', 'muscl', 'therefor', 'fetal', 'cardiac', 'function', 'uterin', 'activ', 'use', 'earli', 'detect', 'complic', 'pregnanc', 'fetal', 'cardiac', 'function', 'uterin', 'activ', 'assess', 'echocardiographi', 'fetal', 'electrocardiographi', 'electrohysterographi', 'ehg', 'studi', 'aim', 'gener', 'refer', 'valu', 'ehg', 'paramet', 'second', 'trimest', 'pregnanc', 'investig', 'diagnost', 'potenti', 'paramet', 'earli', 'detect', 'hdp', 'fgr', 'preterm', 'longitudin', 'prospect', 'cohort', 'studi', 'elig', 'women', 'recruit', 'tertiari', 'care', 'hospit', 'primari', 'midwiferi', 'pract

In [None]:
def build_index_parallel(corpus_path,
                         checkpoint_path="checkpoint.pkl",
                         save_every_n_docs=100000,
                         chunk_size=1000,
                         num_workers=None):
    if num_workers is None:
        num_workers = cpu_count()

    index = defaultdict(list)
    doc_metadata = {}
    tf = {}
    df = Counter()
    total_docs = 0
    total_dl = 0
    processed_lines_count = 0

    # --- Load from checkpoint if exists ---
    if os.path.exists(checkpoint_path):
        print(f"--- Checkpoint found at {checkpoint_path}. Resuming... ---")
        try:
            with open(checkpoint_path, 'rb') as f_checkpoint:
                saved_state = pickle.load(f_checkpoint)
                index = saved_state['index']
                doc_metadata = saved_state['doc_metadata']
                tf = saved_state['tf']
                df = saved_state['df']
                total_docs = saved_state['total_docs']
                total_dl = saved_state['total_dl']
                processed_lines_count = saved_state['processed_lines_count']
                print(f"--- Resumed state: {total_docs} documents processed from {processed_lines_count} lines. ---")
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting from scratch.")
            index, doc_metadata, tf, df = defaultdict(list), {}, {}, Counter()
            total_docs, total_dl, processed_lines_count = 0, 0, 0
    else:
        print("--- No checkpoint found. Starting from scratch. ---")

    last_save_doc_count = total_docs

    try:
        with codecs.open(corpus_path, 'r', encoding='utf-8', errors='ignore') as corpus, \
             Pool(processes=num_workers) as pool:

            if processed_lines_count > 0:
                print(f"Skipping first {processed_lines_count} lines from input file...")
                for _ in range(processed_lines_count):
                    next(corpus)
                print("Skipping complete.")

            current_line_nr = processed_lines_count
            doc_id_counter = total_docs

            while True:
                lines_chunk = list(itertools.islice(corpus, chunk_size))
                if not lines_chunk:
                    break

                chunk_results = pool.map(process_line, lines_chunk, chunksize=max(1, chunk_size // (num_workers * 2)))

                lines_in_chunk = len(lines_chunk)
                del lines_chunk

                for result in chunk_results:
                    if result:
                        pmid, doc_tf_counter, doc_length = result
                        current_doc_id = doc_id_counter

                        doc_metadata[current_doc_id] = {"pmid": pmid, "length": doc_length}
                        tf[current_doc_id] = doc_tf_counter
                        total_dl += doc_length
                        total_docs += 1

                        for token, freq in doc_tf_counter.items():
                            df[token] += 1
                            index[token].append(current_doc_id)

                        doc_id_counter += 1

                current_line_nr += lines_in_chunk
                print(f"Processed up to line ~{current_line_nr}. Total documents: {total_docs}")

                if total_docs >= last_save_doc_count + save_every_n_docs:
                    print(f"\n--- Saving checkpoint at {total_docs} documents (processed ~{current_line_nr} lines)... ---")
                    checkpoint_state = {
                        'index': index,
                        'doc_metadata': doc_metadata,
                        'tf': tf,
                        'df': df,
                        'total_docs': total_docs,
                        'total_dl': total_dl,
                        'processed_lines_count': current_line_nr
                    }
                    temp_checkpoint_path = checkpoint_path + ".tmp"
                    try:
                        with open(temp_checkpoint_path, 'wb') as f_temp_checkpoint:
                            pickle.dump(checkpoint_state, f_temp_checkpoint, protocol=pickle.HIGHEST_PROTOCOL)
                        os.replace(temp_checkpoint_path, checkpoint_path)
                        print(f"--- Checkpoint saved successfully to {checkpoint_path} ---")
                        last_save_doc_count = total_docs
                    except Exception as e:
                        print(f"!!! Error saving checkpoint: {e} !!!")
                        if os.path.exists(temp_checkpoint_path):
                            try:
                                os.remove(temp_checkpoint_path)
                            except OSError:
                                pass
                    del checkpoint_state # Free memory

    except Exception as e:
        print(f"\n!!! An error occurred during processing: {e} !!!")
        print("Please check the error message. Try resuming the script to continue from the last checkpoint.")
        raise

    # --- Final calculations (after loop finishes) ---
    if total_docs == 0:
        print("Warning: No documents processed.")
        return index, {}, {}, {}, 0

    avg_dl = total_dl / total_docs
    # Using +1 smoothing for IDF
    idf = {token: math.log((total_docs + 1) / (freq + 1)) + 1 for token, freq in df.items()}

    print(f"\n--- Finished building index. Total documents: {total_docs} ---")

    # --- Optional: Final save after successful completion ---
    print("--- Saving final state... ---")
    final_state = {
        'index': index, 'doc_metadata': doc_metadata, 'tf': tf, 'df': df,
        'total_docs': total_docs, 'total_dl': total_dl,
        'processed_lines_count': current_line_nr # Save final line count
    }
    temp_checkpoint_path = checkpoint_path + ".tmp"
    try:
        with open(temp_checkpoint_path, 'wb') as f_temp_checkpoint:
            pickle.dump(final_state, f_temp_checkpoint, protocol=pickle.HIGHEST_PROTOCOL)
        os.replace(temp_checkpoint_path, checkpoint_path)
        print(f"--- Final state saved successfully to {checkpoint_path} ---")
    except Exception as e:
        print(f"!!! Error saving final state: {e} !!!")

    return index, doc_metadata, tf, idf, avg_dl

In [None]:
# index, doc_ids, tf, idf, avg_dl = build_index_parallel(json_file)

json_file = 'data/allMeSH_2022.json'
checkpoint_file = 'mesh_index_checkpoint.pkl'

print("Building index (with checkpointing)...")
index, doc_metadata, tf, idf, avg_dl = build_index_parallel(
    json_file,
    checkpoint_path=checkpoint_file,
    save_every_n_docs=200000, # Save every 200,000 documents
    chunk_size=4000,
    num_workers=10
)

Building index (with checkpointing)...
--- Checkpoint found at mesh_index_checkpoint.pkl. Resuming... ---
--- Resumed state: 2003999 documents processed from 2004000 lines. ---
Skipping first 2004000 lines from input file...
Skipping complete.


In [89]:
print(idf.get("hirschsprung", 0))

9.210340371976184


In [90]:
print(tf[0]["hirschsprung"])

0


In [None]:
def calc_scores(query, index, doc_metadata, tf, idf, avg_dl, k1=1.5, b=0.75):
    query_tokens = preprocess(query)
    scores = defaultdict(float) # Use defaultdict for sparse scores
    total_docs = len(doc_metadata)

    if total_docs == 0:
        return {} # Return empty scores if no documents

    for token in query_tokens:
        if token in index: # Only process tokens present in the corpus
            token_idf = idf.get(token, 0) # Get IDF for the token
            postings = index[token] # Get list of internal doc_ids containing the token

            for doc_id in postings:
                doc_len = doc_metadata[doc_id]["length"]
                term_freq_in_doc = tf[doc_id].get(token, 0) # Get frequency of the token in this specific doc

                # BM25 numerator
                numerator = token_idf * term_freq_in_doc * (k1 + 1)
                # BM25 denominator
                denominator = term_freq_in_doc + k1 * (1 - b + b * (doc_len / avg_dl))

                if denominator != 0: # Avoid division by zero
                    scores[doc_id] += numerator / denominator # Add score contribution

    # Sort scores in descending order and return top results (e.g., top 100)
    # Map internal doc_ids back to PMIDs for the final result
    sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)

    # Return list of (pmid, score) tuples
    return [(doc_metadata[doc_id]["pmid"], score) for doc_id, score in sorted_scores]


# --- Example Usage ---
json_file = 'data/allMeSH_2022.json' # Make sure this path is correct

print("Building index...")
index, doc_metadata, tf, idf, avg_dl = build_index_parallel(json_file, chunk_size=2000, num_workers=4) # Adjust chunk_size and num_workers based on your system

print("\nCalculating scores...")
query = "Is Hirschsprung disease a mendelian or a multifactorial disorder?"
results = calc_scores(query, index, doc_metadata, tf, idf, avg_dl)

print(f"\nTop 10 results for query: '{query}'")
for i, (pmid, score) in enumerate(results[:10]):
    print(f"{i+1}. PMID: {pmid}, Score: {score:.4f}")

top_pmids = [pmid for pmid, score in results[:100]]
print("\nTop 100 PMIDs:", top_pmids)

In [102]:
scores = calc_scores("Is Hirschsprung disease a mendelian or a multifactorial disorder?", tf, idf, doc_ids, avg_dl)

In [103]:
np.array(scores).argsort()[::-1][:100]

array([ 4393, 32519, 56329, 28015,  8112, 32430, 20705, 99219, 86489,
       12550,  1078,  8650, 18828, 24153,  5523,   675, 27617, 72717,
       30584, 14829,  7850, 17477, 51264, 78976, 93669,  4674, 27835,
       75720, 92379, 28803, 31679, 13747, 62475, 22590, 72817, 66776,
       34441, 73005, 89728, 52068,    46, 50776, 30623,  5406, 15981,
       35537, 11168, 10936, 19899, 31942, 81135, 76997, 72234, 49841,
        8006, 41076, 49050, 28025, 91050, 96974, 17654, 55831, 76223,
       67767, 43003, 14019, 89595, 49117, 73332, 75469, 96825, 27252,
       12992,   952, 97717, 43085, 51369, 24268, 77000, 31488, 26972,
        9523, 23727, 36188, 79966, 22853, 53691, 18348, 82363, 93782,
       40055, 76704, 28942, 13076, 96771, 20326, 13435, 23763, 83888,
       82954])

In [104]:
doc_ids[4393]

'34634250'