In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
import torch
import json
from langchain_huggingface import HuggingFaceEmbeddings
import argparse
import pandas as pd
from loguru import logger
from tqdm import tqdm

parent_folder = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.append(parent_folder)
print(parent_folder)

grandparent_dir = os.path.dirname(parent_folder)
sys.path.append(grandparent_dir)

#local imports
# from pipeline.qa_base_pipeline import QABasePipeline
from pipeline.qa_pipeline import QAPipeline
from retrieval.reranking import rerank_documents
from retrieval.keyword_db import KeywordDB
from retrieval.vector_db import VectorDB
from summarization.summarization import QueryBasedTextRankSummarizer

/home/mikhail/diploma_work/medrag


  from .autonotebook import tqdm as notebook_tqdm


# Load artifacts

In [3]:
# ML
# qa_base_pipeline = QABasePipeline()
query_based_summarizer = QueryBasedTextRankSummarizer()
keyword_database = KeywordDB(db_path='/s3/misha/data_dir/PMC_patients/db_bm25s')
vector_database = VectorDB(db_path='/s3/misha/data_dir/PMC_patients/db_faiss')

keyword_database.load_db()
vector_database.load_db()

device = torch.device('cuda:0')
model = HuggingFaceEmbeddings(model_name="neuml/pubmedbert-base-embeddings", 
                                        model_kwargs={'device': device},
                                        encode_kwargs={"normalize_embeddings": True})

qa_pipeline = QAPipeline(vector_database, keyword_database, query_based_summarizer, embedding_model=model, k=20)
logger.info('loaded all artifacts!')

[32m2025-04-05 11:59:39.266[0m | [1mINFO    [0m | [36mretrieval.vector_db[0m:[36mload_db[0m:[36m80[0m - [1mvector database loaded from the local file successfully![0m
[32m2025-04-05 11:59:42.056[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [1mloaded all artifacts![0m


## Get the testing data

In [4]:
import wikipedia

In [10]:
en_wiki_titles = wikipedia.search('COVID-19', results=10, suggestion=False)
en_wiki_titles

['COVID-19',
 'COVID-19 pandemic',
 'COVID-19 pandemic in the United States',
 'COVID-19 vaccine',
 'COVID-19 lockdowns',
 'COVID-19 testing',
 'COVID-19 misinformation',
 'List of deaths due to COVID-19',
 'Timeline of the COVID-19 pandemic',
 'COVID-19 pandemic in the United Kingdom']

In [5]:
def get_page(title):
    try:
        p = wikipedia.page(title, auto_suggest=False, redirect=True, preload=False)
        return p
    except wikipedia.DisambiguationError as e:
        s = e.options[0] #random.choice(e.options)
        p = wikipedia.page(s, auto_suggest=False, redirect=True, preload=False)
        return p

In [15]:
title = [] #[t for t in tqdm(en_wiki_titles)]
content = [] # [get_page(t).content for t in tqdm(en_wiki_titles)]
summary = [] #[get_page(t).summary for t in tqdm(en_wiki_titles)]
# links = [get_page(t).links for t in tqdm(en_wiki_titles)]


for t in tqdm(en_wiki_titles):
    page = get_page(t)
    title.append(t)
    content.append(page.content)
    summary.append(page.summary)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  3.92it/s]


In [16]:
import pandas as pd
ds = pd.DataFrame()
ds['title'] = title
ds['content'] = content
ds['summary'] = summary

In [17]:
ds

Unnamed: 0,title,content,summary
0,Stroke,Stroke is a medical condition in which poor bl...,Stroke is a medical condition in which poor bl...
1,Transient ischemic attack,"A transient ischemic attack (TIA), commonly kn...","A transient ischemic attack (TIA), commonly kn..."
2,Ischemia,Ischemia or ischaemia is a restriction in bloo...,Ischemia or ischaemia is a restriction in bloo...
3,Cerebral edema,Cerebral edema is excess accumulation of fluid...,Cerebral edema is excess accumulation of fluid...
4,Cerebral hypoxia,Cerebral hypoxia is a form of hypoxia (reduced...,Cerebral hypoxia is a form of hypoxia (reduced...
5,Tissue-type plasminogen activator,"Tissue-type plasminogen activator, short name ...","Tissue-type plasminogen activator, short name ..."
6,Alteplase,"Alteplase, sold under the brand name Activase ...","Alteplase, sold under the brand name Activase ..."
7,Cerebral infarction,"Cerebral infarction, also known as an ischemic...","Cerebral infarction, also known as an ischemic..."
8,Cerebrovascular disease,Cerebrovascular disease includes a variety of ...,Cerebrovascular disease includes a variety of ...
9,Embolic stroke of undetermined source,Embolic stroke of undetermined source (ESUS) i...,Embolic stroke of undetermined source (ESUS) i...


In [18]:
w_arts = [qa_pipeline.generate_wiki_response(t) for t in tqdm(title)]

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [02:10<00:00, 13.04s/it]


In [19]:
w_arts

['**Stroke**\n\nA stroke is a medical condition that occurs when the blood supply to the brain is interrupted or reduced, depriving the brain of oxygen and nutrients. This can lead to damage to the brain tissue, resulting in various neurological symptoms.\n\n**Causes and Risk Factors**\n\nThe causes of stroke are varied, but common risk factors include:\n\n*   High blood pressure\n*   Diabetes\n*   Smoking\n*   Obesity\n*   Physical inactivity\n*   Family history of stroke or cardiovascular disease\n\nAccording to the American Heart Association/American Stroke Association, continuous ECG monitoring and assessment of LV function (ECHO) and cardiac injury biomarkers are crucial for cardiac monitoring, especially during the first 24 hours. This helps screen for any myocardial damage to prevent sudden cardiac death.\n\n**Diagnosis**\n\nThe diagnosis of stroke involves a combination of clinical evaluation, imaging tests, and laboratory studies. The American Stroke Association recommends the

In [8]:
#ROUGE
from rouge_score import rouge_scorer

def calculate_rouge(reference_summary, generated_summary):
    """
    Calculates ROUGE scores between a reference summary and a generated summary.

    Args:
        reference_summary (str): The ground-truth summary.
        generated_summary (str): The summary to evaluate.

    Returns:
        dict: A dictionary containing ROUGE-1, ROUGE-2, and ROUGE-L scores (precision, recall, and F1).
    """
    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference_summary, generated_summary)

    # Extract and organize scores
    result = {
        "ROUGE-1": {
            "precision": scores['rouge1'].precision,
            "recall": scores['rouge1'].recall,
            "f1": scores['rouge1'].fmeasure,
        },
        "ROUGE-2": {
            "precision": scores['rouge2'].precision,
            "recall": scores['rouge2'].recall,
            "f1": scores['rouge2'].fmeasure,
        },
        "ROUGE-L": {
            "precision": scores['rougeL'].precision,
            "recall": scores['rougeL'].recall,
            "f1": scores['rougeL'].fmeasure,
        }
    }
    return result

reference = "The quick brown fox jumps over the lazy dog."
generated = "The quick brown fox leaps over the lazy dog."

print(f"Reference: {reference}")
print(f"Generated: {generated}")

rouge_scores = calculate_rouge(reference, generated)

print("ROUGE Scores:")
for metric, values in rouge_scores.items():
    print(f"{metric}:")
    print(f"  Precision: {values['precision']:.4f}")
    print(f"  Recall:    {values['recall']:.4f}")
    print(f"  F1 Score:  {values['f1']:.4f}")

Reference: The quick brown fox jumps over the lazy dog.
Generated: The quick brown fox leaps over the lazy dog.
ROUGE Scores:
ROUGE-1:
  Precision: 0.8889
  Recall:    0.8889
  F1 Score:  0.8889
ROUGE-2:
  Precision: 0.7500
  Recall:    0.7500
  F1 Score:  0.7500
ROUGE-L:
  Precision: 0.8889
  Recall:    0.8889
  F1 Score:  0.8889


In [21]:
rouge_scores["ROUGE-1"]['f1']

0.8888888888888888

In [22]:
wiki_r1_res = [calculate_rouge(r, g)["ROUGE-1"] for r, g in zip(tqdm(content), w_arts)] 

precision = [n['precision'] for n in wiki_r1_res]
recall = [n['recall'] for n in wiki_r1_res]
f1 = [n['f1'] for n in wiki_r1_res]

print('precision', sum(precision)/len(precision))
print('recall', sum(recall)/len(recall))
print('f1', sum(f1)/len(f1))

100%|██████████| 10/10 [00:05<00:00,  1.80it/s]

precision 0.7210948457812407
recall 0.16015247592625204
f1 0.24419790426804674





In [23]:
wiki_r1_res = [calculate_rouge(r, g)["ROUGE-1"] for r, g in zip(tqdm(summary), w_arts)] 

precision = [n['precision'] for n in wiki_r1_res]
recall = [n['recall'] for n in wiki_r1_res]
f1 = [n['f1'] for n in wiki_r1_res]

print('precision', sum(precision)/len(precision))
print('recall', sum(recall)/len(recall))
print('f1', sum(f1)/len(f1))

100%|██████████| 10/10 [00:00<00:00, 24.31it/s]

precision 0.2544338498384745
recall 0.559145387011817
f1 0.3367494229872742





In [24]:
print(sum([len(s) for s in summary])/10)
print(sum([len(s) for s in w_arts])/10)

1358.1
3135.1


# Wrap the evaluation code

In [None]:
import wikipedia
import pandas as pd

def evaluate_wiki_article_rouge(topic: str):
    en_wiki_titles = wikipedia.search(topic, results=10, suggestion=False)

    title = [] #[t for t in tqdm(en_wiki_titles)]
    content = [] # [get_page(t).content for t in tqdm(en_wiki_titles)]
    summary = [] #[get_page(t).summary for t in tqdm(en_wiki_titles)]
    # links = [get_page(t).links for t in tqdm(en_wiki_titles)]

    for t in tqdm(en_wiki_titles):
        page = get_page(t)
        title.append(t)
        content.append(page.content)
        summary.append(page.summary)
    
    ds = pd.DataFrame()
    ds['title'] = title
    ds['content'] = content
    ds['summary'] = summary

    w_arts = [qa_pipeline.generate_wiki_response(t) for t in tqdm(title)]

    # rouge between original summary and generated summary 
    wiki_r1_res = [calculate_rouge(r, g)["ROUGE-1"] for r, g in zip(tqdm(summary), w_arts)] 

    precision = [n['precision'] for n in wiki_r1_res]
    recall = [n['recall'] for n in wiki_r1_res]
    f1 = [n['f1'] for n in wiki_r1_res]

    print('precision', sum(precision)/len(precision))
    print('recall', sum(recall)/len(recall))
    print('f1', sum(f1)/len(f1))

In [None]:
# # rouge between content and generated summary
# wiki_r1_res = [calculate_rouge(r, g)["ROUGE-1"] for r, g in zip(tqdm(content), w_arts)] 

# precision = [n['precision'] for n in wiki_r1_res]
# recall = [n['recall'] for n in wiki_r1_res]
# f1 = [n['f1'] for n in wiki_r1_res]

# print('precision', sum(precision)/len(precision))
# print('recall', sum(recall)/len(recall))
# print('f1', sum(f1)/len(f1))

In [15]:
evaluate_wiki_article_rouge('Stroke')

100%|██████████| 10/10 [00:03<00:00,  2.94it/s]
100%|██████████| 10/10 [01:42<00:00, 10.26s/it]
100%|██████████| 10/10 [00:00<00:00, 30.03it/s]

precision 0.24241147534253576
recall 0.38041487963657794
f1 0.2557280987866649





In [11]:
evaluate_wiki_article_rouge('COVID-19')

100%|██████████| 10/10 [00:03<00:00,  2.60it/s]
100%|██████████| 10/10 [02:35<00:00, 15.57s/it]
100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


precision 0.6920262453997174
recall 0.1596007221803261
f1 0.1559219299780082


100%|██████████| 10/10 [00:00<00:00, 13.31it/s]

precision 0.2673066364548939
recall 0.5299296329639346
f1 0.333410043861453





In [16]:
evaluate_wiki_article_rouge('Tuberculosis')

100%|██████████| 10/10 [00:02<00:00,  3.78it/s]
100%|██████████| 10/10 [01:47<00:00, 10.73s/it]
100%|██████████| 10/10 [00:00<00:00, 42.06it/s]

precision 0.20961811999856286
recall 0.5236591352251082
f1 0.2841865800635807





In [21]:
evaluate_wiki_article_rouge('Lung cancer')



  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 10/10 [00:03<00:00,  3.12it/s]
100%|██████████| 10/10 [01:54<00:00, 11.50s/it]
100%|██████████| 10/10 [00:00<00:00, 22.47it/s]

precision 0.27174402641171164
recall 0.4160381851343645
f1 0.2908950871310097





In [22]:
evaluate_wiki_article_rouge('HIV')



  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 10/10 [00:03<00:00,  3.08it/s]
100%|██████████| 10/10 [01:57<00:00, 11.77s/it]
100%|██████████| 10/10 [00:00<00:00, 21.66it/s]

precision 0.31933602985028553
recall 0.5009543386833873
f1 0.35110719320179323





In [23]:
evaluate_wiki_article_rouge('Diabetes')



  lis = BeautifulSoup(html).find_all('li')
100%|██████████| 10/10 [00:02<00:00,  3.42it/s]
100%|██████████| 10/10 [02:19<00:00, 13.92s/it]
100%|██████████| 10/10 [00:00<00:00, 19.87it/s]

precision 0.3354016133633472
recall 0.5182177290266532
f1 0.38911428662283193



