In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#!/usr/bin/env python3
"""
Refactored Gemini Language Classification Tool (single-file)
- Uses a single API key path / env variable (GEMINI_API_KEY)
- More comprehensive, strict, flat-JSON prompt (no arrays/nested objects)
- Parses semicolon-separated phrase fields returned by the model
- Preserves API key rotation + background worker behavior
- NOW INCLUDES CHINESE LANGUAGE DETECTION AND COMPREHENSIVE LANGUAGE ANALYSIS
- ENHANCED: Better markdown and LLM output pattern removal

Usage: edit paths in `main()` and run.
"""

import os
import re
import time
import json
import queue
import logging
import threading
import traceback
from tqdm import tqdm
from collections import deque, defaultdict
from datetime import datetime
from typing import Optional, Dict, Any, List, Tuple
from string import Template

try:
    from google import genai  # type: ignore
except Exception:
    genai = None

# ---------------------------
# Defaults / Configuration
# ---------------------------
DEFAULT_MODEL = "gemini-2.5-flash-lite"
DEFAULT_RATE_LIMIT = 5.0
DEFAULT_CALLS_PER_DAY = 1000

# ---------------------------
# Logging Configuration
# ---------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler('language_classification_gemini.log'), logging.StreamHandler()],
)

# ---------------------------
# Load API Keys (single path)
# ---------------------------

def load_api_keys(key_path: Optional[str] = None) -> List[str]:
    """Load a single API key from a file or the GEMINI_API_KEY environment variable."""
    if key_path:
        if os.path.exists(key_path):
            try:
                with open(key_path, 'r', encoding='utf-8') as fh:
                    raw = fh.read().strip()
                if not raw:
                    raise ValueError(f"Key file {key_path} is empty")
                keys = [k.strip() for k in re.split(r'[\n,]+', raw) if k.strip()]
                if not keys:
                    raise ValueError(f"No usable keys found in {key_path}")
                if len(keys) > 1:
                    logging.warning(f"Multiple keys found in {key_path}; using the first one")
                logging.info(f"Loaded API key from {key_path}")
                return [keys[0]]
            except Exception as e:
                logging.warning(f"Failed to load key from {key_path}: {e}")
        else:
            logging.warning(f"Provided key path does not exist: {key_path}")

    env_key = os.environ.get('GEMINI_API_KEY', '').strip()
    if env_key:
        logging.info("Loaded API key from GEMINI_API_KEY environment variable")
        return [env_key]

    raise ValueError("No API key found. Provide a key file path or set GEMINI_API_KEY")

# ---------------------------
# Gemini API Manager
# ---------------------------
class GeminiLanguageApiManager:
    """Manage multiple (or single) Gemini API keys with rotation and a simple rate limiter."""

    def __init__(self, api_keys: List[str], calls_per_day: int = DEFAULT_CALLS_PER_DAY, rate_limit_delay: float = DEFAULT_RATE_LIMIT):
        if not api_keys:
            raise ValueError("api_keys must contain at least one key")

        self.api_keys = deque(api_keys)
        self.calls_per_day = calls_per_day
        self.rate_limit_delay = rate_limit_delay

        self.usage_count: Dict[str, int] = {k: 0 for k in api_keys}
        self.current_key = self.api_keys[0]

        if genai is None:
            logging.warning("google.genai package not available. API calls will fail until genai is installed.")
            self.client = None
        else:
            self.client = genai.Client(api_key=self.current_key)

        self.lock = threading.Lock()

        self.call_queue: queue.Queue = queue.Queue()
        self.worker_thread = threading.Thread(target=self._process_queue, name="GeminiWorker")
        self.worker_thread.daemon = True
        self.worker_thread.start()

        logging.info(f"Language API Manager initialized with {len(api_keys)} key(s)")

    def _rotate_key(self) -> None:
        with self.lock:
            self.api_keys.rotate(1)
            self.current_key = self.api_keys[0]
            if genai is not None:
                self.client = genai.Client(api_key=self.current_key)
            usage = self.usage_count.get(self.current_key, 0)
        logging.info(f"Rotated to new API key (usage: {usage})")

    def _find_available_key(self) -> bool:
        with self.lock:
            if self.usage_count.get(self.current_key, 0) < self.calls_per_day:
                return True

        initial = self.current_key
        for _ in range(len(self.api_keys)):
            self._rotate_key()
            with self.lock:
                if self.usage_count.get(self.current_key, 0) < self.calls_per_day:
                    return True
            if self.current_key == initial:
                return False
        return False

    def _process_queue(self) -> None:
        while True:
            try:
                args, kwargs, result_queue = self.call_queue.get()

                if not self._find_available_key():
                    try:
                        result_queue.put({"error": "All API keys have reached their daily limit"})
                    except Exception:
                        logging.exception("Failed to notify caller about exhausted keys")
                    self.call_queue.task_done()
                    time.sleep(10)
                    continue

                try:
                    if self.client is None:
                        raise RuntimeError("genai.Client not initialized (missing genai package or client)")

                    response = None
                    try:
                        response = self.client.models.generate_content(*args, **kwargs)
                    except Exception as api_exc:
                        msg = str(api_exc).lower()
                        if 'quota' in msg or 'rate limit' in msg or 'quota_exceeded' in msg:
                            with self.lock:
                                self.usage_count[self.current_key] = self.calls_per_day
                            logging.warning(f"API key reached quota/rate-limit: {api_exc}")
                            result_queue.put({"error": f"Rate limit/quota: {api_exc}"})
                        else:
                            logging.error(f"API call error: {api_exc}")
                            result_queue.put({"error": str(api_exc)})
                        response = None

                    if response is not None:
                        result_queue.put({"response": response})
                        with self.lock:
                            self.usage_count[self.current_key] = self.usage_count.get(self.current_key, 0) + 1

                except Exception as e:
                    logging.error(f"Unexpected API invocation error: {e}\n{traceback.format_exc()}")
                    try:
                        result_queue.put({"error": str(e)})
                    except Exception:
                        logging.exception("Failed to send error to result queue")

                time.sleep(self.rate_limit_delay)
                self.call_queue.task_done()

            except Exception as e:
                logging.error(f"Queue processing error: {e}\n{traceback.format_exc()}")
                time.sleep(1)
                continue

    def generate_content(self, *args, timeout: float = 60.0, **kwargs) -> Any:
        result_queue: queue.Queue = queue.Queue()
        self.call_queue.put((args, kwargs, result_queue))
        try:
            result = result_queue.get(timeout=timeout)
        except queue.Empty:
            raise TimeoutError("Timed out waiting for API worker result.")

        if "error" in result:
            raise Exception(result["error"])
        return result["response"]

    def get_usage_stats(self) -> Dict[str, Any]:
        with self.lock:
            per_key = dict(self.usage_count)
        total_used = sum(per_key.values())
        total_available = len(self.api_keys) * self.calls_per_day
        return {
            "per_key": per_key,
            "total_used": total_used,
            "total_available": total_available,
            "percent_used": (total_used / total_available) * 100 if total_available > 0 else 0,
        }

# ---------------------------
# Text Cleaning & Parsing
# ---------------------------
COT_PHRASES = [
    "step 1", "step 2", "step 3", "step 4", "step 5", "step 6", "step 7", "step 8", "step 9", "step 10",
    "step-by-step", "step-by-step solution", "step by step", "step",
    "solution", "final answer", "final", "proved",
    "verification", "verify",
    "problem understanding", "Problem Understanding", "PROBLEM UNDERSTANDING", 
    "problem statement", "Problem statement", "Problem Statement", "PROBLEM STATEMENT",
    "mathematical analysis", "Mathematical Analysis", "MATHEMATICAL ANALYSIS",
    "in bangla", "In Bangla",
    "analysis", "problem", "understanding", "approach", "method", "calculation", "result",
    "conclusion", "proof", "given", "find", "solve", "answer", "check",
    "first", "second", "third", "next", "then", "finally", "lastly",
    "mathematical", "formula", "equation", "expression", "value", "number",
    "let", "assume", "suppose", "consider", "note", "observe", "recall",
    "definition", "theorem", "lemma", "corollary", "proposition",
]


def clean_text_for_language_analysis(text: str, cot_phrases: Optional[List[str]] = None) -> str:
    """Clean text by removing math, CoT phrases, and structural elements."""
    if not text:
        return ""

    if cot_phrases is None:
        cot_phrases = COT_PHRASES

    # Remove <think> tags first
    text = re.sub(r'<think>.*?</think>', ' ', text, flags=re.DOTALL | re.IGNORECASE)
    
    # Remove <final> tags
    text = re.sub(r'<final>.*?</final>', ' ', text, flags=re.DOTALL | re.IGNORECASE)
    
    # Remove markdown bold/italic formatting (enhanced)
    text = re.sub(r'\*\*\*([^*]+)\*\*\*', r'\1', text)  # ***bold+italic***
    text = re.sub(r'\*\*([^*]+)\*\*', r'\1', text)      # **bold**
    text = re.sub(r'\*([^*]+)\*', r'\1', text)          # *italic*
    text = re.sub(r'___([^_]+)___', r'\1', text)        # ___bold+italic___
    text = re.sub(r'__([^_]+)__', r'\1', text)          # __bold__
    text = re.sub(r'_([^_]+)_', r'\1', text)            # _italic_
    
    # Remove strikethrough
    text = re.sub(r'~~([^~]+)~~', r'\1', text)          # ~~strikethrough~~
    
    # Remove inline code
    text = re.sub(r'`([^`]+)`', r'\1', text)            # `code`
    
    # Remove code blocks (triple backticks)
    text = re.sub(r'```[a-z]*\n.*?\n```', ' ', text, flags=re.DOTALL)
    
    # Remove markdown headers
    text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
    
    # Remove markdown links [text](url)
    text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text)
    
    # Remove markdown images ![alt](url)
    text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', '', text)
    
    # Remove HTML tags (sometimes used by LLMs)
    text = re.sub(r'<[^>]+>', ' ', text)
    
    # Remove blockquotes
    text = re.sub(r'^>\s+', '', text, flags=re.MULTILINE)
    
    # Remove horizontal rules
    text = re.sub(r'^[-*_]{3,}\s*$', '', text, flags=re.MULTILINE)
    
    # Remove bullet points and numbered lists
    text = re.sub(r'^\s*[\*\-\+]\s+', '', text, flags=re.MULTILINE)
    text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
    
    # Remove common section headers (case insensitive, with optional colons/parentheses)
    section_headers = [
        r'problem\s+statement',
        r'problem\s+understanding',
        r'mathematical\s+analysis',
        r'step[-\s]by[-\s]step\s+solution',
        r'verification',
        r'final\s+answer',
        r'conclusion',
        r'approach',
        r'solution',
    ]
    
    for header in section_headers:
        pattern = r'\b' + header + r'\s*[:\(\)]?\s*'
        text = re.sub(pattern, ' ', text, flags=re.IGNORECASE)
    
    # Remove mathematical content
    text = re.sub(r'\$.*?\$', ' ', text)
    text = re.sub(r'\\\[.*?\\\]', ' ', text, flags=re.DOTALL)
    text = re.sub(r'\\\(.*?\\\)', ' ', text)
    text = re.sub(r'\\begin\{.*?\}.*?\\end\{.*?\}', ' ', text, flags=re.DOTALL)
    text = re.sub(r'\\[a-zA-Z]+\{[^}]*\}', ' ', text)
    text = re.sub(r'\\[a-zA-Z]+', ' ', text)

    # Remove numbers and mathematical symbols
    text = re.sub(r'[0-9০-৯]+', ' ', text)
    text = re.sub(r'[+\-*/=<>≤≥≠∑∏∫∂∇±×÷√∞θπ]', ' ', text)
    text = re.sub(r'[()\[\]{}|]', ' ', text)

    # Remove CoT phrases
    for phrase in cot_phrases:
        pattern = r'\b' + re.escape(phrase) + r'[:,.]?\s*'
        text = re.sub(pattern, ' ', text, flags=re.IGNORECASE)

    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def parse_gemini_json_response(text: str) -> Optional[Dict[str, Any]]:
    """Attempt to extract a JSON object from a model response."""
    if not text:
        return None
    text = text.strip()

    # Try to find JSON inside ``` fences
    fence_re = re.compile(r'```(?:json)?\s*\n?(.*?)\n?```', re.DOTALL | re.IGNORECASE)
    for match in fence_re.findall(text):
        candidate = match.strip()
        if not candidate:
            continue
        try:
            parsed = json.loads(candidate)
            if isinstance(parsed, dict):
                return parsed
        except json.JSONDecodeError:
            fixed = re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', candidate)
            try:
                parsed = json.loads(fixed)
                if isinstance(parsed, dict):
                    return parsed
            except json.JSONDecodeError:
                continue

    # Try raw text
    try:
        parsed = json.loads(text)
        if isinstance(parsed, dict):
            return parsed
    except Exception:
        pass

    # Fallback: find the first balanced top-level JSON object
    start = None
    depth = 0
    for i, ch in enumerate(text):
        if ch == '{':
            if start is None:
                start = i
            depth += 1
        elif ch == '}' and start is not None:
            depth -= 1
            if depth == 0:
                candidate = text[start:i+1]
                try:
                    parsed = json.loads(candidate)
                    if isinstance(parsed, dict):
                        return parsed
                except json.JSONDecodeError:
                    fixed = re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', candidate)
                    try:
                        parsed = json.loads(fixed)
                        if isinstance(parsed, dict):
                            return parsed
                    except json.JSONDecodeError:
                        start = None
                        depth = 0
                        continue

    return None


def extract_language_classification(response_text: str) -> Tuple[Optional[str], Optional[float], List[str], List[str], List[str], List[str], str]:
    """Extract classification fields from a parsed JSON object or via heuristics."""
    if not response_text:
        return None, None, [], [], [], [], ""

    obj = parse_gemini_json_response(response_text)
    if obj:
        language = None
        conf = None
        bangla_phrases: List[str] = []
        english_phrases: List[str] = []
        chinese_phrases: List[str] = []
        excluded_cot_phrases: List[str] = []
        mixed_languages_present = ""

        # language detection
        for key in ("language", "Language", "LANGUAGE", "classification", "Classification"):
            if key in obj:
                lang_val = str(obj[key]).strip().lower()
                if lang_val in ("bangla", "english", "chinese", "mixed"):
                    language = lang_val
                break

        # confidence
        for key in ("confidence", "conf", "confidence_score"):
            if key in obj:
                try:
                    conf_val = obj[key]
                    if isinstance(conf_val, str):
                        conf_val = conf_val.strip()
                        conf_val = float(conf_val)
                    else:
                        conf_val = float(conf_val)

                    if 0.0 <= conf_val <= 1.0:
                        conf = conf_val
                    elif 0.0 <= conf_val <= 100.0:
                        conf = conf_val / 100.0
                except Exception:
                    conf = None
                break

        # Mixed languages present
        for key in ("mixed_languages_present", "mixedLanguagesPresent", "languages_present"):
            if key in obj:
                mixed_val = obj[key]
                if isinstance(mixed_val, str):
                    mixed_languages_present = mixed_val.strip()
                elif isinstance(mixed_val, list):
                    mixed_languages_present = ";".join([str(x).strip() for x in mixed_val if str(x).strip()])
                break

        # Phrase fields
        def _to_list(val: Any) -> List[str]:
            if val is None:
                return []
            if isinstance(val, list):
                return [str(x).strip() for x in val if str(x).strip()]
            if isinstance(val, str):
                parts = [p.strip() for p in val.split(';') if p.strip()]
                return parts
            return [str(val).strip()]

        bangla_phrases = _to_list(obj.get('bangla_phrases') or obj.get('Bangla_phrases') or obj.get('banglaPhrases'))
        english_phrases = _to_list(obj.get('english_phrases') or obj.get('English_phrases') or obj.get('englishPhrases'))
        chinese_phrases = _to_list(obj.get('chinese_phrases') or obj.get('Chinese_phrases') or obj.get('chinesePhrases'))
        excluded_cot_phrases = _to_list(obj.get('excluded_cot_phrases') or obj.get('excludedCOT') or obj.get('excluded_cot'))

        # Filter out CoT phrases
        def normalize_phrase(s: str) -> str:
            s = s or ''
            s = s.lower()
            s = re.sub(r'[^a-z0-9\s]', ' ', s)
            s = re.sub(r'[\s\-_]+', ' ', s).strip()
            return s

        cot_norms = {normalize_phrase(p) for p in COT_PHRASES}

        def filter_out_cot(lst: List[str]) -> List[str]:
            out: List[str] = []
            for ph in lst:
                if not ph:
                    continue
                if normalize_phrase(ph) in cot_norms:
                    continue
                out.append(ph)
            return out

        bangla_phrases = filter_out_cot(bangla_phrases)
        english_phrases = filter_out_cot(english_phrases)
        chinese_phrases = filter_out_cot(chinese_phrases)
        excluded_cot_phrases = filter_out_cot(excluded_cot_phrases)

        if language:
            return language, conf, bangla_phrases, english_phrases, chinese_phrases, excluded_cot_phrases, mixed_languages_present

    # Fallback pattern matching
    patterns = [
        r'Classification:\s*(bangla|english|chinese|mixed)',
        r'Language:\s*(bangla|english|chinese|mixed)',
        r'Result:\s*(bangla|english|chinese|mixed)',
        r'(bangla|english|chinese|mixed)\s*\(confidence:\s*([\d.]+)\)',
    ]

    for pattern in patterns:
        match = re.search(pattern, response_text, re.IGNORECASE)
        if match:
            lang = match.group(1).lower()
            conf = None
            if len(match.groups()) > 1:
                try:
                    conf = float(match.group(2))
                    if conf > 1.0:
                        conf = conf / 100.0
                except Exception:
                    pass
            return lang, conf, [], [], [], [], ""

    ru = response_text.upper()
    if "CHINESE" in ru and "BANGLA" not in ru and "ENGLISH" not in ru:
        return "chinese", None, [], [], [], [], ""
    if "BANGLA" in ru and "ENGLISH" not in ru and "CHINESE" not in ru:
        return "bangla", None, [], [], [], [], ""
    if "ENGLISH" in ru and "BANGLA" not in ru and "CHINESE" not in ru:
        return "english", None, [], [], [], [], ""
    if "MIXED" in ru:
        return "mixed", None, [], [], [], [], ""

    return None, None, [], [], [], [], ""


def get_heuristic_classification(text: str, cot_phrases: Optional[List[str]] = None) -> str:
    """Heuristic classification with Chinese detection."""
    try:
        cleaned = clean_text_for_language_analysis(text, cot_phrases)
        has_bangla = bool(re.search(r'[\u0980-\u09FF]', cleaned))
        has_english = bool(re.search(r'[a-zA-Z]{2,}', cleaned))
        has_chinese = bool(re.search(r'[\u4e00-\u9fff]', cleaned))
        
        lang_count = sum([has_bangla, has_english, has_chinese])
        
        if lang_count > 1:
            return 'mixed'
        elif has_chinese:
            return 'chinese'
        elif has_bangla:
            return 'bangla'
        else:
            return 'english'
    except Exception:
        return 'english'


def detect_all_languages_in_text(text: str) -> List[str]:
    """Detect all languages present in text and return them as a list of language names."""
    if not text:
        return []
    
    # Clean text to remove mathematical notation
    cleaned = clean_text_for_language_analysis(text, COT_PHRASES)
    
    languages_found = []
    
    # Check for Bangla (Bengali script)
    if re.search(r'[\u0980-\u09FF]', cleaned):
        languages_found.append('Bangla')
    
    # Check for English (Latin alphabet)
    if re.search(r'[a-zA-Z]{2,}', cleaned):
        languages_found.append('English')
    
    # Check for Chinese (CJK Unified Ideographs)
    if re.search(r'[\u4e00-\u9fff]', cleaned):
        languages_found.append('Chinese')
    
    # Check for Arabic
    if re.search(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]', cleaned):
        languages_found.append('Arabic')
    
    # Check for Japanese Hiragana/Katakana
    if re.search(r'[\u3040-\u309F\u30A0-\u30FF]', cleaned):
        languages_found.append('Japanese')
    
    # Check for Korean Hangul
    if re.search(r'[\uAC00-\uD7AF\u1100-\u11FF\u3130-\u318F]', cleaned):
        languages_found.append('Korean')
    
    # Check for Thai
    if re.search(r'[\u0E00-\u0E7F]', cleaned):
        languages_found.append('Thai')
    
    # Check for Cyrillic (Russian, etc.)
    if re.search(r'[\u0400-\u04FF]', cleaned):
        languages_found.append('Cyrillic')
    
    # Check for Greek
    if re.search(r'[\u0370-\u03FF]', cleaned):
        languages_found.append('Greek')
    
    # If no languages detected, return Unknown
    if not languages_found:
        languages_found.append('Unknown')
    
    return languages_found


def detect_actual_languages(text: str) -> Dict[str, bool]:
    """Detect actual presence of Bangla, English, and Chinese in text (after cleaning)."""
    cleaned = clean_text_for_language_analysis(text, COT_PHRASES)
    
    has_bangla = bool(re.search(r'[\u0980-\u09FF]', cleaned))
    has_english = bool(re.search(r'[a-zA-Z]{2,}', cleaned))
    has_chinese = bool(re.search(r'[\u4e00-\u9fff]', cleaned))
    
    return {
        'bangla': has_bangla,
        'english': has_english,
        'chinese': has_chinese
    }


def get_language_combination_label(langs: Dict[str, bool]) -> str:
    """Convert language presence dict to a readable combination string."""
    present = [lang.capitalize() for lang, is_present in sorted(langs.items()) if is_present]
    if len(present) == 0:
        return "Unknown"
    elif len(present) == 1:
        return present[0] + " Only"
    else:
        return "+".join(present)


def analyze_mixed_compositions(classification_results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Analyze the composition of problems classified as 'mixed'."""
    from collections import Counter
    
    mixed_problems = [r for r in classification_results if r.get('language_classification', {}).get('language') == 'mixed']
    
    if not mixed_problems:
        return {
            'total_mixed': 0,
            'combinations': {},
            'chinese_in_mixed': 0,
            'all_languages_found': [],
            'details': []
        }
    
    combination_counter = Counter()
    all_languages_set = set()
    chinese_count = 0
    mixed_details = []
    
    for problem in mixed_problems:
        generated_answer = problem.get('generated_answer', '')
        lang_class = problem.get('language_classification', {})
        
        # Detect all languages present
        languages_list = detect_all_languages_in_text(str(generated_answer))
        all_languages_set.update(languages_list)
        
        # Create combination label
        if languages_list:
            combination = "+".join(sorted(languages_list))
        else:
            combination = "Unknown"
        
        combination_counter[combination] += 1
        
        # Check for Chinese specifically
        if 'Chinese' in languages_list:
            chinese_count += 1
        
        mixed_details.append({
            'problem_index': problem.get('problem_index', 'unknown'),
            'languages': languages_list,
            'combination': combination,
            'confidence': lang_class.get('confidence'),
            'text_preview': str(generated_answer)[:200] if generated_answer else ''
        })
    
    return {
        'total_mixed': len(mixed_problems),
        'combinations': dict(combination_counter),
        'all_languages_found': sorted(list(all_languages_set)),
        'chinese_in_mixed': chinese_count,
        'chinese_percentage': round((chinese_count / len(mixed_problems) * 100), 2) if mixed_problems else 0,
        'details': mixed_details
    }


def analyze_all_language_presence(classification_results: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Analyze actual language presence across all classifications."""
    classification_analysis = {
        'bangla': {'bangla': 0, 'english': 0, 'chinese': 0, 'total': 0},
        'english': {'bangla': 0, 'english': 0, 'chinese': 0, 'total': 0},
        'chinese': {'bangla': 0, 'english': 0, 'chinese': 0, 'total': 0},
        'mixed': {'bangla': 0, 'english': 0, 'chinese': 0, 'total': 0}
    }
    
    for problem in classification_results:
        lang_class = problem.get('language_classification', {})
        classified_as = lang_class.get('language', '')
        generated_answer = problem.get('generated_answer', '')
        
        if classified_as in classification_analysis:
            classification_analysis[classified_as]['total'] += 1
            langs = detect_actual_languages(str(generated_answer))
            for lang, present in langs.items():
                if present:
                    classification_analysis[classified_as][lang] += 1
    
    return classification_analysis


# ---------------------------
# Prompt Template
# ---------------------------
LANGUAGE_CLASSIFICATION_PROMPT = Template(r"""You are an expert language classifier specializing in distinguishing between Bangla, English, and Chinese text in mathematical solutions.

TASK: Decide whether the provided solution text is written in Bangla, English, Chinese, or Mixed.

IMPORTANT RULES (follow exactly):
- Respond with ONLY one plain JSON object (no Markdown, no code fences, no extra text).
- The JSON must be FLAT (no arrays/lists and no nested objects/dicts).
- For phrase fields (bangla_phrases, english_phrases, chinese_phrases, excluded_cot_phrases) return a single string containing phrases separated by semicolons (`;`). If none, return an empty string.
- Allowed keys (exact): "language", "confidence", "bangla_phrases", "english_phrases", "chinese_phrases", "excluded_cot_phrases", "mixed_languages_present". Do NOT include extra keys.
- "language" must be one of: "bangla", "english", "chinese", "mixed" (lowercase).
- "confidence" must be a number between 0.0 and 1.0 (use 0.0-1.0). If unknown, return 0.0.
- "mixed_languages_present" should be a semicolon-separated string of languages found ONLY if language is "mixed". For single-language classifications, return empty string. Examples: "Bangla;English" or "Chinese;English;Bangla" or "".
- Do not include explanatory text or commentary.

LANGUAGE DETECTION GUIDELINES:
- Bangla: Uses Bengali script (Unicode range U+0980-U+09FF)
- English: Uses Latin alphabet (a-z, A-Z)
- Chinese: Uses Chinese characters (CJK Unified Ideographs, e.g., 这是中文)
- Mixed: Contains 2 or more languages - YOU MUST specify which languages in "mixed_languages_present"

EXCLUSIONS (ignore while deciding language): mathematical notation, digits, CoT phrases, and <think> content.

ORIGINAL SOLUTION TEXT:
$original_text

CLEANED TEXT FOR ANALYSIS:
$cleaned_text

Example valid responses:

For single language:
{"language":"bangla","confidence":0.95,"bangla_phrases":"প্রথমে আমরা","english_phrases":"","chinese_phrases":"","excluded_cot_phrases":"","mixed_languages_present":""}

For mixed language:
{"language":"mixed","confidence":0.88,"bangla_phrases":"আমরা এই সমস্যা","english_phrases":"We need to find","chinese_phrases":"首先分析","excluded_cot_phrases":"step 1","mixed_languages_present":"Bangla;English;Chinese"}

JSON ONLY.""")

# ---------------------------
# Classification Logic
# ---------------------------

def classify_language_with_gemini(api_manager: GeminiLanguageApiManager, generated_answer: Any, problem_index: int, model_name: str = DEFAULT_MODEL) -> Dict[str, Any]:
    """Classify the language of a generated answer using Gemini API."""
    try:
        if isinstance(generated_answer, list):
            generated_answer_text = str(generated_answer[0]) if generated_answer else ""
        else:
            generated_answer_text = str(generated_answer or "")

        if not generated_answer_text:
            return {
                'response_text': 'Input validation failed',
                'language': 'english',
                'confidence': None,
                'bangla_phrases': [],
                'english_phrases': [],
                'chinese_phrases': [],
                'excluded_cot_phrases': [],
                'mixed_languages_present': '',
            }

        cleaned_text = clean_text_for_language_analysis(generated_answer_text, COT_PHRASES)

        try:
            prompt = LANGUAGE_CLASSIFICATION_PROMPT.safe_substitute(
                original_text=generated_answer_text[:4000],
                cleaned_text=cleaned_text[:2000]
            )
        except Exception as e:
            logging.error(f"Prompt substitution failed for problem {problem_index}: {e}")
            prompt = f"ORIGINAL:\n{generated_answer_text[:2000]}\n\nCLEANED:\n{cleaned_text[:1000]}"

        # send prompt to model
        response = api_manager.generate_content(model=model_name, contents=prompt)

        # Robustly extract text field from SDK response
        response_text = ''
        if hasattr(response, 'text'):
            response_text = getattr(response, 'text') or ''
        elif isinstance(response, dict) and 'text' in response:
            response_text = response.get('text') or ''
        else:
            response_text = str(response)

        response_text = response_text.strip()

        language, confidence, bangla_phrases, english_phrases, chinese_phrases, excluded_cot_phrases, mixed_languages_present = extract_language_classification(response_text)

        if language not in ('bangla', 'english', 'chinese', 'mixed'):
            language = get_heuristic_classification(generated_answer_text, COT_PHRASES)
            logging.info(f"Problem {problem_index}: heuristic fallback -> {language}")

        return {
            'response_text': response_text,
            'language': language,
            'confidence': confidence,
            'bangla_phrases': bangla_phrases,
            'english_phrases': english_phrases,
            'chinese_phrases': chinese_phrases,
            'excluded_cot_phrases': excluded_cot_phrases,
            'mixed_languages_present': mixed_languages_present,
        }

    except Exception as e:
        logging.error(f"Error classifying language for problem {problem_index}: {e}")
        try:
            heuristic_lang = get_heuristic_classification(str(generated_answer or ''), COT_PHRASES)
        except Exception:
            heuristic_lang = 'english'
        return {
            'response_text': f'Error: {e}',
            'language': heuristic_lang,
            'confidence': None,
            'bangla_phrases': [],
            'english_phrases': [],
            'chinese_phrases': [],
            'excluded_cot_phrases': [],
            'mixed_languages_present': '',
        }

# ---------------------------
# I/O Helpers
# ---------------------------

def save_json_atomic(obj: Dict, path: str) -> None:
    """Save JSON file atomically using a temporary file."""
    tmp = path + '.tmp'
    with open(tmp, 'w', encoding='utf-8') as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)
    os.replace(tmp, path)


def _load_input_data(path: str) -> List[Dict[str, Any]]:
    """Load input data from JSON or JSONL file."""
    if path.endswith('.jsonl') or path.endswith('.ndjson'):
        data: List[Dict[str, Any]] = []
        with open(path, 'r', encoding='utf-8') as f:
            for line_no, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as je:
                    logging.warning(f"Skipping invalid JSON on line {line_no} in {path}: {je}")
        return data
    else:
        with open(path, 'r', encoding='utf-8') as f:
            return json.load(f)


def calculate_confidence_stats(classification_results: List[Dict]) -> Dict:
    """Calculate statistics for confidence scores."""
    confidences: List[float] = []
    for r in classification_results:
        lc = r.get('language_classification') or {}
        conf = lc.get('confidence')
        if conf is not None:
            confidences.append(conf)
    if not confidences:
        return {"count": 0}
    confidences_sorted = sorted(confidences)
    return {
        "count": len(confidences),
        "mean": sum(confidences) / len(confidences),
        "min": min(confidences),
        "max": max(confidences),
        "median": confidences_sorted[len(confidences_sorted) // 2],
    }


# ---------------------------
# Main processing function
# ---------------------------
def process_json_file(api_manager: GeminiLanguageApiManager, input_path: str, output_path: str, model_name: str = DEFAULT_MODEL, checkpoint_interval: int = 50) -> None:
    """Process JSON file and classify languages for all problems."""
    json_data = _load_input_data(input_path)
    logging.info(f"Loaded {len(json_data)} problems from {input_path}")

    classification_results: List[Dict[str, Any]] = []
    total_problems = len(json_data)
    processed_problems = 0

    bangla_count = english_count = chinese_count = mixed_count = 0
    problems_with_confidence = 0

    for idx, problem in enumerate(tqdm(json_data, desc='Classifying languages')):
        try:
            problem_index = problem.get('problem_index', idx)
            generated_answer = problem.get('generated_answer', '')

            if not generated_answer:
                logging.warning(f"Skipping problem {problem_index}: missing generated answer")
                entry = problem.copy()
                entry['language_classification'] = {
                    'language': None,
                    'confidence': None,
                    'bangla_phrases': [],
                    'english_phrases': [],
                    'chinese_phrases': [],
                    'excluded_cot_phrases': [],
                    'classification_reason': 'Missing generated answer',
                    'response_text': 'Missing generated answer',
                }
                classification_results.append(entry)
                continue

            result = classify_language_with_gemini(api_manager, generated_answer, problem_index, model_name)

            language = result['language']
            confidence = result.get('confidence')

            entry = problem.copy()
            entry['language_classification'] = {
                'language': language,
                'confidence': confidence,
                'bangla_phrases': result.get('bangla_phrases', []),
                'english_phrases': result.get('english_phrases', []),
                'chinese_phrases': result.get('chinese_phrases', []),
                'excluded_cot_phrases': result.get('excluded_cot_phrases', []),
                'mixed_languages_present': result.get('mixed_languages_present', ''),
                'response_text': result.get('response_text', ''),
            }
            classification_results.append(entry)

            processed_problems += 1
            if language == 'bangla':
                bangla_count += 1
            elif language == 'english':
                english_count += 1
            elif language == 'chinese':
                chinese_count += 1
            elif language == 'mixed':
                mixed_count += 1

            if confidence is not None:
                problems_with_confidence += 1

            if confidence is not None:
                print(f"Problem {problem_index}: {language.upper()} (conf: {confidence:.2f})")
            else:
                print(f"Problem {problem_index}: {language.upper()}")

            if processed_problems % checkpoint_interval == 0:
                checkpoint_file = output_path.replace('.json', f'_checkpoint_{processed_problems}.json')
                save_json_atomic(classification_results, checkpoint_file)
                stats = api_manager.get_usage_stats()
                logging.info(f"Progress: {processed_problems}/{total_problems}")
                logging.info(f"Distribution - Bangla: {bangla_count}, English: {english_count}, Chinese: {chinese_count}, Mixed: {mixed_count}")
                logging.info(f"API Usage: {stats['total_used']}/{stats['total_available']} ({stats['percent_used']:.1f}%)")

        except Exception as e:
            logging.error(f"Error processing problem {problem.get('problem_index', idx)}: {e}")
            heuristic_lang = get_heuristic_classification(str(problem.get('generated_answer', '')), COT_PHRASES)
            entry = problem.copy()
            entry['language_classification'] = {
                'language': heuristic_lang,
                'confidence': None,
                'bangla_phrases': [],
                'english_phrases': [],
                'chinese_phrases': [],
                'excluded_cot_phrases': [],
                'classification_reason': f'Error: {e}',
                'response_text': f'Error: {e}',
            }
            classification_results.append(entry)
            continue

    confidence_rate = (problems_with_confidence / processed_problems) * 100 if processed_problems else 0
    confidence_stats = calculate_confidence_stats(classification_results)
    
    # Analyze mixed language compositions
    logging.info("Analyzing mixed language compositions...")
    mixed_analysis = analyze_mixed_compositions(classification_results)
    
    # Analyze language presence across all classifications
    logging.info("Analyzing language presence across all classifications...")
    language_presence_analysis = analyze_all_language_presence(classification_results)

    final_results = {
        'classification_metadata': {
            'model_used': model_name,
            'total_problems': total_problems,
            'successfully_processed': processed_problems,
            'language_distribution': {
                'bangla': bangla_count, 
                'english': english_count, 
                'chinese': chinese_count,
                'mixed': mixed_count
            },
            'language_percentages': {
                'bangla': round((bangla_count / processed_problems) * 100, 2) if processed_problems else 0,
                'english': round((english_count / processed_problems) * 100, 2) if processed_problems else 0,
                'chinese': round((chinese_count / processed_problems) * 100, 2) if processed_problems else 0,
                'mixed': round((mixed_count / processed_problems) * 100, 2) if processed_problems else 0,
            },
            'problems_with_confidence': problems_with_confidence,
            'confidence_rate_percentage': round(confidence_rate, 2),
            'confidence_statistics': confidence_stats,
            'mixed_language_analysis': {
                'total_mixed_problems': mixed_analysis['total_mixed'],
                'all_languages_found': mixed_analysis.get('all_languages_found', []),
                'language_combinations': mixed_analysis['combinations'],
                'chinese_in_mixed_count': mixed_analysis['chinese_in_mixed'],
                'chinese_in_mixed_percentage': mixed_analysis['chinese_percentage'],
            },
            'language_presence_by_classification': language_presence_analysis,
            'excluded_cot_phrases': COT_PHRASES,
            'classification_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        },
        'problem_classifications': classification_results,
        'mixed_problem_details': mixed_analysis['details']
    }

    save_json_atomic(final_results, output_path)

    # Summary print
    print('\n' + '='*60)
    print('LANGUAGE CLASSIFICATION COMPLETED!')
    print('='*60)
    print(f"Model Used: {model_name}")
    print(f"Total Problems: {total_problems}")
    print(f"Successfully Processed: {processed_problems}")
    if processed_problems:
        print(f"  - Bangla: {bangla_count} ({(bangla_count/processed_problems)*100:.1f}%)")
        print(f"  - English: {english_count} ({(english_count/processed_problems)*100:.1f}%)")
        print(f"  - Chinese: {chinese_count} ({(chinese_count/processed_problems)*100:.1f}%)")
        print(f"  - Mixed: {mixed_count} ({(mixed_count/processed_problems)*100:.1f}%)")
    print(f"Problems with Confidence: {problems_with_confidence} ({confidence_rate:.1f}%)")
    if confidence_stats.get('count', 0) > 0:
        print(f"Average Confidence: {confidence_stats['mean']:.3f}")
        print(f"Confidence Range: {confidence_stats['min']:.3f} - {confidence_stats['max']:.3f}")
    
    # Print mixed language analysis
    if mixed_count > 0:
        print('\n' + '='*60)
        print('MIXED LANGUAGE COMPOSITION ANALYSIS')
        print('='*60)
        print(f"Total Mixed Problems: {mixed_analysis['total_mixed']}")
        
        # Show all unique languages found
        all_langs = mixed_analysis.get('all_languages_found', [])
        if all_langs:
            print(f"\nAll Languages Found in Mixed: {', '.join(all_langs)}")
        
        print(f"\nChinese in Mixed: {mixed_analysis['chinese_in_mixed']} ({mixed_analysis['chinese_percentage']:.1f}%)")
        print("\nLanguage Combinations:")
        for combination, count in sorted(mixed_analysis['combinations'].items(), key=lambda x: x[1], reverse=True):
            percentage = (count / mixed_analysis['total_mixed'] * 100) if mixed_analysis['total_mixed'] > 0 else 0
            print(f"  {combination:40s}: {count:4d} ({percentage:5.1f}%)")
        
        # Show sample problems for each combination
        print("\n" + "-"*60)
        print("Sample Problems by Combination:")
        print("-"*60)
        
        # Group details by combination
        combo_examples = defaultdict(list)
        for detail in mixed_analysis['details']:
            combo_examples[detail['combination']].append(detail)
        
        # Show top 3 combinations with examples
        top_combinations = sorted(mixed_analysis['combinations'].items(), key=lambda x: x[1], reverse=True)[:3]
        for combination, count in top_combinations:
            print(f"\n{combination} ({count} problems):")
            examples = combo_examples[combination][:2]  # Show 2 examples
            for i, ex in enumerate(examples, 1):
                print(f"  Example {i} - Problem {ex['problem_index']}:")
                print(f"    Languages: {', '.join(ex['languages'])}")
                if ex.get('confidence'):
                    print(f"    Confidence: {ex['confidence']:.2f}")
                preview = ex.get('text_preview', '')[:100].replace('\n', ' ')
                if preview:
                    print(f"    Preview: {preview}...")
    
    # Print language presence analysis
    print('\n' + '='*60)
    print('LANGUAGE PRESENCE BY CLASSIFICATION')
    print('='*60)
    print(f"{'Classified As':<15} | {'Total':>6} | {'Bangla':>7} | {'English':>7} | {'Chinese':>7}")
    print('-'*60)
    for classification in ['bangla', 'english', 'chinese', 'mixed']:
        analysis = language_presence_analysis.get(classification, {})
        total = analysis.get('total', 0)
        if total > 0:
            print(f"{classification.capitalize():<15} | {total:>6d} | {analysis.get('bangla', 0):>7d} | {analysis.get('english', 0):>7d} | {analysis.get('chinese', 0):>7d}")
    
    print(f"\nCoT Phrases Excluded: {len(COT_PHRASES)} phrases")
    print(f"Output saved to: {output_path}")

    final_stats = api_manager.get_usage_stats()
    print(f"Final API Usage: {final_stats['total_used']}/{final_stats['total_available']} ({final_stats['percent_used']:.1f}%)")
    print('='*60)

# ---------------------------
# Small utilities / validators
# ---------------------------

def validate_json_structure(file_path: str) -> bool:
    """Validate the structure of the input JSON file."""
    try:
        data = _load_input_data(file_path)
        if not isinstance(data, list):
            print(f"Warning: {file_path} should contain a list of problems")
            return False
        if not data:
            print(f"Warning: {file_path} is empty")
            return False
        for i in range(min(3, len(data))):
            problem = data[i]
            if not isinstance(problem, dict):
                print(f"Warning: Problem {i} is not a dictionary")
                return False
            if 'generated_answer' not in problem:
                print(f"Warning: Problem {i}: missing 'generated_answer'")
                return False
        print(f"✅ {file_path} structure validation passed")
        return True
    except Exception as e:
        print(f"Error validating {file_path}: {e}")
        return False

# ---------------------------
# Main Execution
# ---------------------------

def main():
    """Main execution function."""
    gemini_api_path = "/kaggle/input/part_6_api_key.txt"  # update if needed (single path)
    input_path = "/kaggle/input/Code Switching Files/Mathstral_7B_Q(E)_CoT(B).json"
    output_path = "/kaggle/working/Mathstral_7B_Q(E)_CoT(B)_Language_Classification.json"
    model_name = DEFAULT_MODEL
    calls_per_day = DEFAULT_CALLS_PER_DAY
    rate_limit_delay = DEFAULT_RATE_LIMIT
    checkpoint_interval = 50

    print("="*60)
    print("GEMINI LANGUAGE CLASSIFICATION TOOL")
    print("WITH CHINESE DETECTION")
    print("="*60)

    api_keys = load_api_keys(key_path=gemini_api_path)
    api_manager = GeminiLanguageApiManager(api_keys, calls_per_day=calls_per_day, rate_limit_delay=rate_limit_delay)

    validate_json_structure(input_path)

    process_json_file(api_manager, input_path, output_path, model_name=model_name, checkpoint_interval=checkpoint_interval)


if __name__ == '__main__':
    main()