In [None]:
import pandas as pd
import time
import random
import numpy as np
import os
import re
from collections import Counter
from openai import OpenAI, RateLimitError, APIError

# --- CONFIGURATION ---
INPUT_DATA_PATH = "/data/users_data/mli13/LLMvalidationAD112025/drug_disease_long_format_20250325_190703.csv"
ABSTRACT_OUTPUT_PATH = "/data/users_data/mli13/LLMvalidationAD112025/result112025/realtime_abstract_analysis.csv"
SUMMARY_OUTPUT_PATH = "/data/users_data/mli13/LLMvalidationAD112025/result112025/realtime_summary_analysis.csv"
API_KEY_PATH = '/data/users_data/mli13/LLMvalidationAD112025/openai_api.txt'


# Set OpenAI API key
with open(API_KEY_PATH, 'r') as f:
    api_key = f.read().strip()

client = OpenAI(api_key=api_key)
mode = "response"

# --- HELPER FUNCTIONS ---

def construct_comprehensive_prompt(row, abstract):
    """
    Revised prompt to strictly limit external knowledge.
    """
    drug_name = row['Drug_name']
    disease_name = row['Disease_name']


    prompt = (
        f"You are an expert biomedical researcher. Your task is to determine whether {drug_name} is effective "
        f"against {disease_name}.\n\n"
        f"CRITICAL INSTRUCTION: Answer strictly based ONLY on the provided text below. "
        f"Do not use outside knowledge or prior training data regarding this drug or disease. "
        f"If the abstract does not explicitly state the drug is effective for this specific disease, classify as Neutral.\n"
        f"--- BEGIN PROVIDED TEXT ---\n"
        f"[Drug Info]: {drug_name}\n"
    )


    # Add key drug information fields
    key_fields = ['description', 'mechanism_of_action', 'protein_binding', 'pharmacodynamics', 'category']
    for field in key_fields:
        if field in row and pd.notnull(row[field]) and row[field]:
            value = str(row[field])
            # # Limit length of very long text
            # if len(value) > 300:
            #     value = value[:300] + "..."
            prompt += f"{field.replace('_', ' ').title()}: {value}\n"

    # Add drug synonyms
    drug_synonyms = []
    for i in range(1, 11):  # drug_synonym1 through drug_synonym21
        syn_col = f"drug_synonym{i}"
        if syn_col in row and pd.notnull(row[syn_col]) and row[syn_col]:
            drug_synonyms.append(str(row[syn_col]))

    if drug_synonyms:
        prompt += f"Drug Synonyms: {', '.join(drug_synonyms)}\n"

    # Add disease information
    prompt += f"\nDISEASE INFORMATION:\n"
    prompt += f"Disease: {disease_name}\n"

    # Add disease synonyms
    disease_synonyms = []
    for i in range(1, 11):  # disease_synonym1 through disease_synonym15
        syn_col = f"disease_synonym{i}"
        if syn_col in row and pd.notnull(row[syn_col]) and row[syn_col]:
            disease_synonyms.append(str(row[syn_col]))

    if disease_synonyms:
        prompt += f"Disease Synonyms: {', '.join(disease_synonyms)}\n"

    prompt += f"[Abstract]: {abstract}\n"
    prompt += f"--- END PROVIDED TEXT ---\n\n"

    # Add conclusion request
    #prompt += (
     #   "Based primarily on the abstract, and considering the drug and disease information provided, "
      #  "assess whether the drug is effective for the disease. Provide your assessment in the following format:\n\n"
       # "Result: [Positive/Neutral/Negative] (Choose exactly one, where Positive means the drug is effective, "
        #"Neutral means uncertain or insufficient evidence, and Negative means the drug is ineffective or harmful)\n"
     #   "Explanation: [provide a brief explanation in 2-3 sentences focusing mainly on evidence from the abstract]"
    #)

    #return prompt

    prompt += (
        "Based solely on the text delimited above, provide your assessment in the following format:\n\n"
        "Result: [Positive/Neutral/Negative] \n"
        "(Definitions: \n"
        "- Positive: The text explicitly indicates the drug is effective.\n"
        "- Neutral: The text is uncertain, provides insufficient evidence, or describes a study with no clear conclusion.\n"
        "- Negative: The text indicates the drug is harmful or ineffective.)\n\n"
        "Explanation: Provide a brief explanation in 2-3 sentences citing specific evidence strictly from the provided abstract."
    )
    return prompt


def call_gpt_with_retry(prompt, mode, max_retries=5):
    """
    Calls the GPT-4o API with exponential backoff retry logic for rate limits.
    """
    for attempt in range(max_retries):
        try:
            if mode == "chat":
                completion = client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.1,  # Lower temperature for more consistent results
                )
                return completion.choices[0].message.content.strip()
            elif mode == "response":
                response = client.responses.create(
                    model="gpt-4.1",
                    input=prompt,
                    temperature=0.2,  # Lower temperature for more consistent results
                )
                # Extract text from the output
                for item in response.output:
                    if hasattr(item, "type") and item.type == "message":
                        if hasattr(item, "content"):
                            for content_item in item.content:
                                if hasattr(content_item, "text"):
                                    return content_item.text.strip()
                return "No output found in response structure."
            elif mode == "response_web":
                response = client.responses.create(
                    model="gpt-4o",
                    tools=[{"type": "web_search_preview"}],
                    input=prompt,
                    temperature=0.1,  # Lower temperature for more consistent results
                )
                # Extract text from the output
                for item in response.output:
                    if hasattr(item, "type") and item.type == "message":
                        if hasattr(item, "content"):
                            for content_item in item.content:
                                if hasattr(content_item, "text"):
                                    return content_item.text.strip()
                return "No output found in response structure."
            else:
                raise ValueError("Invalid mode selected.")

        except RateLimitError as e:
            print(f"Rate limit error (attempt {attempt+1}/{max_retries}): {e}")

            # If we've reached the max retries, just return the error
            if attempt == max_retries - 1:
                return f"API call error after {max_retries} attempts: {str(e)}"

            # Exponential backoff with jitter
            wait_time = (2 ** attempt) + random.uniform(0, 1) + 20  # Longer wait for rate limits
            print(f"Waiting {wait_time:.2f} seconds before retrying...")
            time.sleep(wait_time)
        except APIError as e:
            print(f"API error (attempt {attempt+1}/{max_retries}): {e}")

            # If we've reached the max retries, just return the error
            if attempt == max_retries - 1:
                return f"API call error after {max_retries} attempts: {str(e)}"

            # Standard backoff with jitter
            wait_time = (2 ** attempt) + random.uniform(0, 1) + 5
            print(f"Waiting {wait_time:.2f} seconds before retrying...")
            time.sleep(wait_time)
        except Exception as e:
            print(f"Unexpected error (attempt {attempt+1}/{max_retries}): {e}")

            # If we've reached the max retries, just return the error
            if attempt == max_retries - 1:
                return f"API call error after {max_retries} attempts: {str(e)}"

            # Shorter wait for non-rate-limit errors
            wait_time = (2 ** attempt) + random.uniform(0, 1)
            print(f"Waiting {wait_time:.2f} seconds before retrying...")
            time.sleep(wait_time)

    return "Failed after maximum retry attempts"

def parse_assessment_output(output):
    """
    Parse the assessment output to extract Result and Explanation.
    Expected format:
    Result: [Positive/Neutral/Negative]
    Explanation: [explanation]
    """
    import re

    result = "Unknown"
    explanation = ""

    if output and not output.startswith("API call error"):
        # Try to find Result section using case-insensitive matching
        result_match = re.search(r'result:\s*(positive|neutral|negative)', output.lower())
        if result_match:
            result_value = result_match.group(1)
            # Convert to proper case format
            result = result_value.capitalize()

        # Try alternate format if not found (sometimes GPT outputs "Result - Positive" format)
        if result == "Unknown":
            alt_result_match = re.search(r'result\s*[-:]\s*(positive|neutral|negative)', output.lower())
            if alt_result_match:
                result_value = alt_result_match.group(1)
                result = result_value.capitalize()

        # Try to find Explanation section using case-insensitive matching
        explanation_match = re.search(r'explanation:\s*(.*?)(?:\n\n|\n*$)', output, re.IGNORECASE | re.DOTALL)
        if explanation_match:
            explanation = explanation_match.group(1).strip()
        else:
            # Try alternate format or look for any text after the result
            after_result = re.search(r'(positive|neutral|negative)[.:]\s*(.*?)(?:\n\n|\n*$)', output.lower(), re.DOTALL)
            if after_result:
                explanation = after_result.group(2).strip()

    # If parsing failed, return the whole output as explanation
    if result == "Unknown" and not explanation:
        explanation = output
        # Make final attempt to extract result from the text
        if "positive" in output.lower() and "negative" not in output.lower():
            result = "Positive"
        elif "negative" in output.lower() and "positive" not in output.lower():
            result = "Negative"
        elif "neutral" in output.lower() or "insufficient evidence" in output.lower():
            result = "Neutral"

    return result, explanation

# --- MAIN ANALYSIS FUNCTION ---

def analyze_drug_disease_abstracts(df, mode="response", filter_disease=None, filter_drug=None, batch_size=5, max_abstracts_per_pair=50):
    
    # 1. LOAD PREVIOUS RESULTS (RESUME LOGIC)
    processed_cache = {} # Key: (Drug, Disease, PMID) -> Row Dict
    
    # Ensure output directories exist
    os.makedirs(os.path.dirname(ABSTRACT_OUTPUT_PATH), exist_ok=True)
    
    if os.path.exists(ABSTRACT_OUTPUT_PATH):
        print("Loading existing results to resume...")
        try:
            existing_df = pd.read_csv(ABSTRACT_OUTPUT_PATH, on_bad_lines='skip')
            for _, row in existing_df.iterrows():
                # Create a unique key for each processed item
                key = (str(row['Drug_name']), str(row['Disease_name']), str(row['PubMed_ID']))
                processed_cache[key] = row.to_dict()
            print(f"Loaded {len(processed_cache)} previously processed abstracts.")
        except Exception as e:
            print(f"Could not read existing file: {e}. Starting fresh.")
    else:
        # Initialize file with headers
        pd.DataFrame(columns=["Drug_ID", "Drug_name", "Disease_ID", "Disease_name", "PubMed_ID", "Title", "Model", "Result", "Explanation", "Raw_Output"]).to_csv(ABSTRACT_OUTPUT_PATH, index=False)
        pd.DataFrame(columns=["Drug_ID", "Drug_name", "Disease_ID", "Disease_name", "Model", "Total_Abstracts", "Analyzed", "Positive", "Neutral", "Negative", "Overall_Assessment"]).to_csv(SUMMARY_OUTPUT_PATH, index=False)

    # 2. FILTERING
    if filter_disease:
        if isinstance(filter_disease, str):
            df_filtered = df[df['Disease_name'].str.contains(filter_disease, case=False)]
        elif isinstance(filter_disease, list):
            df_filtered = df[df['Disease_name'].isin(filter_disease)]
    else:
        df_filtered = df

    if filter_drug:
        if isinstance(filter_drug, str):
            df_filtered = df_filtered[df_filtered['Drug_name'].str.contains(filter_drug, case=False)]
        elif isinstance(filter_drug, list):
            df_filtered = df_filtered[df_filtered['Drug_name'].isin(filter_drug)]
    
    print(f"Processing {len(df_filtered)} rows after filtering")

    # 3. GROUPING
    def standardize_value(val):
        if pd.isna(val) or (isinstance(val, str) and val.lower() in ('none', 'null', '')):
            return 'Missing_Value'
        return val

    df_for_grouping = df_filtered.copy()
    df_for_grouping['Drug_ID'] = df_for_grouping['Drug_ID'].apply(standardize_value)
    df_for_grouping['Disease_ID'] = df_for_grouping['Disease_ID'].apply(standardize_value)
    df_for_grouping['Model'] = df_for_grouping['Model'].apply(standardize_value)

    pairs_with_count = df_filtered.groupby(['Drug_ID', 'Disease_ID', 'Model']).agg({
        'Drug_name': 'first',
        'Disease_name': 'first',
        'pubmed_id_count': 'first',
        'pubmed_id': 'count'
    }).reset_index()
    
    pairs_with_count.rename(columns={'pubmed_id': 'available_abstracts'}, inplace=True)
    pairs_with_count = pairs_with_count.sort_values(['Disease_name', 'Drug_name'])

    print(f"Found {len(pairs_with_count)} unique drug-disease-model pairs")

    # 4. PROCESSING LOOP
    for _, pair_row in pairs_with_count.iterrows():
        drug_id = pair_row['Drug_ID']
        disease_id = pair_row['Disease_ID']
        model = pair_row['Model']
        drug_name = pair_row['Drug_name']
        disease_name = pair_row['Disease_name']
        
        # Re-select rows for this group
        def is_val(row_val, target):
            std_val = standardize_value(row_val)
            return std_val == standardize_value(target)

        # Filter manually to be safe with NaNs
        mask = (df_filtered['Drug_ID'].apply(lambda x: is_val(x, drug_id)) & 
                df_filtered['Disease_ID'].apply(lambda x: is_val(x, disease_id)) &
                df_filtered['Model'].apply(lambda x: is_val(x, model)))
        
        group = df_filtered[mask]

        if len(group) > max_abstracts_per_pair:
            group = group.head(max_abstracts_per_pair)

        print(f"\nAnalyzing {drug_name} for {disease_name} ({len(group)} abstracts)")

        pair_results = []
        
        # Batch processing
        for batch_idx, batch_df in enumerate(np.array_split(group, max(1, len(group) // batch_size))):
            for idx, row in batch_df.iterrows():
                pubmed_id = str(row.get('pubmed_id', 'Unknown'))
                title = row.get('title', 'No title')
                
                # CHECK CACHE
                cache_key = (str(drug_name), str(disease_name), pubmed_id)
                
                if cache_key in processed_cache:
                    # Abstract already processed, load from cache
                    # print(f"  Skipping {pubmed_id} (Already processed)")
                    pair_results.append(processed_cache[cache_key])
                    continue

                # Not in cache, process it
                print(f"  Processing new abstract: {pubmed_id}")
                
                abstract_text = row.get('abstract', '')
                if pd.isnull(abstract_text) or len(str(abstract_text)) < 50:
                    title_text = row.get('title', '')
                    abstract_text = f"TITLE ONLY: {title_text}"
                
                prompt = construct_comprehensive_prompt(row, abstract_text)
                output = call_gpt_with_retry(prompt, mode)
                result, explanation = parse_assessment_output(output)

                # Create result object
                abstract_result = {
                    "Drug_ID": drug_id,
                    "Drug_name": drug_name,
                    "Disease_ID": disease_id,
                    "Disease_name": disease_name,
                    "PubMed_ID": pubmed_id,
                    "Title": title,
                    "Model": model,
                    "Result": result,
                    "Explanation": explanation,
                    "Raw_Output": output
                }
                
                # Save IMMEIDATELY to disk
                pd.DataFrame([abstract_result]).to_csv(ABSTRACT_OUTPUT_PATH, mode='a', header=False, index=False)
                
                # Add to memory for stats
                pair_results.append(abstract_result)
                
                # Add to cache so we don't re-process if we crash later in this same group
                processed_cache[cache_key] = abstract_result

            # Pause between batches if we actually made calls
            if batch_idx < max(1, len(group) // batch_size) - 1:
                time.sleep(1)

        # 5. GENERATE & SAVE STATISTICS (Once group is done)
        if pair_results:
            counts = Counter([str(r.get("Result", "Unknown")).capitalize() for r in pair_results])
            total = len(pair_results)
            
            pos_pct = (counts.get("Positive", 0) / total) * 100
            neg_pct = (counts.get("Negative", 0) / total) * 100
            neu_pct = (counts.get("Neutral", 0) / total) * 100
            
            overall = "Inconclusive"
            if pos_pct >= 60: overall = "Positive"
            elif neg_pct >= 60: overall = "Negative"
            elif neu_pct >= 60: overall = "Neutral"
            elif (pos_pct > neu_pct + 20 and pos_pct > neg_pct + 20): overall = "Likely Positive"
            elif (neg_pct > neu_pct + 20 and neg_pct > pos_pct + 20): overall = "Likely Negative"

            summary_row = {
                "Drug_ID": drug_id,
                "Drug_name": drug_name,
                "Disease_ID": disease_id,
                "Disease_name": disease_name,
                "Model": model,
                "Total_Abstracts": len(group),
                "Analyzed": total,
                "Positive": counts.get("Positive", 0),
                "Neutral": counts.get("Neutral", 0),
                "Negative": counts.get("Negative", 0),
                "Overall_Assessment": overall
            }
            
            # Save Summary immediately
            pd.DataFrame([summary_row]).to_csv(SUMMARY_OUTPUT_PATH, mode='a', header=not os.path.exists(SUMMARY_OUTPUT_PATH), index=False)
            
            print(f"  -> Stats: Pos: {pos_pct:.1f}% | Neu: {neu_pct:.1f}% | Neg: {neg_pct:.1f}% -> {overall}")

# --- EXECUTION ---
if __name__ == "__main__":
    if os.path.exists(INPUT_DATA_PATH):
        print(f"Loading data from {INPUT_DATA_PATH}...")
        df_full = pd.read_csv(INPUT_DATA_PATH)
        
        analyze_drug_disease_abstracts(
            df_full, 
            mode="response",
            batch_size=5,
            max_abstracts_per_pair=200
        )
        print("Analysis completed.")
    else:
        print(f"Error: Input file not found at {INPUT_DATA_PATH}")