<a href="https://colab.research.google.com/github/dkisselev-zz/mmc-pipeline/blob/main/Microbiome_Vector_Graphs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Microbiome Vector Graph

## Authenticate and Configure

In [None]:
# ==============================================================================
# Step 0: Install Libraries
# ==============================================================================
!pip install gensim pyvis

Collecting gensim
  Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting pyvis
  Downloading pyvis-0.3.2-py3-none-any.whl.metadata (1.7 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m61.0/61.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m60.6/60.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting jedi>=0.16 (from ipython>=5.3.0->pyvis)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.

In [None]:
# ==============================================================================
# Step 1: Import Necessary Libraries
# ==============================================================================
# We need gensim for Word2Vec, networkx for graph manipulation,
# and pyvis for beautiful interactive visualizations.
# !pip install pandas gensim networkx pyvis beautifulsoup4 requests

import os
import re
import time
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Any, Set

import json
import requests
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor

import pandas as pd
import numpy as np

import tarfile
import xml.etree.ElementTree as ET

import nltk
from nltk.corpus import stopwords

from gensim.models import Word2Vec
from gensim.models.phrases import Phrases, Phraser

import networkx as nx
from pyvis.network import Network

from google.colab import auth
from google.colab import userdata
from google.colab.data_table import DataTable
from google.auth import default
import google.generativeai as genai
import gspread

print("‚úÖ Libraries imported successfully!")

‚úÖ Libraries imported successfully!


In [None]:
# ==============================================================================
# Step 2: Authenticate and Configure variables
# ==============================================================================

# Grab variables from the form
spreadsheet_url = "https://docs.google.com/spreadsheets/d/1tBjpV_GoXIjx4_3o73Qml0h-BzFBZTD5bc3QHx1l1_4" # @param {"type":"string"}
worksheet_name = "Main Data Sheet" # @param {"type":"string"}
header_indx = 1 # @param {"type":"integer"}

# Get secrets from Colab environment
# Authenticate to access Google Sheet
try:
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    print("Authentication successful.")
except Exception as e:
    print(f"Authentication failed. Please ensure you are in a Google Colab environment. Error: {e}")

# Configure Gemini API
try:
    API_KEY = userdata.get('GOOGLE_API_KEY')
    genai.configure(api_key=API_KEY)
    print("Gemini API configured successfully.")
except Exception as e:
    print(f"Could not configure Gemini API. Please add GOOGLE_API_KEY to your Colab secrets. Error: {e}")

try:
    EMAIL = userdata.get('EMAIL')
except (ValueError, FileNotFoundError):
    raise ValueError("EMAIL not found in Colab secrets. Please add it.")

# Load ICD-11 WHO API Keys
try:
    ICD11_CLIENT_ID = userdata.get('ICD11_CLIENT_ID')
    ICD11_CLIENT_SECRET = userdata.get('ICD11_CLIENT_SECRET')
    print("ICD-11 API configured successfully.")
except Exception as e:
    print(f"Could not configure ICD-11 API. Please add ICD11_CLIENT_ID and ICD11_CLIENT_SECRET to your Colab secrets. Error: {e}")

# Load NCBI API Key if it exists
try:
    NCBI_API_KEY = userdata.get('NCBI_API_KEY')
    print("NCBI API Key loaded successfully.")
except Exception:
    NCBI_API_KEY = None
    print("NCBI API Key not found in Colab secrets. Proceeding with lower rate limits.")

# Microbial dictionary
microbe_dict = 'microbe_dictionary_hierarchical.json'

# Disease dictionary
disease_dict = 'disease_dictionary_hierarchical.json'

ABSTRACT_DICT_PATH = 'abstract_dictionary.json'

print("‚úÖ Authenticated and Configured variables")

Authentication successful.
Gemini API configured successfully.
ICD-11 API configured successfully.
NCBI API Key loaded successfully.
‚úÖ Authenticated and Configured variables


## Build Utility Classes

In [None]:
# ==============================================================================
# Step 3.1: Abstracting the Taxonomy Interface
# ==============================================================================

class TaxonomyProvider(ABC):
    """
    Abstract base class defining the standard interface for a taxonomy provider.
    This ensures that both the NCBI file-based and ICD-11 API-based data
    sources can be used interchangeably by the main application logic.
    """

    @abstractmethod
    def get_node(self, node_id: str) -> Dict[str, Any]:
        """Retrieves all available information for a given node ID."""
        pass

    @abstractmethod
    def get_parents(self, node_id: str) -> List[str]:
        """Retrieves a list of parent IDs for a given node ID."""
        pass

    @abstractmethod
    def get_children(self, node_id: str) -> List[str]:
        """Retrieves a list of child IDs for a given node ID."""
        pass

    @abstractmethod
    def get_lineage(self, node_id: str) -> Set[str]:
        """Retrieves the set of all ancestor IDs for a given node ID."""
        pass

    @abstractmethod
    def get_name(self, node_id: str) -> str:
        """Retrieves the primary name for a given node ID."""
        pass

    @abstractmethod
    def get_synonyms(self, node_id: str) -> List[str]:
        """Retrieves a list of synonyms for a given node ID."""
        pass

In [None]:
# ==============================================================================
# Step 3.2: NCBI Taxonmy Provider
# ==============================================================================
class NCBITaxonomy(TaxonomyProvider):
    """
    A concrete implementation of TaxonomyProvider for the NCBI Taxonomy database.
    This class handles downloading, parsing, and providing access to the data
    from the taxdump files (names.dmp, nodes.dmp).
    """
    def __init__(self, taxdump_dir: str = '.'):
        self.taxdump_dir = taxdump_dir
        self.nodes_file = os.path.join(self.taxdump_dir, 'nodes.dmp')
        self.names_file = os.path.join(self.taxdump_dir, 'names.dmp')

        self._ensure_taxdump_files_exist()

        print("üß† Loading and processing NCBI taxonomy data...")
        self._load_data()
        print("‚úÖ NCBITaxonomy provider initialized.")


    def _ensure_taxdump_files_exist(self):
        """Downloads and extracts NCBI taxdump files if they don't exist."""
        if os.path.exists(self.names_file) and os.path.exists(self.nodes_file):
            print("‚úÖ NCBI Taxonomy files already exist. Skipping download.")
            return

        url = "https://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdump.tar.gz"
        gz_filename = "taxdump.tar.gz"
        print(f"üåé Downloading NCBI Taxonomy database from {url}...")
        try:
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(gz_filename, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            print("‚úÖ Download complete.")

            print("üì¶ Extracting files...")
            with tarfile.open(gz_filename, "r:gz") as tar:
                tar.extractall(path=self.taxdump_dir)
            print("‚úÖ Extraction complete.")
            os.remove(gz_filename)
        except Exception as e:
            raise RuntimeError(f"Failed to download or extract NCBI taxdump: {e}")

    def _load_data(self):
        """Parses the .dmp files and populates internal data structures."""
        # --- Part A: Read nodes.dmp for taxonomic structure ---
        nodes_df = pd.read_csv(self.nodes_file, sep='|', header=None, engine='python',
                               usecols=[0, 1, 2], names=['tax_id', 'parent_tax_id', 'rank'])
        nodes_df = nodes_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)
        self.parent_map = nodes_df.set_index('tax_id')['parent_tax_id'].to_dict()
        self.rank_map = nodes_df.set_index('tax_id')['rank'].to_dict()

        # --- Part B: Read names.dmp to collect all microbe names ---
        names_df = pd.read_csv(self.names_file, sep='|', header=None, engine='python',
                               usecols=[0, 1, 3], names=['tax_id', 'name_txt', 'name_class'])
        names_df = names_df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

        # Cache names and synonyms
        scientific_names = names_df[names_df['name_class'] == 'scientific name']
        self.taxid_to_name = scientific_names.set_index('tax_id')['name_txt'].to_dict()

        valid_alias_classes = ['synonym', 'equivalent name', 'acronym', 'genbank acronym', 'common name', 'genbank common name']
        self.taxid_to_synonyms = names_df[names_df['name_class'].isin(valid_alias_classes)]\
            .groupby('tax_id')['name_txt'].apply(list).to_dict()

        # Build children map by reversing the parent map
        self.children_map = defaultdict(list)
        for child, parent in self.parent_map.items():
            self.children_map[parent].append(child)


    # --- Implementation of Abstract Methods ---
    def get_node(self, node_id: int) -> Dict[str, Any]:
        return {
            "tax_id": node_id,
            "parent_tax_id": self.get_parents(node_id)[0] if self.get_parents(node_id) else None,
            "name": self.get_name(node_id),
            "rank": self.rank_map.get(node_id),
            "synonyms": self.get_synonyms(node_id)
        }

    def get_parents(self, node_id: int) -> List[int]:
        parent = self.parent_map.get(int(node_id))
        return [parent] if parent and parent != node_id else []

    def get_children(self, node_id: int) -> List[int]:
        return self.children_map.get(int(node_id), [])

    def get_lineage(self, node_id: int) -> Set[int]:
        lineage = set()
        current_id = int(node_id)
        while current_id in self.parent_map and current_id != self.parent_map[current_id]:
            parent_id = self.parent_map[current_id]
            lineage.add(parent_id)
            current_id = parent_id
        return lineage

    def get_name(self, node_id: int) -> str:
        return self.taxid_to_name.get(int(node_id))

    def get_synonyms(self, node_id: int) -> List[str]:
        return self.taxid_to_synonyms.get(int(node_id), [])



In [None]:
# ==============================================================================
# Step 3.3: Caching Mechanism
# A file-based cache that persists between sessions.
# ==============================================================================
class JsonFileCache:
    """A simple file-based JSON cache."""
    def __init__(self, cache_path='api_cache.json'):
        self.cache_path = cache_path
        self._cache = self._load_cache()

    def _load_cache(self):
        if os.path.exists(self.cache_path):
            try:
                with open(self.cache_path, 'r') as f:
                    print(f"üíæ Loading API cache from '{self.cache_path}'.")
                    return json.load(f)
            except json.JSONDecodeError:
                print("‚ö†Ô∏è Cache file is corrupted. Starting with an empty cache.")
                return {}
        return {}

    def get(self, key: str):
        return self._cache.get(key)

    def set(self, key: str, value: Any):
        self._cache[key] = value

    def save(self):
        with open(self.cache_path, 'w') as f:
            json.dump(self._cache, f, indent=2)
        print(f"üíæ API cache saved to '{self.cache_path}'.")

# ==============================================================================
# Step 3.4: ICD-11 API Handler
# This class handles all direct interaction with the WHO ICD-11 API, including
# authentication and caching.
# ==============================================================================
class ICD11Handler:
    """Handles authentication and data fetching from the WHO ICD-11 API."""
    def __init__(self, client_id: str, client_secret: str, cache: JsonFileCache):
        self.token_url = "https://icdaccessmanagement.who.int/connect/token"
        self.base_url = "https://id.who.int/icd"
        self.client_id = client_id
        self.client_secret = client_secret
        self.access_token = None
        self.token_expires_at = None
        self.headers = {'Accept': 'application/json', 'API-Version': 'v2', 'Accept-Language': 'en'}
        self.cache = cache
        self.request_counter = 0

    def _get_access_token(self):
        # (Logic reused from the original notebook's builder class)
        print("\nüîë Requesting new ICD-11 API access token...")
        try:
            token_data = {'grant_type': 'client_credentials', 'client_id': self.client_id, 'client_secret': self.client_secret, 'scope': 'icdapi_access'}
            response = requests.post(self.token_url, data=token_data)
            response.raise_for_status()
            token_info = response.json()
            self.access_token = token_info['access_token']
            self.token_expires_at = time.time() + token_info.get('expires_in', 3600) - 300
            self.headers['Authorization'] = f'Bearer {self.access_token}'
            print("‚úÖ Successfully obtained access token.")
        except Exception as e:
            raise RuntimeError(f"Could not obtain ICD-11 access token: {e}")

    def _ensure_valid_token(self):
        if not self.access_token or time.time() >= self.token_expires_at:
            self._get_access_token()

    def get_entity(self, entity_uri: str, use_cache: bool = True) -> Optional[Dict]:
        """Central data retrieval method with caching."""
        entity_id = entity_uri.split('/')[-1]
        if use_cache and (cached_data := self.cache.get(entity_id)):
            return cached_data

        self.request_counter += 1
        self._ensure_valid_token()

        try:
            # We need both the entity URI (for parents/children) and the linearization URI (for the code)
            entity_res = requests.get(entity_uri, headers=self.headers)
            linearization_url = f"{self.base_url}/release/11/2025-01/mms/{entity_id}"
            linearization_res = requests.get(linearization_url, headers=self.headers)

            if entity_res.status_code == 200:
                entity_data = entity_res.json()
                entity_data['code'] = linearization_res.json().get('code', 'N/A') if linearization_res.status_code == 200 else 'N/A'
                self.cache.set(entity_id, entity_data)
                return entity_data
            return None
        except requests.RequestException as e:
            print(f"‚ö†Ô∏è API request failed for entity {entity_id}: {e}")
            return None

print("‚úÖ ICD-11 Caching and API Handler classes defined.")

‚úÖ ICD-11 Caching and API Handler classes defined.


In [None]:
# ==============================================================================
# Step 3.5: ICD-11 Taxonomy Provider
# It uses the `ICD11Handler` to fetch data and adapts the responses.
# ==============================================================================
class ICD11Taxonomy(TaxonomyProvider):
    """Concrete implementation of TaxonomyProvider for the WHO ICD-11 API."""
    def __init__(self, client_id: str, client_secret: str, cache_path: str = 'icd11_api_cache.json', num_threads: int = 10):
        cache = JsonFileCache(cache_path)
        self.handler = ICD11Handler(client_id, client_secret, cache)
        self.id_to_title_map = {eid: data.get('title', {}).get('@value', 'Unknown')
                                for eid, data in self.handler.cache._cache.items()}
        self.num_threads = num_threads  # Store the number of threads
        print(f"‚úÖ ICD11Taxonomy provider initialized with {num_threads} concurrent threads.")

    def _normalize_list(self, value: Any) -> List[str]:
        """Ensures the API response for parents/children is always a list."""
        if not value: return []
        if isinstance(value, list): return value
        return [value]

    def get_node(self, node_uri: str) -> Dict[str, Any]:
        return self.handler.get_entity(node_uri)

    def get_parents(self, node_uri: str) -> List[str]:
        data = self.handler.get_entity(node_uri)
        return self._normalize_list(data.get('parent')) if data else []

    def get_children(self, node_uri: str) -> List[str]:
        data = self.handler.get_entity(node_uri)
        return self._normalize_list(data.get('child')) if data else []

    def get_name(self, node_uri: str) -> str:
        # Use local map for speed if available, otherwise fetch
        node_id = node_uri.split('/')[-1]
        if name := self.id_to_title_map.get(node_id):
            return name
        if data := self.handler.get_entity(node_uri):
            name = data.get('title', {}).get('@value', 'Unknown')
            self.id_to_title_map[node_id] = name
            return name
        return "Unknown"


    def get_synonyms(self, node_uri: str) -> List[str]:
        data = self.handler.get_entity(node_uri)
        aliases = set()
        if not data: return []
        if 'synonym' in data:
            aliases.update(s['label']['@value'] for s in data['synonym'])
        if 'fullySpecifiedName' in data:
            aliases.add(data['fullySpecifiedName']['@value'])
        if 'inclusion' in data:
            aliases.update(inc['label']['@value'] for inc in data.get('inclusion', []) if 'label' in inc and '@value' in inc['label'])
        return sorted(list(aliases))

    def get_lineage(self, node_uri: str) -> Set[str]:
        # Implements a graph traversal to find all ancestors
        queue = [node_uri]
        visited = set()
        lineage = set()
        while queue:
            current_uri = queue.pop(0)
            if current_uri in visited: continue
            visited.add(current_uri)
            parents = self.get_parents(current_uri)
            for parent_uri in parents:
                lineage.add(parent_uri)
                queue.append(parent_uri)
        return lineage

    def build_disease_dictionary(self, root_chapters: Dict[str, str], disease_dict_path: str):
        """
        Builds the full disease dictionary by traversing the ICD-11 graph using a
        concurrent, multi-threaded approach to fetch API data.
        """
        print(f"\nüå≤ Starting concurrent graph traversal from {len(root_chapters)} root chapters...")
        self.handler.request_counter = 0

        queue = [f"{self.handler.base_url}/entity/{root_id}" for root_id in root_chapters.values()]
        visited = set(queue) # Pre-populate visited set to avoid duplicate processing

        # Use a ThreadPoolExecutor to manage a pool of worker threads
        with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
            while queue:
                # Define a batch size, a multiple of the thread count is often efficient
                batch_size = self.num_threads * 5
                # Take the first `batch_size` items from the queue for processing
                current_batch = [queue.pop(0) for _ in range(min(batch_size, len(queue)))]

                if not current_batch:
                    continue

                # Concurrently fetch data for the entire batch. The executor's `map` function
                # applies `self.handler.get_entity` to each URI in the batch across multiple threads.
                # The list() call forces the execution and waits for all threads to complete.
                results = list(executor.map(self.handler.get_entity, current_batch))

                # Process the results synchronously to find the next set of children
                for entity_data in results:
                    if entity_data:
                        # Extract child URIs from the fetched data
                        children = self._normalize_list(entity_data.get('child'))
                        for child_uri in children:
                            # Add new, unvisited children to the end of the queue
                            if child_uri not in visited:
                                visited.add(child_uri)
                                queue.append(child_uri)

                print(f"\r   - API Requests: {self.handler.request_counter}, Entities Processed: {len(visited)}, Queue Size: {len(queue)}  ", end='', flush=True)

        print(f"\n\n‚úÖ Traversal complete. Total unique entities processed: {len(visited)}.")
        self.handler.cache.save()

        # The final dictionary processing remains the same
        print("üõ†Ô∏è  Processing cached data into final dictionary...")
        final_dict = {}
        cached_entities = self.handler.cache._cache
        for entity_id, data in cached_entities.items():
            canonical_name = data.get('title', {}).get('@value')
            if not canonical_name: continue
            parent_uris = self._normalize_list(data.get('parent'))
            parent_id = parent_uris[0].split('/')[-1] if parent_uris else None
            parent_data = cached_entities.get(parent_id) if parent_id else {}

            if not parent_data:
            # If no data is found, assign default "root" values
              parent_name = "ICD-11 Root"
              parent_code = "N/A"
            else:
              # If data exists, proceed with the safe extraction
              parent_name = parent_data.get('title', {}).get('@value', "ICD-11 Root")
              parent_code = parent_data.get('code', 'N/A')
            final_dict[canonical_name] = {
                "icd11_code": data.get('code', 'N/A'),
                "parent_name": parent_name,
                "parent_code": parent_code,
                "aliases": self.get_synonyms(data.get('@id'))
            }
        with open(disease_dict_path, 'w') as f:
            json.dump(final_dict, f, indent=2)
        print(f"‚úÖ Final dictionary with {len(final_dict)} entries saved to '{disease_dict_path}'.")
        return final_dict

    print("‚úÖ ICD11Taxonomy provider class defined.")

‚úÖ ICD11Taxonomy provider class defined.


In [None]:
# ==============================================================================
# Step 3.6: NCBI Publication Metadata fetcher
# ==============================================================================
class PublicationFetcher:
    """Handles the fetching of publication abstracts from NCBI."""
    def __init__(self, email: str, api_key: Optional[str] = None):
        self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/"
        self.api_key = api_key
        self.email = email
        print("‚úÖ PublicationFetcher initialized.")

    def _make_ncbi_request(self, base_url, params=None, data=None, retries=3, delay=2):
        """Makes a request to the NCBI API with retries and backoff."""
        # Add a small delay to respect NCBI API rate limits
        request_delay = 0.1 if self.api_key else 0.4
        time.sleep(request_delay)

        # Add API key and email to every request for tracking
        if self.api_key:
            if params: params['api_key'] = self.api_key
            if data: data['api_key'] = self.api_key
        if params: params['email'] = self.email
        if data: data['email'] = self.email

        for attempt in range(retries):
            try:
                if data:
                    response = requests.post(base_url, data=data, timeout=45)
                else:
                    response = requests.get(base_url, params=params, timeout=45)
                response.raise_for_status()
                return response.content
            except requests.exceptions.RequestException as e:
                print(f"  > WARNING: Request failed on attempt {attempt + 1}/{retries}: {e}")
                if attempt < retries - 1:
                    wait_time = delay * (2 ** attempt)
                    print(f"  > Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
        return None

    def get_abstract_by_doi(self, doi: str) -> Optional[str]:
        """
        Fetches a publication's abstract from PubMed using its DOI.

        Args:
            doi: The Digital Object Identifier of the article.

        Returns:
            The abstract text as a string, or None if not found.
        """
        if not isinstance(doi, str) or not doi:
            return None

        # 1. Use esearch to find the PubMed ID (PMID) for the DOI
        search_params = {'db': 'pubmed', 'term': f'"{doi}"[aid]', 'retmode': 'xml'}
        search_response = self._make_ncbi_request(f"{self.base_url}esearch.fcgi", params=search_params)
        if not search_response:
            return None

        try:
            pmid_root = ET.fromstring(search_response)
            pmid = pmid_root.findtext('.//Id')
            if not pmid:
                return None
        except ET.ParseError:
            return None

        # 2. Use efetch to get article details with the PMID
        fetch_params = {'db': 'pubmed', 'id': pmid, 'retmode': 'xml'}
        fetch_response = self._make_ncbi_request(f"{self.base_url}efetch.fcgi", params=fetch_params)
        if not fetch_response:
            return None

        # 3. Parse the XML to find the abstract text
        try:
            article_root = ET.fromstring(fetch_response)
            abstract_text_elements = article_root.findall('.//Abstract/AbstractText')
            if abstract_text_elements:
                # Join text from all AbstractText elements, handling structured abstracts
                full_abstract = " ".join(ET.tostring(elem, method='text', encoding='unicode').strip() for elem in abstract_text_elements)
                return full_abstract.strip()
            else:
                return None # Abstract not found
        except ET.ParseError:
            return None


## Build and Load Data

In [None]:
# ==============================================================================
# Step 4.1: Build Microbial Dictionary form NCBI Tax data
# ==============================================================================
def is_alias_valid_generalized(alias):
    # (Re-using the validation function from the original notebook)
    alias_lower = alias.lower()

    # Rule 1: Reject if it looks like a strain/culture collection code (e.g., ATCC, DSM, NCTC)
    # This also catches things like 'strain ABC' or 'isolate 123'.
    if re.search(r'\b(atcc|nrcc|dsm|nctc|ukmcc|ccug|cip|jcm|lmg|strain|isolate)\b', alias_lower):
        return False

    # Rule 2: Reject if it contains a year (likely a citation)
    if re.search(r'\b(18|19|20)\d{2}\b', alias_lower):
        return False

    # Rule 3: Reject if it's a generic placeholder
    # if alias_lower.startswith(('bacterium ', 'unidentified ', 'unclassified ', 'endosymbiont of')):
        # return False
    if any(placeholder in alias_lower for placeholder in ['bacterium ', 'unidentified ', 'unclassified ', 'endosymbiont of']):
        return False

    # Rule 4: Reject if it ends with a sequence of letters and numbers that looks like a code
    if re.search(r'\s[A-Z0-9\-_]{5,}$', alias):
         return False

    # Rule 5: Reject if the name is just a short code
    if len(alias) < 4 and not '.' in alias:
        return False

    # Rule 6: Reject if it contains certain keywords that indicate it's not a standard name
    if any(keyword in alias_lower for keyword in ['subgroup', 'serovar', 'genomosp.', ' genomovar']):
      return False

    return True

def build_microbe_dictionary_from_provider(provider: NCBITaxonomy, microbe_dict_path: str):
    """
    Builds the microbe dictionary using the NCBITaxonomy provider.
    This function validates that our refactoring works as intended.
    """
    print("\nüî¨ Building hierarchical microbe dictionary using the new NCBITaxonomy provider...")
    final_dict = {}

    def is_bacterial(tax_id, p_map):
      curr_id = tax_id
      for _ in range(30):
          if curr_id == 2: return True
          if curr_id == 1 or curr_id not in p_map: break
          curr_id = p_map[curr_id]
      return False

    target_tax_ids = {tid for tid, rank in provider.rank_map.items() if rank in ['genus', 'species'] and is_bacterial(tid, provider.parent_map)}
    print(f"   - Identified {len(target_tax_ids)} bacterial genus/species tax IDs.")

    for tax_id in target_tax_ids:
        canonical_name = provider.get_name(tax_id)
        if not canonical_name: continue

        aliases = set(provider.get_synonyms(tax_id))
        rank = provider.rank_map.get(tax_id)

        if rank == 'species':
            parts = canonical_name.split()
            if len(parts) >= 2: aliases.add(f"{parts[0][0]}. {parts[1]}")

        filtered_aliases = {alias for alias in aliases if is_alias_valid_generalized(alias)}
        if not filtered_aliases: continue

        genus_name = None
        if rank == 'species':
            parent_id = provider.get_parents(tax_id)[0] if provider.get_parents(tax_id) else None
            if parent_id and provider.rank_map.get(parent_id) == 'genus':
                genus_name = provider.get_name(parent_id)
        elif rank == 'genus':
            genus_name = canonical_name

        final_dict[canonical_name] = {
            "rank": rank, "genus": genus_name, "aliases": sorted(list(filtered_aliases))
        }

    print(f"   - Final dictionary created with {len(final_dict)} canonical entries.")
    with open(microbe_dict_path, "w") as f:
        json.dump(final_dict, f, indent=2)
    print(f"\n‚úÖ Hierarchical microbe dictionary saved to '{microbe_dict_path}'.")
    return final_dict

# --- Execute and Validate ---
ncbi_provider = NCBITaxonomy()
microbe_dictionary = build_microbe_dictionary_from_provider(ncbi_provider, microbe_dict)

# Display a sample entry to validate the output
print("\n--- Sample Entries (Validation) ---")
sample_key = "Escherichia coli"
if sample_key in microbe_dictionary:
    print(json.dumps({sample_key: microbe_dictionary[sample_key]}, indent=2))

‚úÖ NCBI Taxonomy files already exist. Skipping download.
üß† Loading and processing NCBI taxonomy data...
‚úÖ NCBITaxonomy provider initialized.

üî¨ Building hierarchical microbe dictionary using the new NCBITaxonomy provider...
   - Identified 531722 bacterial genus/species tax IDs.
   - Final dictionary created with 512540 canonical entries.

‚úÖ Hierarchical microbe dictionary saved to 'microbe_dictionary_hierarchical.json'.

--- Sample Entries (Validation) ---
{
  "Escherichia coli": {
    "rank": "species",
    "genus": "Escherichia",
    "aliases": [
      "Bacillus coli",
      "E. coli",
      "Enterococcus coli",
      "Escherichia/Shigella coli"
    ]
  }
}


In [None]:
# ==============================================================================
# PHASE 4.2: Budiling ICD-11 dictionaly
# ==============================================================================

try:
    # Top-level ICD-11 Chapter IDs (unchanged)
    root_chapters = {
        "Certain infectious or parasitic diseases": "1435254666",
        "Neoplasms": "1630407678",
        "Diseases of the blood or blood-forming organs": "1766440644",
        "Diseases of the immune system": "1954798891",
        "Endocrine, nutritional or metabolic diseases": "21500692",
        "Mental, behavioural or neurodevelopmental disorders": "334423054",
        "Sleep-wake disorders": "274880002",
        "Diseases of the nervous system": "1296093776",
        "Diseases of the visual system": "868865918",
        "Diseases of the ear or mastoid process": "1218729044",
        "Diseases of the circulatory system": "426429380",
        "Diseases of the respiratory system": "197934298",
        "Diseases of the digestive system": "1256772020",
        "Diseases of the skin": "1639304259",
        "Diseases of the musculoskeletal system or connective tissue": "1473673350",
        "Diseases of the genitourinary system": "30659757",
    }

    # --- Execute the build process using the new provider ---
    icd_provider = ICD11Taxonomy(ICD11_CLIENT_ID, ICD11_CLIENT_SECRET, cache_path='icd11_api_cache.json', num_threads=4)

    disease_dictionary = icd_provider.build_disease_dictionary(root_chapters, disease_dict)

    # --- Display a sample ---
    print("\n--- Sample Entries ---")
    if "Multiple sclerosis" in disease_dictionary:
        print(json.dumps({"Multiple sclerosis": disease_dictionary["Multiple sclerosis"]}, indent=2))
    if "Crohn disease" in disease_dictionary:
        print(json.dumps({"Crohn disease": disease_dictionary["Crohn disease"]}, indent=2))

except NameError:
    print("‚ùå ERROR: ICD11_CLIENT_ID or ICD11_CLIENT_SECRET not defined.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

üíæ Loading API cache from 'icd11_api_cache.json'.
‚úÖ ICD11Taxonomy provider initialized with 4 concurrent threads.

üå≤ Starting concurrent graph traversal from 16 root chapters...
   - API Requests: 0, Entities Processed: 37555, Queue Size: 0  

‚úÖ Traversal complete. Total unique entities processed: 37555.
üíæ API cache saved to 'icd11_api_cache.json'.
üõ†Ô∏è  Processing cached data into final dictionary...
‚úÖ Final dictionary with 37554 entries saved to 'disease_dictionary_hierarchical.json'.

--- Sample Entries ---
{
  "Multiple sclerosis": {
    "icd11_code": "8A40",
    "parent_name": "Multiple sclerosis or other white matter disorders",
    "parent_code": "",
    "aliases": [
      "MS - [multiple sclerosis]",
      "Multiple sclerosis generalised",
      "Multiple sclerosis of brain stem",
      "Multiple sclerosis of cord",
      "cerebrospinal sclerosis",
      "disseminated brain sclerosis",
      "disseminated cerebrospinal sclerosis",
      "disseminated multiple s

In [None]:
# ==============================================================================
# Step 4.3: Define Function to Build Abstract Dictionary
# ==============================================================================
def build_abstract_dictionary(fetcher, df, output_path):
    """Iterates through DOIs in the dataframe, fetches abstracts, and saves to a JSON file."""
    if os.path.exists(output_path):
        print(f"‚úÖ Abstract dictionary '{output_path}' already exists. Skipping build process.")
        return

    print(f"\nüìö Building Abstract Dictionary... This may take a while.")
    abstract_dict = {}
    total_dois = len(df['DOI'].unique())

    for i, doi in enumerate(df['DOI'].unique()):
        if not doi: continue

        print(f"\r   - Processing DOI {i+1}/{total_dois}: {doi}", end="")
        abstract = fetcher.get_abstract_by_doi(doi)
        abstract_dict[doi] = abstract if abstract else ""

    print("\n‚úÖ Abstract fetching complete.")

    with open(output_path, 'w') as f:
        json.dump(abstract_dict, f, indent=2)
    print(f"üíæ Abstract dictionary saved to '{output_path}'.")


In [None]:
# ==============================================================================
# Step 5: Loading Data
# Step 5.1 Download stopwords from NLTK
# ==============================================================================
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [None]:
# ==============================================================================
# STEP 5.2: Load Data from Google Sheet
# ==============================================================================
print("\n--- Loading data from Google Sheet ---")
try:
    spreadsheet = gc.open_by_url(spreadsheet_url)
    worksheet = spreadsheet.worksheet(worksheet_name)
    all_values = worksheet.get_all_values()
    header = all_values[header_indx]

    # Make column names unique if there are duplicates
    cols = pd.Series(header)
    for dup in cols[cols.duplicated()].unique():
        cols[cols[cols == dup].index.values.tolist()] = [dup + '.' + str(i) if i != 0 else dup for i in range(sum(cols == dup))]
    header = list(cols)

    data_rows = all_values[header_indx+1:]
    input_df = pd.DataFrame(data_rows, columns=header)
    input_df.reset_index(inplace=True)
    input_df.rename(columns={'index': 'row_index'}, inplace=True)

    # Select the first 'StudyTitle' column by its index if there are multiple
    study_title_cols = [col for col in input_df.columns if 'StudyTitle' in col]
    if study_title_cols:
        input_df['StudyTitle'] = input_df[study_title_cols[0]]
        # Drop other StudyTitle columns if they exist, keeping only the first
        other_study_title_cols = study_title_cols[1:]
        if other_study_title_cols:
            input_df.drop(columns=other_study_title_cols, inplace=True)


    if 'Processed' not in input_df.columns:
        input_df['Processed'] = ''


    # Filter for rows that are not processed and have both StudyTitle and Disease
    rows_to_process = input_df[
        (input_df['Processed'] == '') &
        (input_df['icd11_description'] != '') &
        (input_df['icd11_description'] != 'None') &
        (input_df['StudyTitle'] != '')
    ].copy()

    rows_to_process=rows_to_process[['DOI','StudyTitle','icd11_description']]

    print(f"Loaded {len(input_df)} total records.")
    print(f"Found {len(rows_to_process)} new records to process.")
    print("--- Successfully loaded data from Google Sheet ---")
    display(DataTable(rows_to_process.head()))
except Exception as e:
    print(f"Could not load Google Sheet. Error: {e}")
    rows_to_process = pd.DataFrame()


--- Loading data from Google Sheet ---
Loaded 2329 total records.
Found 836 new records to process.
--- Successfully loaded data from Google Sheet ---


Unnamed: 0,DOI,StudyTitle,icd11_description
0,10.3389/fcimb.2019.00476,The Oral Microbiota May Have Influence on Oral...,Malignant neoplasms of other or ill-defined si...
2,10.1038/s41467-024-53013-x,Effects of iron supplements and iron-containin...,Anaemias or other erythrocyte disorders
3,10.1038/s41586-022-04427-4,The lung microbiome regulates brain autoimmunity,Multiple sclerosis
4,10.1128/spectrum.01901-21,Insights into the Unique Lung Microbiota Profi...,Multiple sclerosis
5,10.1038/s41598-022-07995-7,16S rRNA and metagenomic shotgun sequencing da...,Ulcerative colitis


In [None]:
# ==============================================================================
# Step 5.3: Load the Microbe and Disease Dictionaries
# ==============================================================================
def load_and_prepare_dictionary(filepath, entity_type='microbe'):
    """Loads a hierarchical JSON and prepares it for entity recognition."""
    try:
        with open(filepath, 'r') as f:
            structured_dict = json.load(f)

        name_to_canonical_map = {}
        for canonical, data in structured_dict.items():
            name_to_canonical_map[canonical.lower()] = canonical # Map lowercase to canonical
            for alias in data.get('aliases', []):
                name_to_canonical_map[alias.lower()] = canonical

        gazetteer = set(name_to_canonical_map.keys())

        print(f"‚úÖ {entity_type.capitalize()} dictionary loaded from '{filepath}'.")
        print(f"   - {len(gazetteer)} total names/aliases mapped to {len(structured_dict)} canonical entities.")
        return name_to_canonical_map, gazetteer
    except FileNotFoundError:
        print(f"‚ùå Error: The dictionary file '{filepath}' was not found.")
        return None, None

# --- Load Microbe Dictionary ---
microbe_map, microbe_gazetteer = load_and_prepare_dictionary(
    microbe_dict, 'microbe'
)

# --- Load Disease Dictionary ---
disease_map, disease_gazetteer = load_and_prepare_dictionary(
    disease_dict, 'disease'
)

‚úÖ Microbe dictionary loaded from 'microbe_dictionary_hierarchical.json'.
   - 571513 total names/aliases mapped to 512494 canonical entities.
‚úÖ Disease dictionary loaded from 'disease_dictionary_hierarchical.json'.
   - 71156 total names/aliases mapped to 37554 canonical entities.


In [None]:
# ==============================================================================
# Step 5.4: Build and Load the Abstract Dictionary
# ==============================================================================
ABSTRACT_DICT_PATH = 'abstract_dictionary.json'

# --- Build Dictionary (if it doesn't exist) ---
fetcher = PublicationFetcher(email=EMAIL, api_key=NCBI_API_KEY)
build_abstract_dictionary(fetcher, rows_to_process, ABSTRACT_DICT_PATH)



‚úÖ PublicationFetcher initialized.

üìö Building Abstract Dictionary... This may take a while.
  > Retrying in 2 seconds...
  > Retrying in 4 seconds...
   - Processing DOI 821/822: 10.1167/tvst.10.2.19
‚úÖ Abstract fetching complete.
üíæ Abstract dictionary saved to 'abstract_dictionary.json'.
‚úÖ Successfully loaded abstract dictionary with 821 entries.


In [None]:
# --- Load Dictionary ---
try:
    with open(ABSTRACT_DICT_PATH, 'r') as f:
        abstract_dictionary = json.load(f)
    print(f"‚úÖ Successfully loaded abstract dictionary with {len(abstract_dictionary)} entries.")
except FileNotFoundError:
    print(f"‚ùå ERROR: Abstract dictionary not found. Please run the build step.")
    abstract_dictionary = {}

‚úÖ Successfully loaded abstract dictionary with 821 entries.


In [None]:
# ==============================================================================
# Step 6: Define the MicrobeExtractor Class
# ==============================================================================

class MicrobeExtractor:
    """
    A highly optimized class to extract microbe names using pre-compiled,
    batched regular expressions for maximum performance.
    """
    def __init__(self, name_to_canonical_map, gazetteer, batch_size=500):
        self.name_map = name_to_canonical_map
        self.compiled_patterns = []

        print(f"‚öôÔ∏è  Optimizing gazetteer for extraction...")
        # Sort by length descending to ensure longer names are matched first
        sorted_gazetteer = sorted(list(gazetteer), key=len, reverse=True)

        print(f"   - Compiling {len(sorted_gazetteer)} aliases into regex batches of size {batch_size}...")
        # Chunk the gazetteer into batches
        for i in range(0, len(sorted_gazetteer), batch_size):
            batch = sorted_gazetteer[i:i + batch_size]
            # Escape each alias to handle special regex characters safely
            escaped_batch = [re.escape(alias) for alias in batch]
            # Create a single large OR pattern for the batch
            pattern_str = r'\b(' + '|'.join(escaped_batch) + r')\b'
            # Compile the pattern for speed and add it to our list
            self.compiled_patterns.append(re.compile(pattern_str, re.IGNORECASE))

        print(f"‚úÖ MicrobeExtractor initialized with {len(self.compiled_patterns)} compiled regex patterns.")

    def find_microbes(self, text):
        """Finds microbes using the pre-compiled batched regex patterns."""
        if not isinstance(text, str): return []

        found_microbes = set()

        # Iterate through the compiled regex patterns (batches)
        for pattern in self.compiled_patterns:
            # finditer finds all non-overlapping matches for the pattern
            for match in pattern.finditer(text):
                # The matched string
                matched_alias = match.group(0)
                # Normalize it to its canonical name
                canonical_name = self.name_map.get(matched_alias.lower())
                # The .get() is safer. Use .get(matched_alias.lower()) if your map keys are all lowercase
                if canonical_name:
                    found_microbes.add(canonical_name)

        return list(found_microbes)

In [None]:
# ==============================================================================
# Step 7: Corpus Creation with Context
# ==============================================================================
def preprocess_text_for_phrasing(text):
    text = re.sub('[^a-zA-Z]', ' ', text).lower()
    return [word for word in text.split() if word not in stop_words and len(word) > 2]

if not rows_to_process.empty and microbe_map and disease_map:
    microbe_extractor = MicrobeExtractor(microbe_map, microbe_gazetteer)
    disease_extractor = MicrobeExtractor(disease_map, disease_gazetteer)

    initial_corpus = []
    associations = []
    print("\nScanning text and creating initial corpus...")

    for _, row in rows_to_process.iterrows():
        doi = row['DOI']
        title = row['StudyTitle']
        canonical_disease = row['icd11_description']

        # --- Get abstract from the pre-loaded dictionary ---
        abstract = abstract_dictionary.get(doi, "")

        # --- Combine title and abstract for a richer context ---
        full_text = title
        if abstract:
            full_text += ' ' + abstract

        canonical_microbes = microbe_extractor.find_microbes(full_text)
        disease_aliases = disease_extractor.find_microbes(full_text)

        if canonical_microbes and canonical_disease:
            context_words = preprocess_text_for_phrasing(full_text)

            training_sentence = canonical_microbes + context_words + [canonical_disease] + disease_aliases

            if len(training_sentence) > 1:
                initial_corpus.append(training_sentence)

            for microbe in canonical_microbes:
                associations.append({'microbe': microbe, 'disease': canonical_disease})

    # --- Learn and apply multi-word phrases ---
    print("üó£Ô∏è  Learning multi-word phrases")
    phrases = Phrases(initial_corpus, min_count=2, threshold=10.0)
    phraser = Phraser(phrases)
    # Apply the phraser to the whole corpus
    final_corpus = [phraser[doc] for doc in initial_corpus]
    print("‚úÖ Phrase detection complete.")

    print(f"ü¶† Found {len(associations)} microbe-disease associations.")
    # assoc_df = pd.DataFrame(associations).drop_duplicates()
    # print(f"ü¶† Found {len(assoc_df)} unique microbe-disease associations.")

‚öôÔ∏è  Optimizing gazetteer for extraction...
   - Compiling 571513 aliases into regex batches of size 500...
‚úÖ MicrobeExtractor initialized with 1144 compiled regex patterns.
‚öôÔ∏è  Optimizing gazetteer for extraction...
   - Compiling 71156 aliases into regex batches of size 500...
‚úÖ MicrobeExtractor initialized with 143 compiled regex patterns.

Scanning text and creating initial corpus...
üó£Ô∏è  Learning multi-word phrases
‚úÖ Phrase detection complete.
ü¶† Found 1277 microbe-disease associations.


In [None]:
    # --- Learn and apply multi-word phrases ---
    print("üó£Ô∏è  Learning multi-word phrases from the richer corpus...")
    phrases = Phrases(initial_corpus, min_count=3, threshold=10.0)
    phraser = Phraser(phrases)
    final_corpus = [phraser[doc] for doc in initial_corpus]
    print("‚úÖ Phrase detection complete.")

    # --- üìä AGGREGATE CO-OCCURRENCE COUNTS ---
    # Convert the full list of associations to a DataFrame
    assoc_df_raw = pd.DataFrame(associations)

    if not assoc_df_raw.empty:
        # Group by microbe and disease, then count the size of each group
        assoc_counts_df = assoc_df_raw.groupby(['microbe', 'disease']).size().reset_index(name='count')
        assoc_counts_df = assoc_counts_df.sort_values(by='count', ascending=False)
        print(f"üìä Found {len(assoc_counts_df)} unique microbe-disease associations with frequency counts.")
    else:
        assoc_counts_df = pd.DataFrame(columns=['microbe', 'disease', 'count'])
        print("‚ö†Ô∏è No associations found to count.")


üó£Ô∏è  Learning multi-word phrases from the richer corpus...
‚úÖ Phrase detection complete.
üìä Found 1033 unique microbe-disease associations with frequency counts.


In [None]:
assoc_df_raw.to_csv('microbe_disease_associations_raw.csv', index=False)
assoc_counts_df.to_csv('microbe_disease_associations_counts.csv', index=False)

### **Step 8: Adjusting Model Hyperparameters for a Richer Corpus**

The corpus, enriched with abstracts, is significantly larger and more contextually detailed than the one built only from titles. To get the most meaningful results from the Word2Vec model, adjust its hyperparameters:

* **`window`**: The original value was `3`. With longer sentences from abstracts, we need to capture relationships between words that are further apart. **Increasing the window size to `5` or `7`** is a good starting point.
* **`min_count`**: The original value was `2`. Our larger corpus will contain many more rare words (hapax legomena) that are essentially noise. **Increasing `min_count` to `3` or `5`** will help filter these out, leading to more robust and meaningful vectors for the remaining vocabulary.
* **`epochs`**: The original value of `50` is quite high and excellent for a small corpus. With a larger corpus, the model sees more data per epoch. We can likely achieve good results with slightly fewer epochs (e.g., `20-30`), but keeping it at `50` will ensure thorough training if time permits.

In [None]:
# ==============================================================================
# Step 8: Train Word2Vec Model and Build the Final Graph
# ==============================================================================
if 'final_corpus' in locals() and final_corpus:
    print("\nüß† Training Word2Vec model with tuned parameters...")
    # --- Tuned Hyperparameters ---
    model = Word2Vec(
        sentences=final_corpus,
        vector_size=300,
        window=5,      # Smaller window for more specific context
        min_count=3,   # Ignore very rare words
        workers=4,
        sg=1,
        epochs=50      # More training iterations on the small corpus
    )
    print("‚ú® Model training complete!")

    G = nx.Graph()
    print(f"\nüï∏Ô∏è Building graph from the found associations...")
    similarity_threshold = 0.5 # Adjust as needed, scores should be more meaningful now

    for _, row in assoc_df.iterrows():
        microbe = row['microbe']
        disease = row['disease']
        # The disease name needs to be phrased just like the training data
        disease_phrase = '_'.join(preprocess_text_for_phrasing(disease))

        try:
            if microbe in model.wv and disease_phrase in model.wv:
                score = float(model.wv.similarity(microbe, disease_phrase))
                if score >= similarity_threshold:
                    G.add_node(microbe, size=10, color='skyblue', title=f"Microbe: {microbe}")
                    G.add_node(disease, size=20, color='tomato', title=f"Disease: {disease}")
                    G.add_edge(microbe, disease, weight=score, title=f"{score:.2f}", label=f"{score:.2f}")
        except KeyError:
            continue

    print(f"   - Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges (Threshold > {similarity_threshold}).")

    if G.number_of_edges() > 0:
        net = Network(notebook=True, cdn_resources='in_line', height='1600px', width='100%', bgcolor='#222222', font_color='white')
        net.from_nx(G); net.show_buttons(filter_=['physics'])
        net.save_graph("microbe_disease_graph_final_context.html")
        print("\nüéâ Interactive graph saved as 'microbe_disease_graph_final_context.html'.")
        display(net)
    else:
        print(f"\n‚ö†Ô∏è No relationships found. Try lowering the threshold or expanding the corpus with abstracts.")


üß† Training Word2Vec model with tuned parameters...
‚ú® Model training complete!

üï∏Ô∏è Building graph from the found associations...


NameError: name 'assoc_df' is not defined

In [None]:
# ==============================================================================
# Step 8: Train Word2Vec Model and Build the Final Graph
# ==============================================================================
if 'final_corpus' in locals() and final_corpus:
    print("\nüß† Training Word2Vec model with tuned parameters...")
    # --- Tuned Hyperparameters ---
    model = Word2Vec(
        sentences=final_corpus,
        vector_size=300,
        window=5,      # Smaller window for more specific context
        min_count=3,   # Ignore very rare words
        workers=4,
        sg=1,
        epochs=50      # More training iterations on the small corpus
    )
    print("‚ú® Model training complete!")

G = nx.Graph()
print(f"\nüï∏Ô∏è Building graph with a composite score for edge weights...")
similarity_threshold = 0.4
# This new weight will now control the edge thickness and physics
for _, row in assoc_counts_df.iterrows():
    microbe = row['microbe']
    disease = row['disease']
    count = row['count']
    disease_phrase = '_'.join(preprocess_text_for_phrasing(disease))

    try:
        if microbe in model.wv and disease_phrase in model.wv:
            similarity_score = float(model.wv.similarity(microbe, disease_phrase))

            # --- ‚úÖ Calculate the new composite score ---
            # We use np.log1p which is equivalent to log(1 + count)
            composite_score = similarity_score * np.log1p(count)
            if composite_score >= similarity_threshold:
              # Add nodes as before
              G.add_node(microbe, size=10, color='skyblue', title=f"Microbe: {microbe}")
              G.add_node(disease, size=20, color='tomato', title=f"Disease: {disease}")

              # Add the edge using the new composite score for weight
              G.add_edge(
                  microbe,
                  disease,
                  weight=composite_score, # Use the new score for physics
                  title=f"Composite Score: {composite_score:.2f}\nSimilarity: {similarity_score:.2f}\nCo-occurrences: {count}",
                  label=f"{composite_score:.2f}"
              )
    except KeyError:
        continue

print(f"   - Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

# --- The rest of the visualization code remains the same ---
if G.number_of_edges() > 0:
    net = Network(notebook=True, cdn_resources='in_line', height='1600px', width='100%', bgcolor='#222222', font_color='white')
    # This tells Pyvis to make edge thickness dependent on the 'weight' attribute
    net.from_nx(G)
    net.show_buttons(filter_=['physics', 'edges'])
    net.save_graph("microbe_disease_graph_composite_score.html")
    print("\nüéâ Interactive graph saved as 'microbe_disease_graph_composite_score.html'.")
    display(net)
else:
    print(f"\n‚ö†Ô∏è No relationships found.")


üß† Training Word2Vec model with tuned parameters...
‚ú® Model training complete!

üï∏Ô∏è Building graph with a composite score for edge weights...
   - Graph built with 41 nodes and 27 edges.

üéâ Interactive graph saved as 'microbe_disease_graph_composite_score.html'.


<class 'pyvis.network.Network'> |N|=41 |E|=27

In [None]:
# ==============================================================================
# Step 8: Word2Vec Model
# ==============================================================================
if 'final_corpus' in locals() and final_corpus:
    print("\nüß† Training Word2Vec model with tuned parameters...")
    model = Word2Vec(
        sentences=final_corpus,
        vector_size=300,
        window=7,
        min_count=3,
        workers=4,
        sg=1,
        epochs=50
    )
    print("‚ú® Model training complete!")

    G = nx.Graph()
    print(f"\nüï∏Ô∏è Building graph with composite score for physics and co-occurrence count for edge thickness...")

    for _, row in assoc_counts_df.iterrows():
        microbe = row['microbe']
        disease = row['disease']
        count = row['count']
        disease_phrase = '_'.join(preprocess_text_for_phrasing(disease))

        similarity_threshold = 0.3

        try:
            if microbe in model.wv and disease_phrase in model.wv:
                similarity_score = float(model.wv.similarity(microbe, disease_phrase))
                composite_score = similarity_score * np.log1p(count)
                if composite_score >= similarity_threshold:
                  # --- ‚úÖ Calculate log-scaled count for thickness ---
                  # np.log1p(count) is equivalent to np.log(count + 1)
                  edge_thickness = np.log1p(count)

                  # Add nodes as before
                  G.add_node(microbe, size=10, color='skyblue', title=f"Microbe: {microbe}")
                  G.add_node(disease, size=20, color='tomato', title=f"Disease: {disease}")

                  # --- ‚úÖ Add 'value' attribute to control thickness ---
                  G.add_edge(
                      microbe,
                      disease,
                      weight=composite_score, # Use the composite score for physics
                      value=edge_thickness, # Use the log-scaled count for visual thickness
                      label=f"{composite_score:.2f}",
                      title=f"Composite Score: {composite_score:.2f}\nSimilarity: {similarity_score:.2f}\nCo-occurrences: {count}"
                )
        except KeyError:
            continue

    print(f"   - Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")
    print(f"   - Edge thickness now represents the log-scaled co-occurrence count.")

    if G.number_of_edges() > 0:
        net = Network(notebook=True, cdn_resources='in_line', height='1600px', width='100%', bgcolor='#222222', font_color='white')
        net.from_nx(G)
        net.show_buttons(filter_=['physics', 'edges'])
        net.save_graph("microbe_disease_graph_thickness.html")
        print("\nüéâ Interactive graph saved as 'microbe_disease_graph_thickness.html'.")
        display(net)
    else:
        print(f"\n‚ö†Ô∏è No relationships found.")


üß† Training Word2Vec model with tuned parameters...
‚ú® Model training complete!

üï∏Ô∏è Building graph with composite score for physics and co-occurrence count for edge thickness...
   - Graph built with 54 nodes and 49 edges.
   - Edge thickness now represents the log-scaled co-occurrence count.

üéâ Interactive graph saved as 'microbe_disease_graph_thickness.html'.


<class 'pyvis.network.Network'> |N|=54 |E|=49

##Next Steps and Scaling Up

###Expand the Corpus:

* ‚úÖ Fetch PubMed Abstracts: Modify the script to use the requests to get PubMed abstracts for each DOI in your sheet.

###Improve Entity Recognition:

* Create Dictionaries:
  * ‚úÖ Compile comprehensive lists of microbe names (at species and genus levels) and disease synonyms from resources like the NCBI Taxonomy.
  * ‚úÖ Create comprehensive disease dictionary from  ICD-11 or similar taxonomy

* Use BioBERT: FUse a pre-trained language model like BioBERT for Named Entity Recognition.

###Refine the Graph:

* Node Attributes: Add more metadata to nodes. For a disease node, add its "Disease Group." For a microbe, add its phylum. This can be used for more advanced filtering and coloring in visualization.

* ‚úÖ Edge Weighting: Experiment with different similarity thresholds and consider weighting edges by how many times two entities co-occur in corpus.

###Advanced Analysis:

* Community Detection: Use algorithms like the Louvain method in NetworkX to find "communities" or clusters of tightly connected nodes. This could reveal, a group of different bacteria all associated with metabolic disorders.

* Link Prediction: Use graph machine learning techniques to predict missing links, suggesting novel microbe-disease relationships that are plausible but not yet explicitly studied.