In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Suppose you saved keys in /content/drive/MyDrive/gemini_keys.txt as single line: key1,key2,...
with open('/kaggle/input/numerical-symbolic-nctb/gemini_keys.txt','r') as fh:
    keys = fh.read().strip()

import os
import re
import time
import json
import queue
import random
import logging
import threading
import traceback
from tqdm import tqdm
from collections import deque
from datetime import datetime, timezone
from google import genai
from typing import Optional, Dict, Any

# ---------------------------
# Logging
# ---------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("english_solution_perturber.log"),
        logging.StreamHandler()
    ]
)

# If keys were loaded above, set them into an environment variable for optional use by other code.
if 'keys' in globals() and keys:
    os.environ['GEMINI_API_KEYS'] = keys
    print("Loaded keys from Drive into environment variable GEMINI_API_KEYS.")
else:
    print("No keys loaded from Drive; will attempt to read GEMINI_API_KEYS environment variable.")

# ---------------------------
# Gemini API manager
# ---------------------------
class GeminiSolutionApiManager:
    """
    Manages multiple Gemini API keys with rotation and rate limiting for solution perturbation tasks.
    Thread-safe: uses a lock to guard key/usage state.
    """

    def __init__(self, api_keys, calls_per_day=1000, rate_limit_delay=4):
        if not api_keys:
            raise ValueError("api_keys must contain at least one key")

        self.api_keys = deque(api_keys)
        self.calls_per_day = calls_per_day
        self.rate_limit_delay = rate_limit_delay

        # Track usage for each key
        self.usage_count = {key: 0 for key in api_keys}
        self.current_key = self.api_keys[0]
        self.client = genai.Client(api_key=self.current_key)

        # Thread-safety
        self.lock = threading.Lock()

        # Set up a queue for API calls
        self.call_queue = queue.Queue()
        self.worker_thread = threading.Thread(target=self._process_queue, name="GeminiWorker")
        self.worker_thread.daemon = True
        self.worker_thread.start()

        logging.info(f"Solution API Manager initialized with {len(api_keys)} keys")

    def _rotate_key(self):
        """Rotate to the next available API key (thread-safe)."""
        with self.lock:
            self.api_keys.rotate(1)
            self.current_key = self.api_keys[0]
            self.client = genai.Client(api_key=self.current_key)
            usage = self.usage_count.get(self.current_key, 0)
        logging.info(f"Rotated to new API key (usage: {usage})")

    def _find_available_key(self):
        """Find an API key that hasn't reached the daily limit. Returns True if found."""
        with self.lock:
            if self.usage_count.get(self.current_key, 0) < self.calls_per_day:
                return True

        initial_key = self.current_key
        for _ in range(len(self.api_keys)):
            self._rotate_key()
            with self.lock:
                if self.usage_count.get(self.current_key, 0) < self.calls_per_day:
                    return True
            if self.current_key == initial_key:
                return False
        return False

    def _process_queue(self):
        """Process the queue of API calls in a background worker."""
        while True:
            try:
                args, kwargs, result_queue = self.call_queue.get()

                # Find available key
                if not self._find_available_key():
                    err = {"error": "All API keys have reached their daily limit"}
                    try:
                        result_queue.put(err)
                    except Exception:
                        logging.exception("Failed to notify caller about exhausted keys")
                    self.call_queue.task_done()
                    time.sleep(10)
                    continue

                try:
                    response = None
                    try:
                        # call into Gemini client (args and kwargs forwarded)
                        response = self.client.models.generate_content(*args, **kwargs)
                    except Exception as api_exc:
                        msg = str(api_exc).lower()
                        if 'quota' in msg or 'rate limit' in msg or 'quota_exceeded' in msg:
                            with self.lock:
                                self.usage_count[self.current_key] = self.calls_per_day
                            logging.warning(f"API key reached quota/rate-limit: {api_exc}")
                            result_queue.put({"error": f"Rate limit/quota: {api_exc}"})
                        else:
                            logging.error(f"API call error: {api_exc}")
                            result_queue.put({"error": str(api_exc)})
                        response = None

                    if response is not None:
                        result_queue.put({"response": response})
                        with self.lock:
                            self.usage_count[self.current_key] = self.usage_count.get(self.current_key, 0) + 1

                except Exception as e:
                    logging.error(f"Unexpected API invocation error: {e}\n{traceback.format_exc()}")
                    try:
                        result_queue.put({"error": str(e)})
                    except Exception:
                        logging.exception("Failed to send error to result queue")

                time.sleep(self.rate_limit_delay)
                self.call_queue.task_done()

            except Exception as e:
                logging.error(f"Queue processing error: {e}\n{traceback.format_exc()}")
                time.sleep(1)
                continue

    def generate_content(self, *args, timeout=60, **kwargs):
        """Make an API call to generate content, automatically handling key rotation."""
        result_queue = queue.Queue()
        self.call_queue.put((args, kwargs, result_queue))
        try:
            result = result_queue.get(timeout=timeout)
        except queue.Empty:
            raise TimeoutError("Timed out waiting for API worker result.")

        if "error" in result:
            raise Exception(result["error"])
        return result["response"]

    def reset_usage_counts(self):
        """Reset the usage counts for all keys (e.g., at the start of a new day)."""
        with self.lock:
            self.usage_count = {key: 0 for key in self.api_keys}
        logging.info("Reset API key usage counts")

    def get_usage_stats(self):
        """Get usage statistics for all keys."""
        with self.lock:
            per_key = dict(self.usage_count)
        total_used = sum(per_key.values())
        total_available = len(self.api_keys) * self.calls_per_day
        return {
            "per_key": per_key,
            "total_used": total_used,
            "total_available": total_available,
            "percent_used": (total_used / total_available) * 100 if total_available > 0 else 0
        }

# ---------------------------
# Robust Gemini JSON extraction
# ---------------------------

def parse_gemini_json_object(text: str) -> Optional[Dict[str, Any]]:
    """
    Robustly extract the first JSON object from Gemini output.
    1) Look for fenced code blocks (```json ... ``` and variants) and try each match.
    2) For each candidate try json.loads; if it fails, escape stray backslashes and retry.
    3) If no fenced block parses, fall back to a balanced-brace scan of the entire text.
    Returns a dict on success, or None on failure.
    """
    if not text:
        return None

    # Normalize whitespace a little
    text = text.strip()

    # Combined code-fence regex that covers:
    #  ```json\n ... \n```
    #  ```json ... ```
    #  ```\n ... \n```
    #  ``` ... ```
    fence_re = re.compile(r'```(?:json)?\s*\n?(.*?)\n?```', re.DOTALL | re.IGNORECASE)

    # Try fenced blocks first (if any). re.findall returns all non-overlapping matches.
    for match in fence_re.findall(text):
        candidate = match.strip()
        if not candidate:
            continue
        # Try parsing candidate directly
        try:
            parsed = json.loads(candidate)
            if isinstance(parsed, dict):
                return parsed
        except json.JSONDecodeError:
            # Escape stray backslashes that are NOT valid JSON escapes:
            # valid escapes: " \ / b f n r t u
            fixed = re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', candidate)
            try:
                parsed = json.loads(fixed)
                if isinstance(parsed, dict):
                    return parsed
            except json.JSONDecodeError:
                # try next fenced block (if any)
                continue

    # If no fenced block worked, try direct full-text parse (maybe the whole text is JSON)
    try:
        parsed_full = json.loads(text)
        if isinstance(parsed_full, dict):
            return parsed_full
    except Exception:
        pass

    # Fallback: scan for the first balanced {...} substring (robust to extra text around JSON)
    start = None
    depth = 0
    for i, ch in enumerate(text):
        if ch == '{':
            if start is None:
                start = i
            depth += 1
        elif ch == '}' and start is not None:
            depth -= 1
            if depth == 0:
                candidate = text[start:i+1]
                try:
                    parsed = json.loads(candidate)
                    if isinstance(parsed, dict):
                        return parsed
                except json.JSONDecodeError:
                    fixed = re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', candidate)
                    try:
                        parsed = json.loads(fixed)
                        if isinstance(parsed, dict):
                            return parsed
                    except json.JSONDecodeError:
                        # if this candidate fails, continue scanning (rare)
                        start = None
                        depth = 0
                        continue
    return None

# ---------------------------
# Enhanced Solution Perturbation Prompt Templates (3 Strategies)
# ---------------------------

STRATEGY_PROMPTS = {
    1: """You are a mathematics expert creating a sophisticated flawed solution by combining multiple error types.

TASK: Create a flawed solution combining THREE specific error types while maintaining the correct final answer:
1) STEP OMISSION: Skip one crucial step naturally
2) INCORRECT RULE/THEOREM: Apply one wrong mathematical rule confidently
3) FAULTY CAUSAL REASONING: Make incorrect cause-effect assumptions

Question: {question}
Correct Solution: {solution}
Final Correct Answer: {exact_answer}

🎯 CRITICAL INSTRUCTIONS:
- Write the solution in the SAME STYLE and linguistic pattern as the Correct Solution
- Seamlessly integrate all three error types without any meta-commentary
- Present everything confidently as if completely correct
- DO NOT mention or indicate that anything is wrong
- Final answer MUST exactly match: {exact_answer}
- Mimic the exact tone, structure, and presentation style of the Correct Solution

⚠️ The last line of the solution MUST be exactly: "{exact_answer}" — do not add, remove, or change any character (no extra spaces or newlines).
⚠️ THE SOLUTION SHOULD LOOK COMPLETELY AUTHORITATIVE - multiple errors should be naturally embedded.

EXAMPLES OF ERROR COMBINATIONS TO EMBED NATURALLY:
- Skip a verification step + use wrong formula + assume incorrect causation
- Omit algebraic step + apply wrong theorem + make faulty inference
- Skip condition check + wrong property + incorrect reasoning

REQUIRED JSON FORMAT:
{{"modified_solution": "Natural-sounding solution with three error types (no meta-commentary)", "error_description": "Brief 2-line summary of the three error types combined", "erroneous_part": "Locations where the three errors occur"}}""",

    2: """You are a mathematics expert creating a sophisticated flawed solution by combining two error types.

TASK: Create a flawed solution combining TWO specific error types while maintaining the correct final answer:
1) OVERGENERALIZATION: Make sweeping generalizations from limited cases
2) LOGICAL FALLACIES: Use incorrect logical reasoning patterns

Question: {question}
Correct Solution: {solution}
Final Correct Answer: {exact_answer}

🎯 CRITICAL INSTRUCTIONS:
- Write the solution in the SAME STYLE and linguistic pattern as the Correct Solution
- Seamlessly integrate both error types without any meta-commentary
- Present everything confidently as if completely correct
- DO NOT mention or indicate that anything is wrong
- Final answer MUST exactly match: {exact_answer}
- Mimic the exact tone, structure, and presentation style of the Correct Solution

⚠️ The last line of the solution MUST be exactly: "{exact_answer}" — do not add, remove, or change any character (no extra spaces or newlines).
⚠️ THE SOLUTION SHOULD LOOK COMPLETELY AUTHORITATIVE - both errors should be naturally embedded.

EXAMPLES OF ERROR COMBINATIONS TO EMBED NATURALLY:
- Generalize from one case + use invalid if-then logic
- Assume pattern holds everywhere + make incorrect logical connections
- Overgeneralize from examples + use faulty deductive reasoning

REQUIRED JSON FORMAT:
{{"modified_solution": "Natural-sounding solution with two error types (no meta-commentary)", "error_description": "Brief 2-line summary of the two error types combined", "erroneous_part": "Locations where the two errors occur"}}""",

    3: """You are a mathematics expert creating the most sophisticated flawed solution by combining all major error types.

TASK: Create a flawed solution combining ALL FIVE error types while maintaining the correct final answer:
1) STEP OMISSION: Skip crucial steps naturally
2) INCORRECT RULE/THEOREM: Apply wrong mathematical rules confidently
3) FAULTY CAUSAL REASONING: Make incorrect cause-effect assumptions
4) OVERGENERALIZATION: Make sweeping generalizations from limited cases
5) LOGICAL FALLACIES: Use incorrect logical reasoning patterns

Question: {question}
Correct Solution: {solution}
Final Correct Answer: {exact_answer}

🎯 CRITICAL INSTRUCTIONS:
- Write the solution in the SAME STYLE and linguistic pattern as the Correct Solution
- Seamlessly integrate all five error types without any meta-commentary
- Present everything confidently as if completely correct
- DO NOT mention or indicate that anything is wrong
- Final answer MUST exactly match: {exact_answer}
- Mimic the exact tone, structure, and presentation style of the Correct Solution

⚠️ The last line of the solution MUST be exactly: "{exact_answer}" — do not add, remove, or change any character (no extra spaces or newlines).
⚠️ THE SOLUTION SHOULD LOOK COMPLETELY AUTHORITATIVE - all errors should be naturally embedded.

EXAMPLES OF COMPREHENSIVE ERROR INTEGRATION:
- Skip verification + wrong formula + faulty inference + overgeneralize + invalid logic
- Omit steps + misapply theorem + incorrect causation + assume patterns + logical fallacies

REQUIRED JSON FORMAT:
{{"modified_solution": "Natural-sounding solution with all five error types (no meta-commentary)", "error_description": "Brief 2-line summary of the five error types combined", "erroneous_part": "Locations where the five errors occur"}}"""
}

# ---------------------------
# Enhanced Validation Functions
# ---------------------------

def validate_perturbed_solution(result, original_solution, exact_answer):
    """
    Enhanced validation to ensure perturbation quality.
    """
    if not result.get('success', False):
        return False, "Generation failed"

    modified_solution = result.get('modified_solution', '')
    error_description = result.get('error_description', '')
    erroneous_part = result.get('erroneous_part', '')

    # Basic content checks
    if not modified_solution or len(modified_solution.strip()) < 50:
        return False, "Modified solution too short or empty"

    if not error_description or len(error_description.strip()) < 10:
        return False, "Error description too brief"

    # Check error_description is exactly 2 lines
    # error_lines = error_description.strip().split('\n')
    # if len(error_lines) != 2:
    #     return False, f"Error description must be exactly 2 lines, got {len(error_lines)} lines"

    # # Check each line has meaningful content
    # for i, line in enumerate(error_lines, 1):
    #     if len(line.strip()) < 15:
    #         return False, f"Error description line {i} too brief (minimum 15 characters)"

    if not erroneous_part or len(erroneous_part.strip()) < 5:
        return False, "Erroneous part description too brief"

    return True, "Validation passed"

def get_strategy_quality_metrics(strategy_num, result):
    """
    Get quality metrics specific to each strategy type.
    """
    if not result.get('success', False):
        return {'quality_score': 0, 'issues': ['Generation failed']}

    modified_solution = result.get('modified_solution', '')
    error_description = result.get('error_description', '')

    issues = []
    quality_score = 100  # Start with perfect score, deduct for issues

    # Strategy-specific checks
    if strategy_num == 1:  # Step omission + Incorrect rule + Faulty reasoning
        required_keywords = ['omit', 'skip', 'rule', 'theorem', 'formula', 'reasoning', 'causal', 'inference']
        found_keywords = sum(1 for keyword in required_keywords if keyword.lower() in error_description.lower())
        if found_keywords < 3:
            issues.append('Error description does not clearly indicate all three error types')
            quality_score -= 25

    elif strategy_num == 2:  # Overgeneralization + Logical fallacies
        required_keywords = ['generaliz', 'pattern', 'logic', 'fallacy', 'reasoning', 'assumption']
        found_keywords = sum(1 for keyword in required_keywords if keyword.lower() in error_description.lower())
        if found_keywords < 2:
            issues.append('Error description does not clearly indicate both error types')
            quality_score -= 20

    elif strategy_num == 3:  # All five error types
        required_keywords = ['omit', 'skip', 'rule', 'theorem', 'reasoning', 'generaliz', 'logic', 'fallacy']
        found_keywords = sum(1 for keyword in required_keywords if keyword.lower() in error_description.lower())
        if found_keywords < 4:
            issues.append('Error description does not clearly indicate multiple error types')
            quality_score -= 30

    # General quality checks
    if len(modified_solution) < 100:
        issues.append('Solution too brief')
        quality_score -= 15

    if modified_solution.count('=') < 2:  # Should have mathematical content
        issues.append('Insufficient mathematical content')
        quality_score -= 10

    # Check for exactly 2 lines in error description
    error_lines = error_description.strip().split('\n')
    if len(error_lines) != 2:
        issues.append(f'Error description must be exactly 2 lines, got {len(error_lines)}')
        quality_score -= 20

    return {
        'quality_score': max(0, quality_score),
        'issues': issues
    }

# ---------------------------
# Enhanced Solution perturbation function
# ---------------------------

def generate_perturbed_solution(api_manager, question, solution, exact_answer, strategy_num, max_retries=3):
    """
    Enhanced generation function with better validation and retry logic.
    """
    if strategy_num not in STRATEGY_PROMPTS:
        return {
            'success': False,
            'error': f'Invalid strategy number: {strategy_num}',
            'modified_solution': f"[Error: Invalid strategy {strategy_num}]",
            'error_description': "Invalid strategy number\nStrategy must be 1, 2, or 3",
            'erroneous_part': "N/A",
            'quality_metrics': {'quality_score': 0, 'issues': ['Invalid strategy number']}
        }

    prompt = STRATEGY_PROMPTS[strategy_num].format(
        question=question,
        solution=solution,
        exact_answer=exact_answer
    )

    for attempt in range(max_retries):
        try:
            response = api_manager.generate_content(
                model="gemini-2.5-flash-lite",
                # model="gemini-2.5-flash",
                contents=prompt,
                timeout=120  # Increased timeout for complex tasks
            )

            # Get response text
            response_text = ""
            if hasattr(response, 'text'):
                response_text = response.text
            elif isinstance(response, dict) and 'text' in response:
                response_text = response['text']
            else:
                response_text = str(response)

            response_text = response_text.strip()
            logging.info(f"Strategy {strategy_num}, Attempt {attempt + 1}: Got response")

            # Parse JSON using robust parser
            parsed = parse_gemini_json_object(response_text)

            if parsed and isinstance(parsed, dict):
                required_keys = ['modified_solution', 'error_description', 'erroneous_part']
                if all(key in parsed for key in required_keys):
                    # Create preliminary result
                    result = {
                        'success': True,
                        'modified_solution': parsed['modified_solution'],
                        'error_description': parsed['error_description'],
                        'erroneous_part': parsed['erroneous_part'],
                        'raw_response': response_text,
                        'attempt': attempt + 1
                    }

                    # Validate the result
                    is_valid, validation_msg = validate_perturbed_solution(result, solution, exact_answer)

                    if is_valid:
                        # Add quality metrics
                        quality_metrics = get_strategy_quality_metrics(strategy_num, result)
                        result['quality_metrics'] = quality_metrics

                        logging.info(f"Strategy {strategy_num}, Attempt {attempt + 1}: Validation passed (Quality: {quality_metrics['quality_score']})")
                        return result
                    else:
                        logging.warning(f"Strategy {strategy_num}, Attempt {attempt + 1}: Validation failed - {validation_msg}")
                        if attempt < max_retries - 1:
                            continue  # Try again
                else:
                    logging.warning(f"Strategy {strategy_num}, Attempt {attempt + 1}: Missing required keys in JSON")

            if attempt < max_retries - 1:
                logging.info(f"Strategy {strategy_num}, Attempt {attempt + 1} failed, retrying...")
                time.sleep(2 + attempt)  # Increasing delay with attempts

        except Exception as e:
            logging.error(f"Strategy {strategy_num}, Attempt {attempt + 1} error: {e}")
            if attempt < max_retries - 1:
                time.sleep(5 + attempt * 2)  # Longer pause on error, increasing with attempts
            continue

    # If all attempts failed, return failure result
    return {
        'success': False,
        'error': f'Failed after {max_retries} attempts',
        'modified_solution': f"[Error: Strategy {strategy_num} application failed after {max_retries} attempts]",
        'error_description': f"API call or validation failed after {max_retries} attempts\nUnable to generate valid perturbed solution",
        'erroneous_part': "N/A",
        'quality_metrics': {'quality_score': 0, 'issues': [f'Failed after {max_retries} attempts']}
    }

# ---------------------------
# JSONL Processing
# ---------------------------

def get_field_case_insensitive(obj, *keys):
    """Get field value with case-insensitive key matching."""
    for key in keys:
        if key in obj:
            return obj[key]
        for present_key in obj.keys():
            if present_key.lower() == key.lower():
                return obj[present_key]
    return ''

def process_jsonl_file(api_manager, input_file_path, output_file_path):
    """Process JSONL file to add perturbed solutions with enhanced validation."""
    logging.info(f"Starting enhanced solution perturbation for: {input_file_path}")

    # Count total objects
    total_objects = 0
    with open(input_file_path, 'r', encoding='utf-8') as f:
        for _ in f:
            total_objects += 1

    logging.info(f"Found {total_objects} objects to process")

    # Processing counters
    processed_count = 0
    successful_perturbations = 0
    quality_scores = []

    with open(input_file_path, 'r', encoding='utf-8') as input_file, \
         open(output_file_path, 'w', encoding='utf-8') as output_file:

        for line_num, line in enumerate(tqdm(input_file, total=total_objects, desc="Generating Enhanced Perturbed Solutions"), 1):
            try:
                obj = json.loads(line.strip())

                # Extract required fields
                question = get_field_case_insensitive(obj, 'Question', 'question', 'QuestionText')
                solution = get_field_case_insensitive(obj, 'Solution', 'solution')
                exact_answer = get_field_case_insensitive(obj, 'Exact Answer', 'ExactAnswer', 'exact answer', 'Answer')

                if not all([question, solution, exact_answer]):
                    logging.warning(f"Line {line_num}: Missing required fields")
                    # Add placeholder data for missing fields
                    for i in range(1, 4):  # Changed from range(1, 7) to range(1, 4)
                        obj[f'Wrong_Solution_Strategy_{i}'] = f"[Error: Required fields missing]"
                        obj[f'Perturbation_Metadata_Strategy_{i}'] = {
                            'error_type': 'missing_fields',
                            'error_description': 'Required fields missing\nCannot generate perturbed solution',
                            'erroneous_part': 'N/A',
                            'success': False,
                            'quality_metrics': {'quality_score': 0, 'issues': ['Required fields missing']}
                        }
                else:
                    # Generate perturbed solutions for 3 strategies (changed from 6)
                    strategy_success_count = 0

                    for strategy_num in range(1, 4):  # Changed from range(1, 7) to range(1, 4)
                        logging.info(f"Processing Strategy {strategy_num} for line {line_num}")

                        result = generate_perturbed_solution(
                            api_manager, question, solution, exact_answer, strategy_num
                        )

                        # Store the perturbed solution
                        obj[f'Wrong_Solution_Strategy_{strategy_num}'] = result['modified_solution']

                        # Store enhanced metadata
                        obj[f'Perturbation_Metadata_Strategy_{strategy_num}'] = {
                            'error_type': f'strategy_{strategy_num}',
                            'error_description': result['error_description'],
                            'erroneous_part': result['erroneous_part'],
                            'success': result['success'],
                            'quality_metrics': result.get('quality_metrics', {'quality_score': 0, 'issues': ['No metrics available']}),
                            'attempt_count': result.get('attempt', 1)
                        }

                        if result['success']:
                            strategy_success_count += 1
                            quality_scores.append(result.get('quality_metrics', {}).get('quality_score', 0))

                        # Brief pause between strategies
                        time.sleep(1)

                    successful_perturbations += strategy_success_count

                # Write result (proper JSONL format)
                output_file.write(json.dumps(obj, ensure_ascii=False) + '\n')
                processed_count += 1

                # Progress logging with quality metrics (changed frequency)
                if processed_count % 10 == 0:  # Changed from every 5 to every 10 since we have fewer API calls
                    stats = api_manager.get_usage_stats()
                    avg_quality = sum(quality_scores) / len(quality_scores) if quality_scores else 0
                    logging.info(f"Progress: {processed_count}/{total_objects}, Total Successful: {successful_perturbations}")
                    logging.info(f"API Usage: {stats['total_used']}/{stats['total_available']} ({stats['percent_used']:.1f}%)")
                    logging.info(f"Average Quality Score: {avg_quality:.1f}")

            except Exception as e:
                logging.error(f"Line {line_num}: Processing error: {e}")
                traceback.print_exc()
                continue

    # Final statistics with quality metrics
    avg_quality = sum(quality_scores) / len(quality_scores) if quality_scores else 0
    print(f"Enhanced solution perturbation completed!")
    print(f"Total processed: {processed_count}")
    print(f"Total successful perturbations: {successful_perturbations}")
    print(f"Average successful perturbations per problem: {successful_perturbations/processed_count:.2f}" if processed_count > 0 else "No problems processed")
    print(f"Average quality score: {avg_quality:.1f}")
    print(f"Quality distribution: Min: {min(quality_scores) if quality_scores else 0:.1f}, Max: {max(quality_scores) if quality_scores else 0:.1f}")
    print(f"Output saved to: {output_file_path}")

# ---------------------------
# Main function
# ---------------------------

def main():
    """Main function to generate enhanced perturbed solutions for English math problems."""
    # Initialize API keys (require keys loaded from Drive or environment; fail fast if missing)
    api_keys = None

    # 1) Try to use `keys` variable loaded earlier (from Drive)
    if 'keys' in globals() and keys:
        api_keys = [k.strip() for k in keys.split(',') if k.strip()]

    # 2) Otherwise, try the environment variable GEMINI_API_KEYS
    if not api_keys:
        env_val = os.environ.get('GEMINI_API_KEYS', '').strip()
        if env_val:
            api_keys = [k.strip() for k in env_val.split(',') if k.strip()]

    # Fail fast if no keys found
    if not api_keys:
        err_msg = (
            "ERROR: No API keys found. Please create a text file at "
            "/content/drive/MyDrive/gemini_keys.txt containing a single line with keys "
            "separated by commas (e.g. key1,key2,...) or set the GEMINI_API_KEYS environment variable."
        )
        print(err_msg)
        raise SystemExit(err_msg)

    print(f"Initialized with {len(api_keys)} API keys")

    # Initialize enhanced API manager
    api_manager = GeminiSolutionApiManager(
        api_keys=api_keys,
        calls_per_day=1000,
        rate_limit_delay=4  # Reduced delay since we're making 3 calls per item instead of 6
    )

    # File paths (update as needed)
    input_jsonl_path = "/kaggle/input/nctb-dataset/English_Final_Corpus.jsonl"
    output_jsonl_path = "/kaggle/working/Enhanced_Perturbed_English.jsonl"

    if not os.path.exists(input_jsonl_path):
        print(f"Input file not found: {input_jsonl_path}")
        return

    print("Starting enhanced solution perturbation process...")
    print("Key improvements in this version:")
    print("- 3 Combined error strategies instead of 6 individual strategies")
    print("- Strategy 1: Step omission + Incorrect rules + Faulty reasoning")
    print("- Strategy 2: Overgeneralization + Logical fallacies")
    print("- Strategy 3: All five error types combined")
    print("- Error descriptions limited to exactly 2 meaningful lines")
    print("- Enhanced validation and quality metrics")

    try:
        process_jsonl_file(api_manager, input_jsonl_path, output_jsonl_path)

        # Final API usage summary
        final_stats = api_manager.get_usage_stats()
        print("\n" + "="*50)
        print("FINAL API USAGE SUMMARY")
        print("="*50)
        print(f"Total API calls made: {final_stats['total_used']}")
        print(f"Total API calls available: {final_stats['total_available']}")
        print(f"Utilization: {final_stats['percent_used']:.1f}%")
        print("Per-key usage:")
        for i, (key, usage) in enumerate(final_stats['per_key'].items(), 1):
            key_preview = key[:8] + "..." if len(key) > 8 else key
            print(f"  Key {i} ({key_preview}): {usage} calls")

    except Exception as e:
        print(f"Error processing file: {e}")
        logging.error(f"Error processing file: {e}")
        traceback.print_exc()

# ---------------------------
# Utility Functions for Analysis
# ---------------------------

def analyze_perturbation_results(jsonl_file_path):
    """
    Analyze the results of the perturbation process to provide insights.
    """
    print(f"Analyzing perturbation results from: {jsonl_file_path}")

    if not os.path.exists(jsonl_file_path):
        print(f"File not found: {jsonl_file_path}")
        return

    strategy_stats = {i: {'success': 0, 'total': 0, 'quality_scores': []} for i in range(1, 4)}  # Changed from range(1, 7)
    total_problems = 0

    with open(jsonl_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                obj = json.loads(line.strip())
                total_problems += 1

                for strategy_num in range(1, 4):  # Changed from range(1, 7)
                    metadata_key = f'Perturbation_Metadata_Strategy_{strategy_num}'
                    if metadata_key in obj:
                        metadata = obj[metadata_key]
                        strategy_stats[strategy_num]['total'] += 1

                        if metadata.get('success', False):
                            strategy_stats[strategy_num]['success'] += 1
                            quality_score = metadata.get('quality_metrics', {}).get('quality_score', 0)
                            strategy_stats[strategy_num]['quality_scores'].append(quality_score)

            except Exception as e:
                print(f"Error analyzing line: {e}")
                continue

    print("\n" + "="*60)
    print("PERTURBATION ANALYSIS RESULTS")
    print("="*60)
    print(f"Total problems processed: {total_problems}")

    strategy_names = {
        1: "Step Omission + Incorrect Rules + Faulty Reasoning",
        2: "Overgeneralization + Logical Fallacies",
        3: "All Five Error Types Combined"
    }

    for strategy_num in range(1, 4):  # Changed from range(1, 7)
        stats = strategy_stats[strategy_num]
        success_rate = (stats['success'] / stats['total'] * 100) if stats['total'] > 0 else 0
        avg_quality = sum(stats['quality_scores']) / len(stats['quality_scores']) if stats['quality_scores'] else 0

        print(f"\nStrategy {strategy_num}: {strategy_names[strategy_num]}")
        print(f"  Success Rate: {success_rate:.1f}% ({stats['success']}/{stats['total']})")
        print(f"  Average Quality Score: {avg_quality:.1f}")
        if stats['quality_scores']:
            print(f"  Quality Range: {min(stats['quality_scores']):.1f} - {max(stats['quality_scores']):.1f}")

    # Overall statistics
    total_successful = sum(stats['success'] for stats in strategy_stats.values())
    total_attempts = sum(stats['total'] for stats in strategy_stats.values())
    overall_success_rate = (total_successful / total_attempts * 100) if total_attempts > 0 else 0

    all_quality_scores = []
    for stats in strategy_stats.values():
        all_quality_scores.extend(stats['quality_scores'])

    overall_avg_quality = sum(all_quality_scores) / len(all_quality_scores) if all_quality_scores else 0

    print(f"\n" + "-"*40)
    print("OVERALL STATISTICS")
    print("-"*40)
    print(f"Total perturbation attempts: {total_attempts}")
    print(f"Total successful perturbations: {total_successful}")
    print(f"Overall success rate: {overall_success_rate:.1f}%")
    print(f"Overall average quality score: {overall_avg_quality:.1f}")
    print(f"Average successful perturbations per problem: {total_successful/total_problems:.2f}" if total_problems > 0 else "No problems found")

def sample_perturbations(jsonl_file_path, num_samples=3):
    """
    Display sample perturbations for manual inspection.
    """
    print(f"\nSampling {num_samples} perturbations from: {jsonl_file_path}")

    if not os.path.exists(jsonl_file_path):
        print(f"File not found: {jsonl_file_path}")
        return

    samples_collected = 0

    with open(jsonl_file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if samples_collected >= num_samples:
                break

            try:
                obj = json.loads(line.strip())

                # Check if this problem has successful perturbations
                has_successful = False
                for strategy_num in range(1, 4):  # Changed from range(1, 7)
                    metadata_key = f'Perturbation_Metadata_Strategy_{strategy_num}'
                    if (metadata_key in obj and
                        obj[metadata_key].get('success', False)):
                        has_successful = True
                        break

                if has_successful:
                    samples_collected += 1
                    question = get_field_case_insensitive(obj, 'Question', 'question', 'QuestionText')
                    solution = get_field_case_insensitive(obj, 'Solution', 'solution')
                    exact_answer = get_field_case_insensitive(obj, 'Exact Answer', 'ExactAnswer', 'exact answer', 'Answer')

                    print(f"\n" + "="*80)
                    print(f"SAMPLE {samples_collected} (Line {line_num})")
                    print("="*80)
                    print(f"Question: {question[:200]}..." if len(question) > 200 else f"Question: {question}")
                    print(f"Exact Answer: {exact_answer}")

                    # Show one successful perturbation
                    for strategy_num in range(1, 4):  # Changed from range(1, 7)
                        metadata_key = f'Perturbation_Metadata_Strategy_{strategy_num}'
                        solution_key = f'Wrong_Solution_Strategy_{strategy_num}'

                        if (metadata_key in obj and solution_key in obj and
                            obj[metadata_key].get('success', False)):

                            metadata = obj[metadata_key]
                            perturbed_solution = obj[solution_key]
                            quality_score = metadata.get('quality_metrics', {}).get('quality_score', 0)

                            strategy_names = {
                                1: "Multi-Error (Step + Rule + Reasoning)",
                                2: "Generalization + Logic Fallacies",
                                3: "All Five Error Types"
                            }

                            print(f"\nStrategy {strategy_num} - {strategy_names[strategy_num]} (Quality: {quality_score:.1f}):")
                            print(f"Error Description: {metadata.get('error_description', 'N/A')}")
                            print(f"Perturbed Solution: {perturbed_solution[:300]}..." if len(perturbed_solution) > 300 else f"Perturbed Solution: {perturbed_solution}")
                            break  # Show only one strategy per sample

            except Exception as e:
                print(f"Error processing sample from line {line_num}: {e}")
                continue

if __name__ == "__main__":
    main()

    # Optionally run analysis after processing
    output_file = "/kaggle/working/Enhanced_Perturbed_English.jsonl"
    if os.path.exists(output_file):
        print("\n" + "="*60)
        print("Running post-processing analysis...")
        analyze_perturbation_results(output_file)
        sample_perturbations(output_file, num_samples=2)