In [20]:
import os
import textwrap
from typing import List, Dict, Set, Any
from Bio import Entrez
from time import sleep
import json
import yaml  
import pandas as pd
import re

from dotenv import load_dotenv
load_dotenv()

True

### get list of drug names

In [15]:
def load_config(yaml_path="P2-config.yaml"):
    with open(yaml_path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)
    
config = load_config()

all_drugs_path = config["all_drugs"]
all_drugs = pd.read_csv(all_drugs_path)
drug_names = all_drugs['DRUGNAME'].dropna().astype(str).unique().tolist()

drug_names_file = config["drug_names_file"]

with open(drug_names_file, "w", encoding="utf-8") as f:
    for name in drug_names:
        f.write(name + "\n")



In [16]:
class NCBIMeSHScraper:
    def __init__(self, email: str, config: Dict[str, Any]):
        self.email = email
        self.config = config

        Entrez.email = self.email
        Entrez.tool = "DrugPathwayAnalyzer"

        self.api_key = os.getenv("NCBI_API_KEY")
        if self.api_key:
            Entrez.api_key = self.api_key
            print("[INFO] NCBI API key loaded.")
        else:
            print("[WARN] NCBI_API_KEY not found. Continuing without key (stricter rate limits).")

    def _make_or_block(self, terms: List[str]) -> str:
        """
        ["foo", "bar baz"] -> '("foo" OR "bar baz")'
        Assumes terms are already quoted if needed.
        """
        clean_terms = [t.strip() for t in terms if t.strip()]
        if not clean_terms:
            return ""
        if len(clean_terms) == 1:
            return clean_terms[0]
        return "(" + " OR ".join(clean_terms) + ")"

    def search_pubmed(self, query: str, retmax: int = 50) -> List[str]:
        handle = Entrez.esearch(
            db="pubmed",
            term=query,
            retmax=retmax,
        )
        rec = Entrez.read(handle)
        handle.close()
        return rec.get("IdList", [])

    def fetch_pubmed_mesh(self, pmids: List[str]) -> Dict[str, List[str]]:
        """
        Returns { pmid: [mesh_term1, mesh_term2, ...], ... }
        """
        mesh_by_pmid: Dict[str, List[str]] = {}
        if not pmids:
            return mesh_by_pmid

        BATCH = 50
        for i in range(0, len(pmids), BATCH):
            batch_ids = pmids[i:i+BATCH]
            handle = Entrez.efetch(
                db="pubmed",
                id=",".join(batch_ids),
                retmode="xml"
            )
            records = Entrez.read(handle)
            handle.close()

            for article in records.get("PubmedArticle", []):
                # PMID
                pmid = None
                try:
                    pmid_val = article["MedlineCitation"]["PMID"]
                    pmid = (
                        pmid_val.get("#text", pmid_val)
                        if isinstance(pmid_val, dict) else pmid_val
                    )
                except KeyError:
                    pmid = None

                # MeSH terms list
                mesh_terms = []
                try:
                    mh_list = article["MedlineCitation"]["MeshHeadingList"]
                    for mh in mh_list:
                        desc = mh.get("DescriptorName")
                        if isinstance(desc, dict):
                            term_text = desc.get("#text", "")
                        else:
                            term_text = str(desc)
                        term_text = term_text.strip()
                        if term_text:
                            mesh_terms.append(term_text)
                except KeyError:
                    pass

                if pmid:
                    mesh_by_pmid[pmid] = mesh_terms

            # be polite if no key
            sleep(0.34)

        return mesh_by_pmid

    def get_joint_mesh_for_drug(
        self,
        drug_name: str,
        disease_blocks: List[Dict[str, Any]],
        therapy_terms: List[str],
        retmax: int = 50,
    ) -> Dict[str, Any]:
        """
        Generic version:
        - disease_blocks: list of disease configs, each with:
            - id (str)
            - label (str)
            - query_terms (List[str])
            - mesh_match_substrings (List[str])
        - therapy_terms: extra PubMed filter terms (e.g. ["\"drug therapy\""])
        """

        # Build per-disease OR blocks for PubMed
        disease_query_blocks = []
        for d in disease_blocks:
            q_terms = d.get("query_terms", [])
            disease_query_blocks.append(self._make_or_block(q_terms))

        therapy_block = self._make_or_block(therapy_terms)

        # Build the combined PubMed query
        query_parts = [f"\"{drug_name}\"[Title/Abstract]"]
        for dq in disease_query_blocks:
            if dq:
                query_parts.append(dq)
        if therapy_block:
            query_parts.append(therapy_block)

        query = " AND ".join(query_parts)

        pmids = self.search_pubmed(query, retmax=retmax)
        mesh_map = self.fetch_pubmed_mesh(pmids)

        # all MeSH terms across joint PMIDs
        all_mesh_terms: Set[str] = set(
            term for terms in mesh_map.values() for term in terms
        )

        # classify MeSH into per-disease buckets based on substrings
        disease_results = []
        for d in disease_blocks:
            d_id = d.get("id", "unknown")
            d_label = d.get("label", d_id)
            match_subs = d.get("mesh_match_substrings", [])

            hits = sorted({
                t for t in all_mesh_terms
                if any(sub in t for sub in match_subs)
            }) if match_subs else []

            disease_results.append({
                "id": d_id,
                "label": d_label,
                "query_terms": d.get("query_terms", []),
                "mesh_match_substrings": match_subs,
                "mesh_hits": hits,
            })

        # joint signal requires at least one hit for every disease
        has_joint_signal = bool(pmids) and all(
            len(d["mesh_hits"]) > 0 for d in disease_results
        )

        return {
            "drug": drug_name,
            "query": query,
            "pmids_joint": pmids,
            "mesh_terms_joint": sorted(all_mesh_terms),
            "diseases": disease_results,
            "has_joint_signal": has_joint_signal,
        }

    # ------------------------------------------------------------------
    # Pretty report generator
    # ------------------------------------------------------------------

    def _wrap_line(self, s: str, width: int = 100) -> str:
        return '\n'.join(textwrap.fill(
            s,
            width=width,
            replace_whitespace=False,
            break_long_words=False,
            break_on_hyphens=False
        ).splitlines())

    def generate_user_friendly_report(self, analysis: Dict[str, Any]) -> str:
        """
        Turn a single drug analysis dict (from get_joint_mesh_for_drug)
        into a human-readable multi-section text block.
        Supports an arbitrary number of diseases (from config).
        """

        drug_name = analysis["drug"]
        pmids = analysis["pmids_joint"]
        all_mesh = analysis["mesh_terms_joint"]
        diseases = analysis.get("diseases", [])
        has_joint = analysis.get("has_joint_signal", False)

        # If zero PMIDs, just say nothing found.
        if not pmids:
            disease_labels = [d["label"] for d in diseases] if diseases else ["the configured diseases"]
            disease_str = " and ".join(disease_labels)
            return (
                f"DRUG PATHWAY ANALYSIS REPORT: {drug_name.upper()}\n"
                f"----------------------------------------------------------------------------\n"
                f"No PubMed articles were found that mention {drug_name} together with BOTH\n"
                f"{disease_str} in the same query.\n"
            )

        out_lines: List[str] = []

        out_lines.append("=" * 80)
        out_lines.append(f"DRUG PATHWAY ANALYSIS REPORT: {drug_name.upper()}")
        out_lines.append("=" * 80)
        out_lines.append("")

        # 1. SUMMARY
        out_lines.append("SUMMARY")
        out_lines.append("-" * 80)
        out_lines.append(f"Total joint-hit PMIDs: {len(pmids)}")
        out_lines.append(f"Dual-indication signal detected? {'YES' if has_joint else 'NO'}")
        out_lines.append("")
        out_lines.append("PubMed Boolean query used:")
        out_lines.append(self._wrap_line(analysis["query"]))
        out_lines.append("")

        # 2. MESH EVIDENCE
        out_lines.append("MESH EVIDENCE (DISEASE AREAS)")
        out_lines.append("-" * 80)
        for d in diseases:
            out_lines.append(f"{d['label']} MeSH terms observed in these same PMIDs:")
            hits = d.get("mesh_hits", [])
            if hits:
                for term in hits:
                    out_lines.append(f"  - {term}")
            else:
                out_lines.append("  (none)")
            out_lines.append("")

        # 3. ALL MESH TERMS
        out_lines.append("ALL UNIQUE MeSH TERMS FROM JOINT-HIT ARTICLES")
        out_lines.append("-" * 80)
        for term in all_mesh:
            out_lines.append("  - " + term)

        out_lines.append("")

        # 4. PMIDs
        out_lines.append("SUPPORTING PMIDs")
        out_lines.append("-" * 80)
        pmid_chunks = [pmids[i:i+10] for i in range(0, len(pmids), 10)]
        for chunk in pmid_chunks:
            line = ", ".join(chunk)
            out_lines.append("  " + line)

        out_lines.append("")
        out_lines.append("INTERPRETATION")
        out_lines.append("-" * 80)

        disease_labels = [d["label"] for d in diseases]
        disease_str = " and ".join(disease_labels) if disease_labels else "the configured diseases"

        if has_joint:
            out_lines.append(
                self._wrap_line(
                    f"These PubMed articles simultaneously reference the drug and biological "
                    f"contexts related to {disease_str}. The MeSH labeling suggests that "
                    f"researchers are discussing this drug in a setting that touches multiple "
                    f"disease areas. This is a strong multi-indication signal, but does not "
                    f"prove therapeutic efficacy."
                )
            )
        else:
            out_lines.append(
                self._wrap_line(
                    f"Although PubMed returned articles where the drug and the configured "
                    f"disease-related terms co-occur in the same query, the MeSH annotations "
                    f"did not clearly include all disease areas together. This weakens confidence "
                    f"that the drug is being studied directly in a multi-indication context."
                )
            )

        return "\n".join(out_lines)







In [19]:
if __name__ == "__main__":
    # --- load config from YAML ---

    config = load_config()


    ncbi_cfg = config.get("ncbi", {})
    email = os.getenv("NCBI_EMAIL")

    scraper = NCBIMeSHScraper(email=email, config=config)
    drug_names_file = config["drug_names_file"]

    # Iterate over all disease pairs defined in YAML
    for pair_cfg in config.get("disease_pairs", []):
        pair_name = pair_cfg.get("name", "unnamed_pair")
        print(f"\n=== Running disease pair: {pair_name} ===")

        
        reports_dir = pair_cfg["reports_dir"]
        results_json = pair_cfg["results_json"]
        retmax = pair_cfg.get("retmax", 50)
        sleep_between_drugs = pair_cfg.get("sleep_between_drugs", 0.5)
        therapy_terms = pair_cfg.get("therapy_terms", [])
        disease_blocks = pair_cfg["diseases"]

        # load the drug list
        with open(drug_names_file, "r", encoding="utf-8") as f:
            drug_names = [line.strip() for line in f if line.strip()]

        # os.makedirs(reports_dir, exist_ok=True)

        results = []
        for drug in drug_names:
            print(f"[CHECK] {drug}")
            analysis = scraper.get_joint_mesh_for_drug(
                drug_name=drug,
                disease_blocks=disease_blocks,
                therapy_terms=therapy_terms,
                retmax=retmax,
            )
            results.append(analysis)
            sleep(sleep_between_drugs)


 
            report_txt = scraper.generate_user_friendly_report(analysis)
            safe_name = drug.replace("/", "_").replace(" ", "_")
            print(safe_name)
            report_path = os.path.join(reports_dir, f"{safe_name}_report.txt")
            with open(report_path, "w", encoding="utf-8") as rf:
                rf.write(report_txt)

        # console summary of which drugs look promising
        print("\n=== DRUGS WITH DIRECT JOINT SIGNAL (MESH CONFIRMED) ===")
        for r in results:
            if r["has_joint_signal"]:
                print(f"- {r['drug']} ({len(r['pmids_joint'])} PMIDs)")

        # dump raw JSON for downstream ML
        with open(results_json, "w", encoding="utf-8") as fh:
            json.dump(results, fh, indent=2)
        print(f"[OK] wrote {results_json} and individual text reports in {reports_dir}/")

[INFO] NCBI API key loaded.

=== Running disease pair: diabetes_crc ===
[CHECK] Pemigatinib
Pemigatinib
[CHECK] Intedanib
Intedanib
[CHECK] Romiplostim
Romiplostim
[CHECK] Necitumumab
Necitumumab
[CHECK] HEGF
HEGF
[CHECK] Erlotinib
Erlotinib
[CHECK] Gefitinib
Gefitinib
[CHECK] Panitumumab
Panitumumab
[CHECK] Cetuximab
Cetuximab
[CHECK] BIBW 2992
BIBW_2992
[CHECK] Nitroglycerin
Nitroglycerin
[CHECK] Epidermal growth factor
Epidermal_growth_factor
[CHECK] Osimertinib
Osimertinib
[CHECK] NERATINIB MALEATE
NERATINIB_MALEATE
[CHECK] Amivantamab
Amivantamab
[CHECK] SKI-758
SKI-758
[CHECK] Dacomitinib
Dacomitinib
[CHECK] Lapatinib
Lapatinib
[CHECK] Merimepodib
Merimepodib
[CHECK] Vandetanib
Vandetanib
[CHECK] Sorafenib
Sorafenib
[CHECK] Adenosine triphosphate
Adenosine_triphosphate
[CHECK] Bosutinib
Bosutinib
[CHECK] Ponatinib
Ponatinib
[CHECK] Estrone
Estrone
[CHECK] Conjugated estrogens b
Conjugated_estrogens_b
[CHECK] Raloxifene
Raloxifene
[CHECK] Ospemifene
Ospemifene
[CHECK] Estropipate


### extract pmid and put into file

In [25]:
config = load_config()
for pair_cfg in config.get("disease_pairs", []):
    pair_name = pair_cfg.get("name", "unnamed_pair")
    print(f"\n=== Running disease pair: {pair_name} ===")

    
    reports_dir = pair_cfg["reports_dir"]
    qualified_drugs = []
    for filename in os.listdir(reports_dir):
        if not filename.endswith("_report.txt"):
            continue

        filepath = os.path.join(reports_dir, filename)
        with open(filepath, "r", encoding="utf-8") as f:
            content = f.read()

    

        # extract drug name and PMID count
        drug_match = re.search(r"DRUG PATHWAY ANALYSIS REPORT:\s*(.+)", content)
        pmid_match = re.search(r"Total joint-hit PMIDs:\s*(\d+)", content)

        # extract all 6â€“9 digit numeric sequences (likely PMIDs)
        pmids = re.findall(r"\b\d{6,9}\b", content)

        if drug_match and pmid_match:
            drug = drug_match.group(1).strip()
            pmid_count = int(pmid_match.group(1))
            if pmid_count >= 0:
                qualified_drugs.append((drug, pmid_count, pmids))

    # # sort by PMID count descending
    qualified_drugs.sort(key=lambda x: x[1], reverse=True)

    # print results
    print("=== DRUGS WITH JOINT-HIT PMIDs ===")
    for drug, count, pmids in qualified_drugs:
        print(f"- {drug}: {count} PMIDs")
        print(f"  PMIDs: {', '.join(pmids[:10])}{'...' if len(pmids) > 10 else ''}")

    print(f"\nTotal qualifying drugs for {pair_name}: {len(qualified_drugs)}")

    df = pd.DataFrame(qualified_drugs, columns=["Drug", "PMID Count", "PMIDs"])
    df["PMIDs"] = df["PMIDs"].apply(lambda x: ", ".join(x) if isinstance(x, list) else "")
    initial_pmid_list_path = pair_cfg["initial_pmid_list"] 
    df.to_csv(initial_pmid_list_path, index=False)
    print(f"Qualified drugs saved to {initial_pmid_list_path}")
    



=== Running disease pair: diabetes_crc ===
=== DRUGS WITH JOINT-HIT PMIDs ===
- METFORMIN: 50 PMIDs
  PMIDs: 40670504, 40649818, 40572709, 40471238, 40353629, 40299707, 40136342, 40124270, 39994354, 39600254...
- CETUXIMAB: 14 PMIDs
  PMIDs: 38763818, 35467766, 35279470, 33741057, 29054804, 27881709, 27681944, 26517691, 26405092, 23921573...
- EPIDERMAL GROWTH FACTOR: 10 PMIDs
  PMIDs: 38992135, 36546770, 35888591, 26850678, 26831715, 26517691, 26405092, 22439666, 22210091, 20713879
- BEVACIZUMAB: 7 PMIDs
  PMIDs: 41019039, 35888591, 29066701, 26972374, 26517691, 21609927, 18056916
- PIOGLITAZONE: 7 PMIDs
  PMIDs: 38378472, 32170185, 29637913, 23345544, 23098518, 22135104, 19139117
- ADENOSINE TRIPHOSPHATE: 5 PMIDs
  PMIDs: 37522672, 37442756, 30155759, 29463225, 19383331
- THIAZOLIDINEDIONE: 5 PMIDs
  PMIDs: 31887708, 23648711, 19908241, 18930061, 17192841
- ESTROGEN: 4 PMIDs
  PMIDs: 39132768, 34529197, 33359221, 32859627
- EXENATIDE: 4 PMIDs
  PMIDs: 40437949, 39777709, 27470345, 2