In [None]:
"""
PubMed Metadata Retriever

This script enriches a dataset containing PMIDs by retrieving the following metadata:
- Title and abstract
- Authors
- MeSH terms
- Publication date [edat]

It uses the Bio.Entrez module from Biopython to interact with NCBI's E-utilities.
"""

# Install required packages if not already installed
try:
    import pandas as pd
    from Bio import Entrez
except ImportError:
    !pip install biopython pandas
    import pandas as pd
    from Bio import Entrez

import time
import json
import os
from google.colab import files
import io

# Configuration settings - MODIFY THESE AS NEEDED
# ----------------------------------------------
# IMPORTANT: Always provide your email when using Entrez
EMAIL = "put email address here"  # Replace with your email address
Entrez.email = EMAIL

# Optional: Provide your API key if you have one (increases rate limits)
# API_KEY = "your_api_key_here"
# Entrez.api_key = API_KEY

# Processing parameters
BATCH_SIZE = 10  # Number of PMIDs to process per batch
RATE_LIMIT_DELAY = 0.5  # Delay between API calls in seconds
LIMIT = None  # Set to a number to limit processing (or None to process all)
# ----------------------------------------------

# Function to load PMIDs from a CSV file
def load_pmids(file_path=None):
    """
    Load PMIDs from a CSV file either from a local path or uploaded by the user.

    Args:
        file_path: Path to the CSV file (optional)

    Returns:
        List of PMIDs as strings
    """
    if file_path is None or not os.path.exists(file_path):
        print("Please upload your CSV file containing PMIDs...")
        uploaded = files.upload()

        if not uploaded:
            print("No file was uploaded. Using a sample PMID for demonstration.")
            return ["21383081"]  # Example PMID

        file_path = next(iter(uploaded))

    try:
        # Try reading with pandas
        df = pd.read_csv(file_path)

        # Check if 'pmid' column exists
        if 'pmid' not in df.columns:
            print(f"Warning: 'pmid' column not found. Available columns: {df.columns.tolist()}")
            print("Using the first column as PMIDs...")
            pmids = df.iloc[:, 0].astype(str).tolist()
        else:
            pmids = df['pmid'].astype(str).tolist()

        # Remove any NaN values and convert to strings
        pmids = [str(int(float(pmid))) for pmid in pmids if str(pmid).lower() != 'nan']

        print(f"Successfully loaded {len(pmids)} PMIDs")
        return pmids

    except Exception as e:
        print(f"Error reading CSV file: {e}")
        print("Using a sample PMID for demonstration.")
        return ["21383081"]  # Example PMID

# Function to fetch metadata for a single PMID
def fetch_pubmed_metadata(pmid):
    """
    Fetch metadata for a single PMID using NCBI's E-utilities

    Args:
        pmid: PubMed ID as a string

    Returns:
        Dictionary containing article metadata
    """
    try:
        # Fetch the article data
        handle = Entrez.efetch(db="pubmed", id=pmid, retmode="xml")
        record = Entrez.read(handle)
        handle.close()

        if not record.get('PubmedArticle'):
            return {"pmid": pmid, "error": "No data found"}

        article = record['PubmedArticle'][0]

        # Extract article metadata
        metadata = {
            "pmid": pmid,
            "title": "",
            "abstract": "",
            "authors": [],
            "mesh_terms": [],
            "publication_date": "",
            "edat": ""
        }

        # Extract title
        article_data = article['MedlineCitation']['Article']
        metadata["title"] = article_data.get('ArticleTitle', '')

        # Extract abstract
        if 'Abstract' in article_data and 'AbstractText' in article_data['Abstract']:
            abstract_parts = article_data['Abstract']['AbstractText']
            if isinstance(abstract_parts, list):
                # Handle structured abstracts
                abstract_text = []
                for part in abstract_parts:
                    # Check if this is a labeled part with a NlmCategory attribute
                    if hasattr(part, 'attributes') and 'NlmCategory' in part.attributes:
                        label = part.attributes['NlmCategory']
                        abstract_text.append(f"{label}: {part}")
                    else:
                        abstract_text.append(str(part))
                abstract_text = " ".join(abstract_text)
            else:
                abstract_text = str(abstract_parts)
            metadata["abstract"] = abstract_text

        # Extract authors
        if 'AuthorList' in article_data:
            for author in article_data['AuthorList']:
                if isinstance(author, dict):  # Ensure author is a dictionary
                    author_info = {}
                    if 'LastName' in author:
                        author_info["last_name"] = author['LastName']
                    if 'ForeName' in author:
                        author_info["fore_name"] = author['ForeName']
                    elif 'Initials' in author:  # Use initials if ForeName is not available
                        author_info["fore_name"] = author['Initials']

                    # Extract affiliation
                    if 'AffiliationInfo' in author and author['AffiliationInfo']:
                        affiliations = []
                        for affil in author['AffiliationInfo']:
                            if 'Affiliation' in affil:
                                affiliations.append(affil['Affiliation'])
                        if affiliations:
                            author_info["affiliation"] = "; ".join(affiliations)
                    elif 'Affiliation' in author:  # For older format
                        author_info["affiliation"] = author['Affiliation']

                    if author_info:  # Only add if we have some data
                        metadata["authors"].append(author_info)

        # Extract MeSH terms
        if 'MeshHeadingList' in article['MedlineCitation']:
            for mesh in article['MedlineCitation']['MeshHeadingList']:
                if 'DescriptorName' in mesh:
                    descriptor = mesh['DescriptorName']
                    # Check if descriptor is a string or has attributes
                    if isinstance(descriptor, str):
                        mesh_term = descriptor
                    else:
                        ui = descriptor.attributes.get('UI', '')
                        term = descriptor
                        mesh_term = f"{ui}: {term}" if ui else term
                    metadata["mesh_terms"].append(mesh_term)

        # Extract publication date from Journal
        if 'Journal' in article_data and 'JournalIssue' in article_data['Journal']:
            if 'PubDate' in article_data['Journal']['JournalIssue']:
                pub_date = article_data['Journal']['JournalIssue']['PubDate']
                date_parts = []

                # Handle different date formats
                if 'Year' in pub_date:
                    date_parts.append(pub_date['Year'])
                    if 'Month' in pub_date:
                        date_parts.append(pub_date['Month'])
                        if 'Day' in pub_date:
                            date_parts.append(pub_date['Day'])
                elif 'MedlineDate' in pub_date:
                    date_parts.append(pub_date['MedlineDate'])

                metadata["publication_date"] = " ".join(date_parts)

        # Try to get electronic publication date (edat) from PubmedData history
        if 'PubmedData' in article and 'History' in article['PubmedData']:
            for date_item in article['PubmedData']['History']:
                if isinstance(date_item, dict) and 'PubStatus' in date_item.attributes:
                    if date_item.attributes['PubStatus'] == 'pubmed':
                        year = date_item.get('Year', '')
                        month = date_item.get('Month', '')
                        day = date_item.get('Day', '')
                        if year:
                            metadata["edat"] = f"{year}-{month}-{day}" if month and day else year

        return metadata

    except Exception as e:
        error_msg = str(e)
        print(f"Error fetching metadata for PMID {pmid}: {error_msg}")
        return {"pmid": pmid, "error": error_msg}

# Function to process PMIDs in batches
def process_pmids(pmids, batch_size=BATCH_SIZE, limit=LIMIT):
    """
    Process PMIDs in batches to avoid overloading the NCBI API

    Args:
        pmids: List of PMIDs to process
        batch_size: Number of PMIDs to process per batch
        limit: Maximum number of PMIDs to process (None for all)

    Returns:
        List of dictionaries containing metadata for each PMID
    """
    if limit is not None and limit < len(pmids):
        pmids = pmids[:limit]
        print(f"Processing the first {limit} PMIDs...")

    total_pmids = len(pmids)
    print(f"Starting to process {total_pmids} PMIDs in batches of {batch_size}")

    all_metadata = []

    for i in range(0, total_pmids, batch_size):
        batch_end = min(i + batch_size, total_pmids)
        batch = pmids[i:batch_end]

        print(f"\nProcessing batch {i//batch_size + 1}/{(total_pmids + batch_size - 1)//batch_size} (PMIDs {i+1}-{batch_end}/{total_pmids})")

        batch_metadata = []
        for pmid in batch:
            print(f"  Fetching metadata for PMID: {pmid}")
            metadata = fetch_pubmed_metadata(pmid)
            batch_metadata.append(metadata)

            # Add a small delay to respect NCBI's rate limits
            time.sleep(RATE_LIMIT_DELAY)

        all_metadata.extend(batch_metadata)

        # Display progress
        success_count = sum(1 for item in batch_metadata if "error" not in item or not item["error"])
        print(f"  Batch completed: {success_count}/{len(batch)} successful")

    print(f"\nProcessing complete! Retrieved metadata for {len(all_metadata)} PMIDs")
    return all_metadata

# Function to create a flattened dataframe for CSV output
def create_flattened_dataframe(metadata_list):
    """
    Create a flattened DataFrame from the metadata list for CSV export

    Args:
        metadata_list: List of metadata dictionaries

    Returns:
        Pandas DataFrame with flattened structure
    """
    flattened_data = []

    for item in metadata_list:
        flat_item = {
            "pmid": item["pmid"],
            "title": item.get("title", ""),
            "abstract": item.get("abstract", ""),
            "publication_date": item.get("publication_date", ""),
            "edat": item.get("edat", ""),
            "authors_count": len(item.get("authors", [])),
            "mesh_terms_count": len(item.get("mesh_terms", [])),
            "error": item.get("error", "")
        }

        # Add first 5 authors if available
        authors = item.get("authors", [])
        for i in range(min(5, len(authors))):
            author = authors[i]
            flat_item[f"author{i+1}_last"] = author.get("last_name", "")
            flat_item[f"author{i+1}_fore"] = author.get("fore_name", "")
            if "affiliation" in author:
                flat_item[f"author{i+1}_affiliation"] = author["affiliation"]

        # Add first 10 MeSH terms if available
        mesh_terms = item.get("mesh_terms", [])
        for i in range(min(10, len(mesh_terms))):
            flat_item[f"mesh{i+1}"] = mesh_terms[i]

        flattened_data.append(flat_item)

    return pd.DataFrame(flattened_data)

# Function to save results and provide download links
def save_and_download_results(metadata, base_filename="pubmed_metadata"):
    """
    Save results to CSV and JSON files and provide download links

    Args:
        metadata: List of metadata dictionaries
        base_filename: Base filename for the output files

    Returns:
        None
    """
    # Create DataFrame for CSV export
    df = create_flattened_dataframe(metadata)

    # Save CSV
    csv_filename = f"{base_filename}.csv"
    df.to_csv(csv_filename, index=False)
    print(f"Saved CSV file: {csv_filename}")

    # Save full JSON
    json_filename = f"{base_filename}_full.json"
    with open(json_filename, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"Saved full JSON file: {json_filename}")

    # Provide download links
    print("\nDownload the results by running the following commands in separate cells:")
    print(f"from google.colab import files")
    print(f"files.download('{csv_filename}')")
    print(f"files.download('{json_filename}')")

# Main function to run the entire process
def main():
    """
    Main function to run the entire process
    """
    print("PubMed Metadata Retriever for Google Colab")
    print("=" * 50)
    print(f"Using email: {EMAIL}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Rate limit delay: {RATE_LIMIT_DELAY} seconds")
    print(f"Processing limit: {'All' if LIMIT is None else LIMIT} PMIDs")
    print("=" * 50)

    # Step 1: Load PMIDs
    pmids = load_pmids()

    if not pmids:
        print("No PMIDs found. Exiting.")
        return

    # Step 2: Process PMIDs
    metadata = process_pmids(pmids, BATCH_SIZE, LIMIT)

    # Step 3: Save and provide download links
    save_and_download_results(metadata)

    print("\nProcess completed successfully!")

# Run the main function
if __name__ == "__main__":
    main()