In [1]:
!pip install ratelimit
!pip install trafilatura
import trafilatura
import csv
import json
import time
import logging
import pickle
import gzip
import glob
import os
import re
from typing import Optional, Dict, Any
from tqdm import tqdm
import pandas as pd
from bs4 import BeautifulSoup
from ratelimit import limits, sleep_and_retry
from pydantic import BaseModel, Field
import google.generativeai as genai




In [2]:
import csv
import json
import time
import logging
import pickle
import gzip
import glob
import os
import re
from typing import Optional, Dict, Any, List # Added List
from tqdm.notebook import tqdm
import pandas as pd
from bs4 import BeautifulSoup
from ratelimit import limits, sleep_and_retry
from pydantic import BaseModel, Field
import google.generativeai as genai
import concurrent.futures # For parallelization

In [3]:
import csv
import json
import time
import logging
import pickle
import gzip
import glob
import os
import re
from typing import Optional, Dict, Any, List

from tqdm import tqdm
import pandas as pd
from bs4 import BeautifulSoup
from ratelimit import limits, sleep_and_retry
from pydantic import BaseModel, Field, ValidationError
import google.generativeai as genai
import concurrent.futures
import trafilatura

# --- Logging and API Configuration ---
# WARNING: Hardcoding API keys is a significant security risk.
GEMINI_API_KEY_HARDCODED = "AIzaSyADS35cYhRoTNH6OCE2TpH4CMCJkeyTAMc"  # Replace with your actual key

logging.basicConfig(
    filename='gemini_processing_errors.log',
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

model = None
gemini_config_settings = {}

try:
    # More robust check for placeholder/example key
    is_placeholder_key = (
        not GEMINI_API_KEY_HARDCODED or
        GEMINI_API_KEY_HARDCODED.startswith("YOUR_") or
        GEMINI_API_KEY_HARDCODED == "AIzaSyADS35cYhRoTNH6OCE2TpH4CMCJkeyTAMc" # Example placeholder
    )
    if is_placeholder_key:
        logger.critical("Using a placeholder or example Gemini API key. Please update it in the script.")
        print("🚨 CRITICAL WARNING: Using a placeholder or example Gemini API key. Please update it in the script.")

    genai.configure(api_key=GEMINI_API_KEY_HARDCODED)
    logger.info("Gemini API configured with hardcoded key.")

    gemini_config_settings = {
        'model_name': 'gemini-1.5-flash-latest',
        'temperature': 0.1,
        'max_output_tokens': 8192,
        'request_timeout': 400,
        'calls_per_period': 15,
        'period_seconds': 60,
        'max_gemini_workers': 1, # Default, can be changed for row-level parallelism
        'max_content_chars': 150000
    }
    # If loading from a config file, update gemini_config_settings here.

    model = genai.GenerativeModel(gemini_config_settings['model_name'])
    logger.info(f"Gemini model '{model.model_name}' initialized.")
    print(f"Gemini model '{model.model_name}' initialized.")

except Exception as e:
    logger.critical(f"Failed to configure Gemini API or initialize model: {e}", exc_info=True)
    print(f"🚨 CRITICAL ERROR: Failed to configure Gemini API or initialize model. Error: {e}")

# --- Data Models ---
class ExtractedData(BaseModel):
    findings: Optional[str] = Field(default=None)
    p_value: Optional[str] = Field(default=None)
    population: Optional[str] = Field(default=None)
    variant_details: Optional[str] = Field(default=None)
    gene_association: Optional[str] = Field(default=None)

# --- Helper Functions ---
@sleep_and_retry
@limits(calls=gemini_config_settings.get('calls_per_period', 15),
        period=gemini_config_settings.get('period_seconds', 60))
def call_gemini_api_with_retries(
    current_logger: logging.Logger,
    current_model_instance: genai.GenerativeModel, # Renamed for clarity
    prompt: str,
    current_gemini_config: dict,
    max_retries=2,
    base_delay=10
):
    if not current_model_instance:
        current_logger.error("Gemini model is not initialized. Cannot make API call.")
        return None

    for attempt in range(max_retries + 1):
        try:
            current_logger.debug(f"Gemini API call attempt {attempt + 1} for prompt (first 100 chars): {prompt[:100]}")
            response = current_model_instance.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(
                    temperature=current_gemini_config.get('temperature', 0.1),
                    max_output_tokens=current_gemini_config.get('max_output_tokens', 8192),
                    response_mime_type='application/json'
                ),
                request_options={'timeout': current_gemini_config.get('request_timeout', 400)}
            )
            current_logger.debug(f"Gemini API response received. Safety: {response.prompt_feedback if response.prompt_feedback else 'N/A'}, Finish Reason: {response.candidates[0].finish_reason if response.candidates else 'N/A'}")

            if response.prompt_feedback and response.prompt_feedback.block_reason:
                block_msg = response.prompt_feedback.block_reason_message or "No specific block message"
                current_logger.error(f"Prompt blocked by Gemini API. Reason: {response.prompt_feedback.block_reason} - {block_msg}. Prompt: {prompt[:200]}")
                return None

            if not response.candidates:
                current_logger.warning(f"No candidates in Gemini response for prompt: {prompt[:200]}")
                if attempt < max_retries:
                    delay_val = base_delay * (2 ** attempt)
                    current_logger.info(f"Retrying due to no candidates in {delay_val}s...")
                    time.sleep(delay_val)
                    continue
                return None

            if response.candidates[0].finish_reason != 1: # 1 is "STOP"
                 current_logger.warning(f"Gemini generation finished with reason: {response.candidates[0].finish_reason}. Content may be incomplete. Safety: {response.candidates[0].safety_ratings}")

            return response
        except Exception as e:
            delay_val = base_delay * (2 ** attempt)
            if "429" in str(e) or "rate limit" in str(e).lower():
                current_logger.warning(f"Gemini API rate limit error (429) on attempt {attempt + 1}: {e}. Applying longer backoff.")
                try:
                    error_response = getattr(e, 'response', None)
                    if error_response and hasattr(error_response, 'headers'):
                        retry_after_seconds = int(error_response.headers.get("Retry-After", current_gemini_config.get('period_seconds', 60)))
                        delay_val = retry_after_seconds + 5
                    else:
                         delay_val = current_gemini_config.get('period_seconds', 60) + (base_delay * (2 ** attempt))
                except:
                     delay_val = current_gemini_config.get('period_seconds', 60) + (base_delay * (2 ** attempt))

            current_logger.error(f"Gemini API error on attempt {attempt + 1}: {e} for prompt: {prompt[:200]}", exc_info=True)
            if attempt < max_retries:
                current_logger.info(f"Retrying Gemini API call (exception) in {delay_val} seconds...")
                time.sleep(delay_val)
            else:
                current_logger.error(f"Gemini API call failed after {max_retries + 1} attempts for prompt: {prompt[:200]}")
                return None
    return None

def clean_json_response(text: str) -> str:
    text = text.strip()
    if text.startswith("```json"): text = text[7:]
    if text.endswith("```"): text = text[:-3]
    if text.lower().startswith("json") and \
       (text[4:].strip().startswith('{') or text[4:].strip().startswith('[')):
        text = text[4:].strip()
    return text

def extract_text_from_html(current_logger: logging.Logger, html_content: str, pmid: str) -> str:
    if not html_content or not isinstance(html_content, str) or len(html_content.strip()) < 50:
        current_logger.warning(f"Insufficient or invalid HTML content for PMID {pmid} for any text extraction.")
        return ""
    try:
        # FIX: Removed output_format, relying on trafilatura's default.
        trafilatura_text = trafilatura.extract(
            html_content, include_comments=False, include_tables=False, favor_recall=True
        )
        if trafilatura_text and len(trafilatura_text.strip()) > 200:
            current_logger.info(f"Trafilatura successfully extracted text for PMID {pmid}. Length: {len(trafilatura_text.strip())}")
            return trafilatura_text.strip()
        else:
            current_logger.info(f"Trafilatura extracted insufficient text for PMID {pmid} (Length: {len(trafilatura_text.strip()) if trafilatura_text else 0}). Falling back to BeautifulSoup.")
    except Exception as e_traf:
        current_logger.warning(f"Trafilatura extraction failed for PMID {pmid}: {e_traf}. Falling back to BeautifulSoup.", exc_info=False)

    try:
        soup = BeautifulSoup(html_content, 'html.parser')
        for tag_name in ['script', 'style', 'header', 'footer', 'nav', 'aside', 'form', 'button', 'input', 'figure', 'figcaption', 'meta', 'link']:
            for tag in soup.find_all(tag_name):
                tag.decompose()
        noisy_selectors = ['.sidebar', '#sidebar', '.related-posts', '.comments-area', '.site-sidebar',
                           '.post-navigation', '.post-meta', '.breadcrumbs', '.site-header', '.site-footer',
                           '.widget', 'div[class*="share"]', 'div[class*="popup"]', 'div[class*="cookie"]',
                           'div[id*="comment"]', 'div[class*="banner"]', 'div[class*="ad"]']
        for selector in noisy_selectors:
            for el in soup.select(selector):
                el.decompose()

        main_content_selectors = ['article', 'main', '[role="main"]', '.main-content', '.article-body',
                                  '.entry-content', '#content', '.td-post-content', '.post-content', '.content',
                                  'div[class*="article__body"]', 'div[class*="story-content"]', 'div.abstract', 'section.abstract']
        for selector in main_content_selectors:
            content_area = soup.select_one(selector)
            if content_area:
                text = content_area.get_text(separator=' ', strip=True)
                if len(text) > 200:
                    current_logger.info(f"BeautifulSoup extracted text for PMID {pmid} using selector '{selector}'. Length: {len(text)}")
                    return text

        current_logger.info(f"No specific main content selector yielded substantial text (BeautifulSoup) for HTML of PMID {pmid}. Using stripped_strings from body or root.")
        body = soup.find('body')
        target_element = body if body else soup

        text_from_stripped = ' '.join(target_element.stripped_strings) if target_element else ""

        if len(text_from_stripped) > 100:
            current_logger.info(f"BeautifulSoup stripped_strings (HTML) succeeded for PMID {pmid}. Length: {len(text_from_stripped)}")
            return text_from_stripped
        else:
            current_logger.warning(f"BeautifulSoup stripped_strings (HTML) also yielded insufficient text for PMID {pmid}. Length: {len(text_from_stripped)}")
            return ""
    except Exception as e_bs:
        current_logger.error(f"Error during BeautifulSoup HTML parsing for PMID {pmid} (after trafilatura): {e_bs}", exc_info=True)
        return ""

def extract_text_from_xml(current_logger: logging.Logger, xml_content: str, pmid: str) -> str:
    if not xml_content or not isinstance(xml_content, str) or len(xml_content.strip()) < 50:
        current_logger.warning(f"Insufficient or invalid XML content for PMID {pmid}.")
        return ""
    try:
        soup = BeautifulSoup(xml_content, 'xml')
        tags_to_remove = ['table-wrap', 'fig', 'graphic', 'media', 'xref', 'disp-formula', 'inline-formula',
                          'contrib-group', 'aff', 'author-notes', 'pub-history', 'permissions', 'license',
                          'funding-group', 'related-article', 'front-stub', 'article-categories',
                          'article-meta', 'journal-meta', 'ref-list', 'back', 'notes', 'app', 'app-group', 'ack']
        for tag_name in tags_to_remove:
            for tag in soup.find_all(tag_name):
                tag.decompose()

        body = soup.find('body')
        if body:
            text_content = body.get_text(separator=' ', strip=True)
            if len(text_content) > 200:
                current_logger.info(f"Extracted text from XML 'body' for PMID {pmid}. Length: {len(text_content)}")
                return text_content

        abstracts = soup.find_all('abstract')
        abstract_text = ""
        for abstract_tag in abstracts:
            abstract_text += abstract_tag.get_text(separator=' ', strip=True) + " "
        abstract_text = abstract_text.strip()

        article_title_tag = soup.find('article-title')
        title_text = article_title_tag.get_text(separator=' ', strip=True) if article_title_tag else ""

        combined_text = (title_text + " " + abstract_text).strip()

        if len(combined_text) > 100:
             current_logger.info(f"Extracted text from XML 'title' and/or 'abstract' for PMID {pmid}. Length: {len(combined_text)}")
             return combined_text

        all_text = ' '.join(soup.stripped_strings)
        if len(all_text) > 100:
            current_logger.info(f"Extracted text from XML (all stripped strings) for PMID {pmid}. Length: {len(all_text)}")
            return all_text

        current_logger.warning(f"Could not extract substantial text from XML for PMID {pmid} (final length: {len(all_text)}).")
        return ""
    except Exception as e_xml:
        current_logger.error(f"Error during BeautifulSoup XML parsing for PMID {pmid}: {e_xml}", exc_info=True)
        return ""

def get_text_from_content_item(current_logger: logging.Logger, pmid: str, content_item: Dict[str, Any]) -> str:
    if not content_item or not isinstance(content_item, dict):
        current_logger.warning(f"Invalid content_item for PMID {pmid}.")
        return ""
    content_type = content_item.get('type')
    content_text = content_item.get('content')
    final_url = content_item.get('final_url', 'N/A')

    if not content_text or not isinstance(content_text, str):
        current_logger.warning(f"No content string or invalid content for PMID {pmid} from {final_url}. Type: {type(content_text)}")
        return ""
    if content_type == 'html':
        current_logger.info(f"Attempting HTML text extraction for PMID {pmid} from {final_url}.")
        return extract_text_from_html(current_logger, content_text, pmid)
    if content_type == 'xml':
        current_logger.info(f"Attempting XML text extraction for PMID {pmid} from {final_url}.")
        return extract_text_from_xml(current_logger, content_text, pmid)
    if content_type == 'error' or content_text == "Failed to retrieve":
        current_logger.warning(f"Content for PMID {pmid} from {final_url} marked as error/failed.")
        return ""
    current_logger.warning(f"Unknown content type '{content_type}' for PMID {pmid} from {final_url}. Attempting HTML then XML extraction as fallback.")
    text = extract_text_from_html(current_logger, content_text, pmid)
    if not text or len(text) < 100 :
        text = extract_text_from_xml(current_logger, content_text, pmid)
    return text

# --- NEW: Process a single CSV row ---
def process_single_csv_row_with_gemini(
    current_logger: logging.Logger,
    current_model_instance: genai.GenerativeModel,
    pmid: str,
    original_row_data: Dict[str, Any], # Single row from CSV as a dictionary
    content_dict: Dict[str, Dict[str, Any]],
    content_cache: Dict[str, str],
    current_gemini_config: dict
) -> Dict[str, Any]: # Returns a single augmented row dictionary

    augmented_row = original_row_data.copy()
    augmented_row.update({"findings": None, "p_value": None, "population": None,
                          "variant_details": None, "gene_association": None})

    paper_content_item = content_dict.get(pmid)
    if not paper_content_item:
        current_logger.warning(f"No content item found in content_dict for PMID {pmid} (target row: {original_row_data.get('Gene')}/{original_row_data.get('Variant')}). Skipping Gemini.")
        return augmented_row

    if pmid not in content_cache:
        text_for_llm = get_text_from_content_item(current_logger, pmid, paper_content_item)
        content_cache[pmid] = text_for_llm
    else:
        text_for_llm = content_cache[pmid]

    title_for_prompt = str(original_row_data.get('Title', 'No Title Provided')).strip()
    current_text_length = len(text_for_llm) if isinstance(text_for_llm, str) else 0

    if not text_for_llm or current_text_length < 100:
        current_logger.warning(f"Insufficient extracted text for PMID {pmid} (length: {current_text_length}). Using title for analysis of target: {original_row_data.get('Gene')}/{original_row_data.get('Variant')}.")
        text_for_llm = title_for_prompt
        current_title_length = len(text_for_llm) if isinstance(text_for_llm, str) else 0
        if not text_for_llm or current_title_length < 50:
            current_logger.error(f"Both extracted text and title are insufficient for PMID {pmid} (target: {original_row_data.get('Gene')}/{original_row_data.get('Variant')}). Skipping Gemini.")
            return augmented_row

    if not isinstance(text_for_llm, str):
        current_logger.error(f"text_for_llm is not a string for PMID {pmid}. Type: {type(text_for_llm)}. Skipping Gemini.")
        return augmented_row

    max_chars = current_gemini_config.get('max_content_chars', 150000)
    if len(text_for_llm) > max_chars:
        current_logger.info(f"Truncating content for PMID {pmid} from {len(text_for_llm)} to {max_chars} chars for prompt.")
        text_for_llm = text_for_llm[:max_chars]

    current_logger.info(f"Content length for PMID {pmid} (target: {original_row_data.get('Gene')}/{original_row_data.get('Variant')}) to Gemini: {len(text_for_llm)}")

    # Construct prompt for a single target item
    gene = str(original_row_data.get('Gene', 'N/A')).strip()
    variant_val = original_row_data.get('Variant')
    variant_str = str(variant_val).strip() if pd.notna(variant_val) and str(variant_val).strip() else "not specified"
    alleles_val = original_row_data.get('Alleles')
    alleles_str = str(alleles_val).strip() if pd.notna(alleles_val) and str(alleles_val).strip() else "not specified"

    item_to_analyze_prompt = f"""The specific item to analyze from this paper is:
- Gene: {gene}
- Variant: {variant_str}
- Alleles: {alleles_str}
"""

    prompt = f"""Analyze the scientific paper (PMID: {pmid}, Title: "{title_for_prompt}") provided below, focusing ONLY on the specific item listed.

{item_to_analyze_prompt}

Output a single JSON object. The JSON object MUST contain these exact keys: "findings" (main conclusion about the item), "p_value" (e.g., "0.03", "<0.001", "NS", or null if not found), "population" (study population description, or null if not found), "variant_details" (string summarizing variant info like type, location, predicted effect, if applicable; otherwise use the original variant string if provided and relevant, or null), "gene_association" (strength or type of association, e.g., "strong association", "suggestive link", "no association", "functional impact noted", or null if not found).
If information for a key is not found for the specific item, use a JSON null value for that key.
The "variant_details" field MUST be a string. Do NOT use markdown for the JSON output.

Paper Content:
{text_for_llm}
"""
    gemini_response = call_gemini_api_with_retries(current_logger, current_model_instance, prompt, current_gemini_config)

    if not gemini_response or not hasattr(gemini_response, 'text') or not gemini_response.text:
        current_logger.warning(f"No valid response or text from Gemini for PMID {pmid}, target {gene}/{variant_str}.")
        return augmented_row

    try:
        cleaned_json_text = clean_json_response(gemini_response.text)
        extracted_data_object = json.loads(cleaned_json_text) # Expect a single object now

        if not isinstance(extracted_data_object, dict):
            current_logger.error(f"Gemini did not return a JSON object for PMID {pmid}, target {gene}/{variant_str}. Response: {cleaned_json_text[:500]}")
            return augmented_row

        # Validate and update the single augmented_row
        try:
            if isinstance(extracted_data_object.get('variant_details'), dict):
                current_logger.info(f"Gemini returned dict for variant_details (PMID {pmid}, target {gene}/{variant_str}). Converting to string.")
                extracted_data_object['variant_details'] = json.dumps(extracted_data_object['variant_details'])

            validated_item_data = ExtractedData(**extracted_data_object)

            current_variant_details = validated_item_data.variant_details
            if current_variant_details is None: # Fallback to original variant if Gemini provides null
                original_variant_from_row = augmented_row.get('Variant')
                if pd.notna(original_variant_from_row) and str(original_variant_from_row).strip() and str(original_variant_from_row).strip().lower() != "not specified":
                    current_variant_details = str(original_variant_from_row).strip()

            augmented_row.update({
                "findings": validated_item_data.findings,
                "p_value": validated_item_data.p_value,
                "population": validated_item_data.population,
                "variant_details": current_variant_details,
                "gene_association": validated_item_data.gene_association
            })
            current_logger.info(f"Successfully processed PMID {pmid}, target {gene}/{variant_str} with Gemini.")
        except ValidationError as e_val:
            current_logger.error(f"Pydantic validation error for PMID {pmid}, target {gene}/{variant_str}: {e_val}. JSON object: {extracted_data_object}", exc_info=True)
        except Exception as e_other:
            current_logger.error(f"Error updating row for PMID {pmid}, target {gene}/{variant_str}: {e_other}. JSON object: {extracted_data_object}", exc_info=True)

    except json.JSONDecodeError as e_json:
        raw_response_text = gemini_response.text if gemini_response and hasattr(gemini_response, 'text') else 'Raw response text not available'
        cleaned_text_for_error = locals().get('cleaned_json_text', raw_response_text)
        current_logger.error(f"JSON parsing error for PMID {pmid}, target {gene}/{variant_str}: {e_json}. Snippet: {cleaned_text_for_error[:500]}", exc_info=True)
    except Exception as e_gen:
        current_logger.error(f"General error processing Gemini response for PMID {pmid}, target {gene}/{variant_str}: {e_gen}", exc_info=True)

    return augmented_row


def main(input_csv_path: Optional[str] = None):
    global model
    global gemini_config_settings
    global logger # Use the globally configured logger

    if not model:
        logger.critical("Gemini model was not initialized globally. Terminating main function.")
        print("🚨 CRITICAL ERROR: Gemini model not initialized. Cannot proceed.")
        return

    logger.info("--- Starting Step 3: Gemini Insight Extraction (Row-by-Row) ---")

    if not input_csv_path:
        script_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd()
        potential_dirs = [os.getcwd(), os.path.join(script_dir, "data/output"), os.path.join(script_dir, "..", "data/output")]
        csv_files = []
        for d_path in potential_dirs: # Renamed loop variable
            if os.path.isdir(d_path):
                csv_files.extend(glob.glob(os.path.join(d_path, 'pubmed_genetic_results_*.csv')))

        if not csv_files:
            logger.error("No pubmed_genetic_results_*.csv files found in standard locations.")
            print("Error: No input CSV (pubmed_genetic_results_*.csv) found.")
            return
        input_csv_path = max(csv_files, key=os.path.getctime)

    logger.info(f"Using input CSV file: {input_csv_path}")
    print(f"Using input CSV file: {input_csv_path}")

    content_dict_filename = "content_dict.pkl.gz"
    content_dict_full_path = content_dict_filename
    if not os.path.exists(content_dict_full_path):
         script_dir_cd = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd() # Renamed variable
         potential_paths_cd = [ # Renamed variable
             os.path.join(script_dir_cd, "data/output", content_dict_filename),
             os.path.join(script_dir_cd, "..", "data/output", content_dict_filename)
         ]
         for p_path_cd in potential_paths_cd: # Renamed loop variable
             if os.path.exists(p_path_cd):
                 content_dict_full_path = p_path_cd
                 break

    if not os.path.exists(content_dict_full_path):
        logger.error(f"Content dictionary '{content_dict_filename}' not found in standard locations.")
        print(f"Error: Content dictionary '{content_dict_filename}' not found.")
        return
    logger.info(f"Using content dictionary: {content_dict_full_path}")

    try:
        df = pd.read_csv(input_csv_path)
        if 'pmid' in df.columns and 'PMID' not in df.columns:
            df.rename(columns={'pmid': 'PMID'}, inplace=True)
        if 'PMID' not in df.columns:
            logger.error(f"CSV {input_csv_path} must contain 'PMID' or 'pmid' column.")
            print(f"Error: CSV {input_csv_path} needs 'PMID' or 'pmid' column.")
            return
        df['PMID'] = df['PMID'].astype(str)
    except Exception as e:
        logger.error(f"Error reading CSV {input_csv_path}: {e}", exc_info=True)
        print(f"Error reading CSV: {e}")
        return

    try:
        with gzip.open(content_dict_full_path, 'rb') as f_gz:
            content_dict = {str(k): v for k, v in pickle.load(f_gz).items()}
        logger.info(f"Loaded {len(content_dict)} entries from {content_dict_full_path}")
        print(f"Loaded {len(content_dict)} from {content_dict_full_path}")
    except Exception as e:
        logger.error(f"Error loading {content_dict_full_path}: {e}", exc_info=True)
        print(f"Error loading {content_dict_full_path}: {e}")
        return

    valid_pmids_with_content = set()
    for pmid_key, item_data in content_dict.items():
        if isinstance(item_data, dict) and \
           item_data.get('type') != 'error' and \
           isinstance(item_data.get('content'), str) and \
           item_data.get('content').strip():
            valid_pmids_with_content.add(str(pmid_key))

    original_row_count_df = len(df)
    df_filtered = df[df['PMID'].isin(valid_pmids_with_content)].copy()
    filtered_row_count_df = len(df_filtered) # Renamed variable

    if filtered_row_count_df < original_row_count_df:
        msg = (f"Filtered CSV from {original_row_count_df} to {filtered_row_count_df} rows. "
               f"{original_row_count_df - filtered_row_count_df} rows were removed because their PMIDs "
               f"lacked valid, non-empty content in '{content_dict_full_path}'.")
        print(f"INFO: {msg}")
        logger.info(msg)

    if df_filtered.empty:
        msg = "No rows remain to process after filtering for PMIDs with valid content. Exiting."
        print(msg)
        logging.info(msg)
        return

    # --- Testing Slice: Process only the first 120 rows ---
    num_rows_to_test = 120
    if len(df_filtered) > num_rows_to_test:
        logger.info(f"TESTING MODE: Processing only the first {num_rows_to_test} rows of the {len(df_filtered)} filtered rows.")
        print(f"INFO: TESTING MODE - Will process only the first {num_rows_to_test} rows.")
        df_to_process = df_filtered.head(num_rows_to_test).copy()
    else:
        df_to_process = df_filtered.copy()
    logger.info(f"Number of CSV rows to process: {len(df_to_process)}")
    print(f"Number of CSV rows to process: {len(df_to_process)}")


    max_gemini_workers = gemini_config_settings.get('max_gemini_workers', 1)

    logger.info(f"Starting Gemini processing for {len(df_to_process)} CSV rows. Using {max_gemini_workers} worker(s).")
    print(f"Starting Gemini processing for {len(df_to_process)} CSV rows. Using {max_gemini_workers} worker(s).")

    all_processed_results_final: List[Dict[str, Any]] = []
    text_extraction_cache: Dict[str, str] = {}
    run_timestamp = time.strftime("%Y%m%d_%H%M%S")

    tasks = []
    logger.info("Creating tasks for individual CSV row processing...")
    # Each task now represents a single original CSV row.
    # If multiple CSV rows share the same PMID, the paper's text will be fetched once (and cached),
    # but a separate Gemini query will be made for each CSV row's specific gene/variant target.
    for index, row_data_series in df_to_process.iterrows():
        row_data_dict = row_data_series.to_dict()
        pmid_val = str(row_data_dict['PMID']) # Renamed to avoid conflict
        tasks.append((logger, model, pmid_val, row_data_dict, content_dict, text_extraction_cache, gemini_config_settings))

    if max_gemini_workers > 1 and len(tasks) > 0 :
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_gemini_workers) as executor:
            # Each task is now for a single row, so future_to_original_row might be more fitting
            future_to_original_row = {
                executor.submit(process_single_csv_row_with_gemini, *task_args): task_args[3] # task_args[3] is original_row_data (a dict)
                for task_args in tasks
            }

            for future in tqdm(concurrent.futures.as_completed(future_to_original_row), total=len(tasks), desc="Processing CSV rows (Gemini Parallel)", unit="row"):
                original_row_if_error = future_to_original_row[future]
                pmid_processed = str(original_row_if_error.get('PMID', "UnknownPMID"))
                gene_processed = str(original_row_if_error.get('Gene', "UnknownGene"))

                try:
                    augmented_row_dict = future.result() # Expecting a single dict
                    all_processed_results_final.append(augmented_row_dict)
                except Exception as exc:
                    logger.error(f"PMID {pmid_processed}, Gene {gene_processed} generated an exception in ThreadPool: {exc}", exc_info=True)
                    # Append the original row with None for Gemini fields on error
                    error_row = original_row_if_error.copy()
                    error_row.update({"findings": None, "p_value": None, "population": None, "variant_details": None, "gene_association": None})
                    all_processed_results_final.append(error_row)
    else:
        for task_args in tqdm(tasks, total=len(tasks), desc="Processing CSV rows (Gemini Sequential)", unit="row"):
            augmented_row_dict = process_single_csv_row_with_gemini(*task_args)
            all_processed_results_final.append(augmented_row_dict)

    script_dir_for_output_main = os.path.dirname(os.path.abspath(__file__)) if '__file__' in locals() else os.getcwd() # Renamed variable
    base_output_dir = os.path.join(script_dir_for_output_main, "data/output") if os.path.isdir(os.path.join(script_dir_for_output_main, "data")) else script_dir_for_output_main
    os.makedirs(base_output_dir, exist_ok=True)

    final_json_output_path = os.path.join(base_output_dir, f'gemini_extracted_row_by_row_{run_timestamp}.json')
    try:
        with open(final_json_output_path, 'w', encoding='utf-8') as f_final_json:
            json.dump(all_processed_results_final, f_final_json, indent=4, ensure_ascii=False)
        logger.info(f"Saved final {len(all_processed_results_final)} augmented rows to {final_json_output_path}")
        print(f"\nSaved final {len(all_processed_results_final)} augmented rows to {final_json_output_path}")
    except Exception as e_save_final_json:
        logger.error(f"Error saving final JSON results to {final_json_output_path}: {e_save_final_json}", exc_info=True)
        print(f"\nError saving final JSON results: {e_save_final_json}")


    final_csv_output_path = os.path.join(base_output_dir, f'gemini_extracted_row_by_row_{run_timestamp}.csv')
    if all_processed_results_final:
        try:
            final_df = pd.DataFrame(all_processed_results_final)
            original_cols_from_input_df = list(df.columns)
            gemini_added_cols = ["findings", "p_value", "population", "variant_details", "gene_association"]

            ordered_cols = [col for col in original_cols_from_input_df if col in final_df.columns]
            for col in gemini_added_cols:
                if col in final_df.columns and col not in ordered_cols:
                    ordered_cols.append(col)
            for col in final_df.columns:
                if col not in ordered_cols:
                    ordered_cols.append(col)

            if ordered_cols:
                final_df = final_df[ordered_cols]
            else:
                logger.warning("Could not determine column order for final CSV. Using default DataFrame order.")

            final_df.to_csv(final_csv_output_path, index=False, encoding='utf-8')
            logger.info(f"Saved final {len(all_processed_results_final)} augmented rows to {final_csv_output_path}")
            print(f"Saved final {len(all_processed_results_final)} augmented rows to {final_csv_output_path}")
        except Exception as e_csv:
            logger.error(f"Could not save final results to CSV {final_csv_output_path}: {e_csv}", exc_info=True)
            print(f"Error saving final results to CSV: {e_csv}. JSON available at {final_json_output_path}")
    else:
        logger.info("No results processed to save to CSV.")
        print("No results processed to save to CSV.")

    logger.info("--- Step 3: Gemini Insight Extraction (Row-by-Row) Finished ---")

if __name__ == "__main__":
    main()

CRITICAL:__main__:Using a placeholder or example Gemini API key. Please update it in the script.


Gemini model 'models/gemini-1.5-flash-latest' initialized.
Using input CSV file: /content/pubmed_genetic_results_68a3f3d2.csv
Loaded 923 from content_dict.pkl.gz
INFO: Filtered CSV from 4695 to 3499 rows. 1196 rows were removed because their PMIDs lacked valid, non-empty content in 'content_dict.pkl.gz'.
INFO: TESTING MODE - Will process only the first 120 rows.
Number of CSV rows to process: 120
Starting Gemini processing for 120 CSV rows. Using 1 worker(s).


Processing CSV rows (Gemini Sequential):  48%|████▊     | 57/120 [03:24<02:08,  2.04s/row]ERROR:tornado.access:503 POST /v1beta/models/gemini-1.5-flash-latest:generateContent?%24alt=json%3Benum-encoding%3Dint (127.0.0.1) 533.38ms
Processing CSV rows (Gemini Sequential): 100%|██████████| 120/120 [07:34<00:00,  3.79s/row]


Saved final 120 augmented rows to /content/gemini_extracted_row_by_row_20250531_174830.json
Saved final 120 augmented rows to /content/gemini_extracted_row_by_row_20250531_174830.csv



