In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import numpy as np

In [None]:
INPUT_EXCEL_FILE1 ="/content/drive/MyDrive/worker_comp_work/WC_Final/research_ready_facts_AI_sampled.xlsx"
OUTPUT_EXCEL_FILE1 ="/content/drive/MyDrive/worker_comp_work/WC_Final/research_ready_facts_AI_sampled_gemini_1.5_flash_002.xlsx"

In [None]:
!pip install -q -U google-generativeai google-api-core pandas==2.2.2

In [None]:
import pandas as pd
import google.generativeai as genai
import os
import time
from google.api_core import exceptions # Import for specific exceptions
from google.colab import userdata

In [None]:
api_key = userdata.get("GOOGLE_API_KEY")

In [None]:
import json

In [None]:
genai.configure(api_key=api_key)
# 2. Specify the Gemini Model
# Use the latest appropriate 'flash' model identifier (e.g., gemini-1.5-flash-latest)
# MODEL_NAME_LIST = ["gemini-1.5-pro", "gemini-2.0-flash", gemini-1.5-flash-002]
MODEL_NAME_LIST = ["gemini-1.5-flash-002"] # Or "gemini-1.5-flash-latest"

# 3. Excel File Paths
INPUT_EXCEL_FILE = INPUT_EXCEL_FILE1  # Replace with your input file name
OUTPUT_EXCEL_FILE = OUTPUT_EXCEL_FILE1 # Changed output name slightly

# 4. Column Names (Adjust if different in your Excel)
# 4. Column Names (Adjust if different in your Excel)
FACTS_COLUMN = "Annonymized_Facts"

# 5. Time Delays (in seconds)
MAIN_LOOP_DELAY_SECONDS = 0
API_CALL_DELAY_SECONDS = 0  # Adjust as needed, especially if hitting rate limits

# --- Create Gemini Model Instance (ONCE) ---
# For structured JSON output, you MUST set response_mime_type to "application/json"
generation_config = {
    "temperature": 0, # Keep at 0 for deterministic extraction
    "top_p": None,
    "top_k": 1,
    "response_mime_type": "application/json", # Crucial for JSON output
}

safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]

# --- Function to get predictions from Gemini ---
def get_gemini_predictions(facts_text, MODEL_NAME):
    """
    Sends anonymized facts to Google Gemini and asks it to predict Case ID, Year,
    Plaintiff Name, and Defendant Name.

    Args:
        facts_text (str): The text from the 'Annonymized_Facts' column.
        MODEL_NAME (str): The Google Gemini model to use for the prediction.

    Returns:
        dict: A dictionary with predicted 'Case ID', 'Year', 'Plaintiff Name',
              'Defendant Name'. Returns None or default values if an error occurred.
    """
    if not facts_text or not isinstance(facts_text, str) or len(facts_text.strip()) == 0:
        print("Warning: Empty or invalid facts text provided. Returning default predictions.")
        return {
            "Predicted_ID": "N/A",
            "Predicted_Year": "N/A",
            "Predicted_Plaintiff": "N/A",
            "Predicted_Defendent": "N/A"
        }

    # New prompt to extract specific entities
    # For Gemini, it's often better to put the JSON instruction clearly in the system/user role.
    system_prompt = """
    You are an expert at extracting case information from legal summaries. Based solely on your internal training data and knowledge (no web search), identify the following for each anonymized workers' compensation case:
    "Case_ID": (string, predicted Case ID, or "Unknown" if not found)
    "Year": (string, predicted Year the case was heard, or "Unknown" if not found)
    "Plaintiff_Name": (string, predicted Plaintiff's Name, or "Unknown" if not found)
    "Defendant_Name": (string, predicted Defendant's Name, or "Unknown" if not found)

    If a piece of information is explicitly stated as anonymized or cannot be confidently extracted, use "Unknown" for that specific key.
    Your response MUST be a JSON object and contain ONLY the JSON object. Do NOT include any other text or explanation.
    """

    user_prompt = f"""
    You are given the following anonymized workers compensation case. Now predict Case ID, Year of this Case heard, Plaintiff Name and Defendant Name.

    Anonymized Case Facts:
    ---
    {facts_text}
    ---
    """

    # For Gemini, the prompt structure for JSON output is usually handled by the content.
    # While we can't directly specify `response_format={"type": "json_object"}` like OpenAI,
    # a strong system/user prompt usually guides Gemini to produce JSON.

    predicted_values = {
        "Predicted_ID": "Error",
        "Predicted_Year": "Error",
        "Predicted_Plaintiff": "Error",
        "Predicted_Defendent": "Error"
    }

    try:
        print(f"Waiting for {API_CALL_DELAY_SECONDS} second(s) before API call...")
        time.sleep(API_CALL_DELAY_SECONDS)

        print(f"Making API call to {MODEL_NAME} for extraction...")
        # Use genai.GenerativeModel for chat interactions
        model = genai.GenerativeModel(MODEL_NAME)
        response = model.generate_content(
            [{"role": "user", "parts": [system_prompt, user_prompt]}],
            generation_config=generation_config,
            safety_settings=safety_settings
        )
        print("API call complete.")

        # Access content based on Gemini's structure
        response_content = response.text.strip()

        # Attempt to parse the JSON response
        try:
            parsed_json = json.loads(response_content)
            predicted_values["Predicted_ID"] = parsed_json.get("Case_ID", "Unknown")
            predicted_values["Predicted_Year"] = parsed_json.get("Year", "Unknown")
            predicted_values["Predicted_Plaintiff"] = parsed_json.get("Plaintiff_Name", "Unknown")
            predicted_values["Predicted_Defendent"] = parsed_json.get("Defendant_Name", "Unknown")
        except json.JSONDecodeError:
            print(f"Warning: Could not parse JSON from response: '{response_content}' for facts: '{facts_text[:100]}...'")
            predicted_values["Predicted_ID"] = "Parsing Error"
            predicted_values["Predicted_Year"] = "Parsing Error"
            predicted_values["Predicted_Plaintiff"] = "Parsing Error"
            predicted_values["Predicted_Defendent"] = "Parsing Error"

    except Exception as e: # Catch broader exceptions for Gemini as it's not as granular as OpenAI's
        print(f"An unexpected error occurred during Google Gemini API call for facts: '{facts_text[:100]}...'. Error: {e}")
        # Consider specific error handling if google.api_core.exceptions is imported and used
        # for more granular error types (e.g., RateLimitExceeded, Aborted).

    return predicted_values

# --- Main Processing Logic ---

print(f"\nReading Excel file: {INPUT_EXCEL_FILE}")
try:
    df = pd.read_excel(INPUT_EXCEL_FILE)
    print(f"Successfully read {len(df)} rows.")
except FileNotFoundError:
    print(f"Error: Input file not found at {INPUT_EXCEL_FILE}")
    exit()
except Exception as e:
    print(f"Error reading Excel file: {e}")
    exit()

# Check for input column
if FACTS_COLUMN not in df.columns:
    print(f"Error: Column '{FACTS_COLUMN}' not found in the Excel file.")
    exit()

# Initialize new prediction columns
output_columns = ["Predicted_ID", "Predicted_Year", "Predicted_Plaintiff", "Predicted_Defendent"]
for col in output_columns:
    if col not in df.columns:
        df[col] = "Not Processed" # Initialize with a placeholder

for selected_model in MODEL_NAME_LIST:
    print(f"\n===== Starting processing for Model: {selected_model} =====")
    for run_number in range(1, 2): # Just one run for prediction
        print(f"\n--- Model: {selected_model}, Run: {run_number} ---")
        print(f"Predicting details into columns: {', '.join(output_columns)}")

        print(f"\nProcessing {len(df)} rows using Google Gemini model: {selected_model} (Run {run_number})...")

        total_rows = len(df)
        for index, row in df.iterrows():
            print(f"\n--- Processing row {index + 1} of {total_rows} (Model: {selected_model}, Run: {run_number}) ---")
            facts = str(row[FACTS_COLUMN]) if pd.notna(row[FACTS_COLUMN]) else ""

            # Get the predictions from Google Gemini
            predictions = get_gemini_predictions(facts, selected_model)

            # Update the DataFrame with the predicted values
        # Update the DataFrame with the predicted values, taking only the first element if it's a list/tuple
            for key, default_val in [
                ("Predicted_ID", "Error"),
                ("Predicted_Year", "Error"),
                ("Predicted_Plaintiff", "Error"),
                ("Predicted_Defendent", "Error")
            ]:
                predicted_val = predictions.get(key, default_val)
                if isinstance(predicted_val, (list, tuple)) and len(predicted_val) > 0:
                    df.loc[index, key] = predicted_val[0]
                else:
                    df.loc[index, key] = predicted_val

            print(f"Row {index + 1}: Predicted ID: {predictions.get('Predicted_ID', 'Error')}, "
                    f"Year: {predictions.get('Predicted_Year', 'Error')}, "
                    f"Plaintiff: {predictions.get('Predicted_Plaintiff', 'Error')}, "
                    f"Defendant: {predictions.get('Predicted_Defendent', 'Error')}")

            if index < total_rows - 1:
                if MAIN_LOOP_DELAY_SECONDS > 0:
                    print(f"Waiting for {MAIN_LOOP_DELAY_SECONDS} second(s) before next row...")
                    time.sleep(MAIN_LOOP_DELAY_SECONDS)

        print(f"\n--- Finished Run {run_number} for Model: {selected_model} ---")

        # Save the DataFrame after each run (or after all runs if not looping)
        print(f"Saving results to {OUTPUT_EXCEL_FILE}...")
        try:
            df.to_excel(OUTPUT_EXCEL_FILE, index=False)
            print(f"Successfully saved results to: {OUTPUT_EXCEL_FILE}")
        except Exception as e:
            print(f"Error saving results to Excel file: {e}")

    print(f"\n===== Finished all runs for Model: {selected_model} =====")

print("\nAll processing complete. Script finished.")