Meta

In [None]:
#!/usr/bin/env python3
"""
BigQuery Pipeline with Gemini 1.5 Flash Judge - Meta Dataset
Uses Gemini 1.5 Flash to evaluate Gemini Flash decisions on Meta dataset
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from google.cloud import bigquery
from google.cloud import storage
from google.oauth2 import service_account
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
import json
from datetime import datetime
from typing import Dict, List, Any, Optional
import traceback
import time
import os
from tqdm import tqdm
from dotenv import load_dotenv
import re
import io
from PIL import Image

# Load environment variables
load_dotenv('.env')

# Configuration for Meta dataset with Gemini 1.5 Flash
PROJECT_ID = "scope3-dev"
DATASET_ID = "research_bs_monitoring"
META_TABLE = "BA_Meta_Ground_Truth"  # Meta ground truth table
RESULTS_TABLE = "BA_Meta_Gemini_15_Flash_Judge_Results"  # Results table for Gemini 1.5 Flash
RESEARCH_BUCKET = os.getenv('RESEARCH_BUCKET', 'classification-research')
GEMINI_MODEL_NAME = "gemini-1.5-flash"  # Gemini 1.5 Flash model
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Configure Gemini 1.5 Flash
if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
    print("✅ Gemini 1.5 Flash configured successfully!")
else:
    print("❌ ERROR: Please set your GEMINI_API_KEY environment variable")
    exit(1)

# Initialize BigQuery client
def initialize_bigquery_client() -> bigquery.Client:
    """Initialize BigQuery client with proper error handling"""
    try:
        client = bigquery.Client(project=PROJECT_ID)
        
        # Test the connection immediately
        print(f"🔧 Testing BigQuery connection to {PROJECT_ID}...")
        
        # Verify we can access the dataset
        dataset = client.get_dataset(f"{PROJECT_ID}.{DATASET_ID}")
        print(f"✅ BigQuery connection successful!")
        print(f"   Project: {PROJECT_ID}")
        print(f"   Dataset: {dataset.dataset_id}")
        print(f"   Location: {dataset.location}")
        
        return client
        
    except Exception as e:
        print(f"❌ BigQuery initialization failed: {type(e).__name__}: {e}")
        print(f"\n💡 Troubleshooting steps:")
        print(f"   1. Run: gcloud auth application-default login")
        print(f"   2. Verify project ID: {PROJECT_ID}")
        print(f"   3. Verify dataset exists: {DATASET_ID}")
        print(f"   4. Check BigQuery permissions (Data Editor, Job User)")
        raise

# Initialize clients globally
client = initialize_bigquery_client()

try:
    storage_client = storage.Client(project=PROJECT_ID)
    print("✅ Storage client initialized for GCS access")
except Exception as e:
    print(f"⚠️ Storage client failed: {e}")
    storage_client = None

# Initialize Gemini model
gemini_model = genai.GenerativeModel(GEMINI_MODEL_NAME)
print(f"✅ {GEMINI_MODEL_NAME} model initialized")

class Gemini15FlashJudgePipeline:
    """Pipeline using Gemini 1.5 Flash to judge Gemini Flash decisions for Meta dataset"""
    
    def __init__(self):
        self.api_call_count = 0
        self.total_api_time = 0
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.successful_evaluations = 0
        self.failed_evaluations = 0

    def check_available_tables(self):
        """Check what tables are available in the dataset"""
        try:
            print(f"🔍 Checking available tables in {PROJECT_ID}.{DATASET_ID}...")
            
            query = f"""
            SELECT table_name, table_type, creation_time 
            FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.TABLES`
            ORDER BY table_name
            """
            
            tables_df = client.query(query).to_dataframe()
            print(f"📊 Available tables:")
            for _, row in tables_df.iterrows():
                print(f"   - {row['table_name']} ({row['table_type']})")
            
            return tables_df['table_name'].tolist()
            
        except Exception as e:
            print(f"⚠️ Could not check tables: {e}")
            return []

    def load_full_meta_dataset(self, limit: int = None) -> pd.DataFrame:
        """Load Meta dataset from BigQuery - Simplified without unused fields"""
        
        if limit:
            print(f"📥 Loading TEST Meta dataset ({limit} records) from {PROJECT_ID}.{DATASET_ID}.{META_TABLE}...")
        else:
            print(f"📥 Loading COMPLETE Meta dataset from {PROJECT_ID}.{DATASET_ID}.{META_TABLE}...")
        
        # First, check available tables
        available_tables = self.check_available_tables()
        
        # Simplified query using only Meta table
        limit_clause = f"LIMIT {limit}" if limit else ""
        query = f"""
        SELECT 
            artifact_id,
            artifact_json_gcs_url,
            model_prompt,
            correct_classification,
            correct_reasoning,
            source,
            '{META_TABLE}' as data_source
        FROM `{PROJECT_ID}.{DATASET_ID}.{META_TABLE}`
        ORDER BY artifact_id
        {limit_clause}
        """
        
        try:
            print(f"⏳ Executing query to load {'test' if limit else 'full'} Meta dataset...")
            df = client.query(query).to_dataframe()
            
            if limit:
                print(f"✅ Successfully loaded TEST Meta dataset!")
                print(f"📊 Test records: {len(df):,} (requested: {limit})")
            else:
                print(f"✅ Successfully loaded COMPLETE Meta dataset!")
                print(f"📊 Total records: {len(df):,}")
            
            print(f"📊 Classification distribution: {df['correct_classification'].value_counts().to_dict()}")
            print(f"📊 Source distribution: {df['source'].value_counts().to_dict()}")
            print(f"📊 Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")
            
            return df
            
        except Exception as e:
            print(f"❌ Error loading {'test' if limit else 'full'} Meta dataset: {type(e).__name__}: {e}")
            raise

    def extract_gcs_path(self, url: str) -> tuple:
        """Extract bucket name and path from GCS URL"""
        try:
            if not isinstance(url, str) or url == 'nan' or not url.strip():
                return None, None
            
            # Check for pandas NaN    
            if pd.isna(url):
                return None, None
                
            # Match gs://bucket-name/path format
            match = re.match(r"gs://([^/]+)/(.+)", url.strip())
            if match:
                return match.group(1), match.group(2)
            return None, None
        except Exception as e:
            print(f"Error extracting GCS path from {url}: {e}")
            return None, None

    def read_json_from_gcs(self, bucket_name: str, json_path: str) -> Optional[Dict]:
        """Read JSON artifact from GCS bucket"""
        try:
            bucket = storage_client.bucket(bucket_name)
            blob = bucket.blob(json_path)
            
            if not blob.exists():
                print(f"JSON file not found: gs://{bucket_name}/{json_path}")
                return None
                
            json_content = blob.download_as_text()
            return json.loads(json_content)
        except Exception as e:
            print(f"Error reading JSON from GCS: {e}")
            return None

    def load_artifact_content(self, row):
        """Load artifact content with retry logic"""
        if not storage_client:
            print(f"⚠️ No storage client, using demo content for {row['artifact_id']}")
            return self.create_demo_content()
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Try to load from artifact_json_gcs_url (primary method for Meta table)
                if pd.notna(row.get('artifact_json_gcs_url')):
                    gcs_url = row['artifact_json_gcs_url']
                    if attempt == 0:
                        print(f"📥 Loading Meta content from GCS URL: {gcs_url}")
                    
                    bucket_name, json_path = self.extract_gcs_path(gcs_url)
                    if bucket_name and json_path:
                        content = self.read_json_from_gcs(bucket_name, json_path)
                        if content:
                            if attempt == 0:
                                print(f"✅ Successfully loaded Meta content for {row['artifact_id']}")
                            return content
                        else:
                            if attempt < max_retries - 1:
                                print(f"⚠️ Attempt {attempt + 1} failed to load Meta JSON, retrying...")
                                time.sleep(1)  # Brief pause before retry
                                continue
                    else:
                        print(f"⚠️ Could not parse Meta GCS URL: {gcs_url}")
                        break
                
                # If all retries fail, use demo content
                if attempt == max_retries - 1:
                    print(f"⚠️ All {max_retries} attempts failed for {row['artifact_id']}, using demo content")
                    return self.create_demo_content()
                    
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"⚠️ Attempt {attempt + 1} failed for {row['artifact_id']}: {e}, retrying...")
                    time.sleep(1)
                    continue
                else:
                    print(f"⚠️ All retries failed for {row['artifact_id']}: {e}, using demo content")
                    return self.create_demo_content()
        
        return self.create_demo_content()
    
    def create_demo_content(self):
        """Create demo content when GCS loading fails"""
        return {
            "url": "https://example.com/content",
            "text_content": [
                {
                    "type": "heading",
                    "heading": {"level": 1, "text": "Sample Content"}
                },
                {
                    "type": "paragraph", 
                    "paragraphs": [
                        "This is sample content for demonstration purposes.",
                        "It represents the type of content that would be evaluated for brand safety."
                    ]
                }
            ]
        }

    def extract_content_from_artifact(self, artifact_json: Dict) -> Dict[str, any]:
        """Extract text and image content from artifact JSON"""
        content = {
            'text_content': '',
            'image_paths': [],
            'has_image': False,
            'content_type': 'unknown'
        }
        
        try:
            text_parts = []
            
            # Look for structured JSON content with text_content array (main approach)
            text_content_array = artifact_json.get("text_content", [])
            if text_content_array and isinstance(text_content_array, list):
                # Parse structured content
                for content_item in text_content_array:
                    if isinstance(content_item, dict):
                        content_type = content_item.get("type", "")
                        
                        if content_type == "paragraph":
                            paragraphs = content_item.get("paragraphs", [])
                            if paragraphs:
                                text_parts.extend(paragraphs)
                        
                        elif content_type == "image":
                            image_data = content_item.get("image", {})
                            if isinstance(image_data, dict):
                                image_identifier = image_data.get("filepath", "")
                                if image_identifier:
                                    content['image_paths'].append(image_identifier)
                                    content['has_image'] = True
                                
                                # Also extract alt_text for context
                                alt_text = image_data.get("alt_text", "")
                                if alt_text:
                                    text_parts.append(f"[Image: {alt_text}]")
                        
                        elif content_type == "video":
                            video_data = content_item.get("video", {})
                            if isinstance(video_data, dict):
                                frames = video_data.get("frames", [])
                                if frames and len(frames) > 0:
                                    # Use first frame as representative image
                                    first_frame = frames[0] if isinstance(frames[0], dict) else {}
                                    frame_identifier = first_frame.get("filepath", "")
                                    if frame_identifier:
                                        content['image_paths'].append(frame_identifier)
                                        content['has_image'] = True
                    
                    elif isinstance(content_item, str):
                        # Handle direct string content
                        text_parts.append(content_item)
            
            # Fallback to legacy format if no structured content found
            if not text_parts and not content['has_image']:
                # Common text fields
                for field in ['title', 'description', 'caption', 'text', 'content', 'body']:
                    if field in artifact_json and artifact_json[field]:
                        text_parts.append(str(artifact_json[field]))
                
                # Legacy image fields
                legacy_content = artifact_json.get("text_content", "")
                if isinstance(legacy_content, str):
                    text_parts.append(legacy_content)
                    
                    # Look for image identifiers in the text using regex - FIXED VERSION
                    image_pattern = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\.(?:jpg|jpeg|png|webp|avif)'
                    image_matches = re.findall(image_pattern, legacy_content)
                    if image_matches:
                        # Use the first full match (complete UUID + extension)
                        content['image_paths'].append(image_matches[0])
                        content['has_image'] = True
                
                # Platform-specific extraction
                if 'platform_data' in artifact_json:
                    platform_data = artifact_json['platform_data']
                    if isinstance(platform_data, dict):
                        for key, value in platform_data.items():
                            if isinstance(value, str) and value.strip():
                                text_parts.append(value)
                
                # Legacy image URL fields
                image_fields = ['image_url', 'image_path', 'media_url', 'thumbnail_url']
                for field in image_fields:
                    if field in artifact_json and artifact_json[field]:
                        content['image_paths'].append(artifact_json[field])
                        content['has_image'] = True
            
            # Combine all text content
            content['text_content'] = ' '.join(text_parts).strip()
            
            # Determine content type
            if content['has_image'] and content['text_content']:
                content['content_type'] = 'multimodal'
            elif content['has_image']:
                content['content_type'] = 'image_only'
            elif content['text_content']:
                content['content_type'] = 'text_only'
            else:
                content['content_type'] = 'no_content'
                
        except Exception as e:
            print(f"Error extracting content: {e}")
        
        return content

    def extract_source_context(self, content):
        """Extract text content from artifact"""
        extracted_content = self.extract_content_from_artifact(content)
        return extracted_content['text_content'] if extracted_content['text_content'] else "Sample content for evaluation"

    def query_gemini_judge(self, content, brand_prompt, gemini_flash_verdict):
        """Query Gemini 1.5 Flash with retry logic"""
        max_retries = 3
        base_delay = 1
        
        for attempt in range(max_retries):
            try:
                # Extract text content using the same method as your pipeline
                source_context = self.extract_source_context(content)
                
                # Create judge prompt - Gemini 1.5 Flash evaluating Gemini Flash
                judge_prompt = f"""You are an expert content moderator using Gemini 1.5 Flash to evaluate another AI model's (Gemini Flash) brand safety decision.

Use your reasoning capabilities to provide a thorough analysis of whether Gemini Flash made the right decision.

BRAND GUIDELINES:
{brand_prompt}

CONTENT TO EVALUATE:
{source_context[:2000]}{"..." if len(source_context) > 2000 else ""}

GEMINI FLASH'S VERDICT:
{gemini_flash_verdict}

As Gemini 1.5 Flash, evaluate whether Gemini Flash's verdict is correct and well-reasoned. Consider:
1. Accuracy of the classification given the brand guidelines
2. Quality and completeness of the reasoning provided
3. Any nuances or edge cases that Flash may have missed
4. Whether you would reach the same conclusion

Respond in JSON format:
{{
  "judge_agreement": true/false,
  "confidence": 0.0-1.0,
  "verdict": "Aligned" or "Not-Aligned",
  "reasoning": "Your detailed analysis as Gemini 1.5 Flash of why Flash's decision was correct or incorrect",
  "improvements": ["suggestions for better reasoning if disagreeing"],
  "flash_analysis": "analysis of Flash's performance on this content"
}}"""
                
                # Configure safety settings
                safety_settings = {
                    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                }
                
                # Make API call with Gemini 1.5 Flash
                start_time = time.time()
                response = gemini_model.generate_content(
                    judge_prompt,
                    safety_settings=safety_settings
                )
                api_time = time.time() - start_time
                
                # Update tracking
                self.api_call_count += 1
                self.total_api_time += api_time
                
                # Track tokens if available
                if hasattr(response, 'usage_metadata'):
                    self.total_input_tokens += response.usage_metadata.prompt_token_count
                    self.total_output_tokens += response.usage_metadata.candidates_token_count
                
                if response.text:
                    result = self.parse_gemini_response(response.text, api_time)
                    result['model_used'] = GEMINI_MODEL_NAME
                    result['attempts'] = attempt + 1
                    return result
                else:
                    if attempt < max_retries - 1:
                        delay = base_delay * (2 ** attempt)  # Exponential backoff
                        print(f"⚠️ Empty response, retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
                        time.sleep(delay)
                        continue
                    else:
                        return {
                            'error': f'Empty response from {GEMINI_MODEL_NAME} after {max_retries} attempts',
                            'model_used': GEMINI_MODEL_NAME,
                            'api_time': api_time,
                            'attempts': max_retries
                        }
                        
            except Exception as e:
                if attempt < max_retries - 1:
                    delay = base_delay * (2 ** attempt)  # Exponential backoff
                    print(f"⚠️ API error: {str(e)[:100]}..., retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
                    time.sleep(delay)
                    continue
                else:
                    return {
                        'error': f'Failed after {max_retries} attempts: {str(e)}',
                        'model_used': GEMINI_MODEL_NAME,
                        'api_time': 0,
                        'attempts': max_retries
                    }
        
        # Should never reach here, but just in case
        return {
            'error': f'Unexpected failure after {max_retries} attempts',
            'model_used': GEMINI_MODEL_NAME,
            'api_time': 0,
            'attempts': max_retries
        }

    def parse_gemini_response(self, response_text, api_time):
        """Parse Gemini 1.5 Flash response"""
        try:
            # Clean the response - remove markdown code blocks if present
            cleaned_text = response_text.strip()
            if cleaned_text.startswith('```json'):
                cleaned_text = cleaned_text.replace('```json', '').replace('```', '').strip()
            elif cleaned_text.startswith('```'):
                cleaned_text = cleaned_text.replace('```', '').strip()
            
            # Try to find JSON in response
            if '{' in cleaned_text and '}' in cleaned_text:
                start_idx = cleaned_text.find('{')
                end_idx = cleaned_text.rfind('}') + 1
                json_str = cleaned_text[start_idx:end_idx]
                
                parsed = json.loads(json_str)
                
                # Ensure would_reach_same_conclusion is present
                would_reach_same = parsed.get('would_reach_same_conclusion')
                if would_reach_same is None:
                    # Try to infer from the text if not explicitly provided
                    full_text = response_text.lower()
                    if any(phrase in full_text for phrase in ['same conclusion', 'reach the same', 'would also classify', 'independent conclusion']):
                        would_reach_same = True
                    elif any(phrase in full_text for phrase in ['different conclusion', 'would not reach', 'disagree with classification']):
                        would_reach_same = False
                    else:
                        # Default to same as judge_agreement if unclear
                        would_reach_same = parsed.get('judge_agreement', None)
                
                return {
                    'judge_agreement': parsed.get('judge_agreement', None),
                    'confidence': parsed.get('confidence', 0.0),
                    'verdict': parsed.get('verdict', 'Unknown'),
                    'would_reach_same_conclusion': would_reach_same,
                    'reasoning': parsed.get('reasoning', ''),
                    'improvements': parsed.get('improvements', []),
                    'flash_analysis': parsed.get('flash_analysis', ''),
                    'raw_response': response_text,
                    'api_time': api_time
                }
            else:
                # Fallback parsing
                agreement = any(word in response_text.lower() for word in ['agree', 'correct', 'accurate'])
                same_conclusion = any(phrase in response_text.lower() for phrase in [
                    'same conclusion', 'reach the same', 'would also', 'independent conclusion'
                ])
                return {
                    'judge_agreement': agreement,
                    'confidence': 0.5,
                    'verdict': 'Aligned' if agreement else 'Not-Aligned', 
                    'would_reach_same_conclusion': same_conclusion if same_conclusion else agreement,
                    'reasoning': response_text,
                    'improvements': [],
                    'flash_analysis': '',
                    'raw_response': response_text,
                    'api_time': api_time
                }
                
        except json.JSONDecodeError as e:
            print(f"⚠️ JSON parsing failed: {e}")
            print(f"Raw text: {response_text[:200]}...")
            
            agreement = any(word in response_text.lower() for word in ['agree', 'reasonable'])
            # Try to infer would_reach_same_conclusion from text
            same_conclusion = any(phrase in response_text.lower() for phrase in [
                'same conclusion', 'reach the same', 'would also', 'agree with the verdict'
            ])
            return {
                'judge_agreement': agreement,
                'would_reach_same_conclusion': same_conclusion if same_conclusion else agreement,
                'confidence': 0.5,
                'verdict': 'Aligned' if agreement else 'Not-Aligned',
                'reasoning': response_text,
                'improvements': [],
                'flash_analysis': '',
                'raw_response': response_text,
                'api_time': api_time
            }

    def run_full_dataset_evaluation(self, meta_df: pd.DataFrame, 
                                   batch_size: int = 100,
                                   save_to_bigquery: bool = True) -> pd.DataFrame:
        """
        Run Gemini 1.5 Flash evaluation on the COMPLETE Meta dataset
        """
        
        print(f"🔄 RUNNING GEMINI 1.5 FLASH EVALUATION ON FULL META DATASET")
        print("=" * 70)
        print(f"📊 Total records to process: {len(meta_df):,}")
        print(f"📦 Batch size: {batch_size:,}")
        print(f"📦 Number of batches: {(len(meta_df) + batch_size - 1) // batch_size}")
        print(f"🧠 Judge model: {GEMINI_MODEL_NAME}")
        print(f"⚡ Flash model: Gemini Flash (from ground truth)")
        
        # Initialize results storage
        all_results = []
        
        # Process in batches
        for batch_num in range(0, len(meta_df), batch_size):
            batch_end = min(batch_num + batch_size, len(meta_df))
            batch_df = meta_df.iloc[batch_num:batch_end].copy()
            
            batch_number = (batch_num // batch_size) + 1
            total_batches = (len(meta_df) + batch_size - 1) // batch_size
            
            print(f"\n🔄 Processing Batch {batch_number}/{total_batches}")
            print(f"   Records: {batch_num + 1:,} to {batch_end:,}")
            print(f"   Batch size: {len(batch_df):,}")
            
            batch_start_time = time.time()
            batch_results = []
            batch_successful = 0
            batch_failed = 0
            
            # Process each record in the batch with progress bar
            for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), 
                               desc=f"Batch {batch_number}", leave=False):
                
                try:
                    # Load artifact content using updated method
                    content = self.load_artifact_content(row)
                    
                    # Get Gemini 1.5 Flash judgment of Flash's decision
                    judge_result = self.query_gemini_judge(
                        content, 
                        row['model_prompt'], 
                        row['correct_reasoning']
                    )
                    
                    # Create result record
                    result = {
                        'artifact_id': str(row['artifact_id']),
                        'data_source': str(row['data_source']),
                        'source': str(row.get('source', 'unknown')),
                        'flash_classification': int(row['correct_classification']),
                        'flash_reasoning': str(row['correct_reasoning'])[:1000],  # Truncate for storage
                        'model_prompt': str(row['model_prompt'])[:500],
                        'judge_agreement': judge_result.get('judge_agreement', None),
                        'verdict': str(judge_result.get('verdict', 'Unknown')),
                        'confidence': float(judge_result.get('confidence', 0.0)),
                        'would_reach_same_conclusion': judge_result.get('would_reach_same_conclusion', None),
                        'reasoning': str(judge_result.get('reasoning', ''))[:1000],
                        'flash_analysis': str(judge_result.get('flash_analysis', ''))[:500],
                        'improvements': str('; '.join(judge_result.get('improvements', [])))[:500],
                        'api_call_time': float(judge_result.get('api_time', 0.0)),
                        'batch_number': int(batch_number),
                        'created_at': pd.Timestamp.now(),
                        'model_used': str(judge_result.get('model_used', GEMINI_MODEL_NAME)),
                        'error_message': str(judge_result.get('error', '')) if 'error' in judge_result else None,
                        'retry_attempts': int(judge_result.get('attempts', 1))
                    }
                    
                    batch_results.append(result)
                    batch_successful += 1
                    self.successful_evaluations += 1
                    
                except Exception as e:
                    print(f"❌ Error processing {row['artifact_id']}: {e}")
                    batch_failed += 1
                    self.failed_evaluations += 1
                    continue
            
            batch_time = time.time() - batch_start_time
            
            # Calculate batch statistics
            if batch_results:
                batch_agreement_rate = sum(1 for r in batch_results 
                                         if r['judge_agreement'] is True) / len(batch_results)
                batch_avg_confidence = np.mean([r['confidence'] for r in batch_results 
                                              if r['confidence'] > 0])
            else:
                batch_agreement_rate = 0
                batch_avg_confidence = 0
            
            print(f"   ✅ Batch {batch_number} completed in {batch_time:.1f}s")
            print(f"   📊 Successful: {batch_successful}, Failed: {batch_failed}")
            print(f"   📊 Agreement rate: {batch_agreement_rate:.1%}")
            print(f"   📊 Avg confidence: {batch_avg_confidence:.2f}")
            print(f"   📊 API calls: {self.api_call_count}")
            
            # Add batch results to overall results
            all_results.extend(batch_results)
            
            # Save batch to BigQuery if requested
            if save_to_bigquery and len(batch_results) > 0:
                self.save_batch_to_bigquery(pd.DataFrame(batch_results), batch_number)
        
        # Convert all results to DataFrame
        results_df = pd.DataFrame(all_results)
        
        # Final statistics
        final_agreement_rate = results_df['judge_agreement'].mean() if len(results_df) > 0 else 0
        final_avg_confidence = results_df['confidence'].mean() if len(results_df) > 0 else 0
        total_api_time = results_df['api_call_time'].sum() if len(results_df) > 0 else 0
        
        print(f"\n✅ GEMINI 1.5 FLASH EVALUATION COMPLETED!")
        print(f"📊 Total records processed: {len(results_df):,}")
        print(f"📊 Successful evaluations: {self.successful_evaluations}")
        print(f"📊 Failed evaluations: {self.failed_evaluations}")
        print(f"📊 Overall agreement rate: {final_agreement_rate:.1%}")
        print(f"📊 Average confidence: {final_avg_confidence:.2f}")
        print(f"📊 Total API calls: {self.api_call_count}")
        print(f"📊 Total API time: {total_api_time:.1f}s ({total_api_time/60:.1f} minutes)")
        
        # Cost estimation for Gemini 1.5 Flash (much cheaper than 2.5 Pro)
        input_cost = self.total_input_tokens * 0.00000075  # Gemini 1.5 Flash pricing
        output_cost = self.total_output_tokens * 0.0000015
        total_cost = input_cost + output_cost
        print(f"💰 Estimated cost: ${total_cost:.4f}")
        print(f"   Input tokens: {self.total_input_tokens:,} (${input_cost:.4f})")
        print(f"   Output tokens: {self.total_output_tokens:,} (${output_cost:.4f})")
        
        return results_df

    def save_batch_to_bigquery(self, batch_df: pd.DataFrame, batch_number: int):
        """Save a batch of results to BigQuery with updated schema"""
        print(f"   💾 Saving batch {batch_number} to BigQuery ({len(batch_df)} records)...")
        
        try:
            # Configure the load job
            table_ref = f"{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}"
            
            job_config = bigquery.LoadJobConfig(
                write_disposition="WRITE_APPEND",  # Append data
                create_disposition="CREATE_IF_NEEDED",  # Create table if needed
                schema=[
                    bigquery.SchemaField("artifact_id", "STRING", mode="REQUIRED"),
                    bigquery.SchemaField("data_source", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("source", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("flash_classification", "INTEGER", mode="NULLABLE"),
                    bigquery.SchemaField("flash_reasoning", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("model_prompt", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("judge_agreement", "BOOLEAN", mode="NULLABLE"),
                    bigquery.SchemaField("verdict", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("confidence", "FLOAT", mode="NULLABLE"),
                    bigquery.SchemaField("would_reach_same_conclusion", "BOOLEAN", mode="NULLABLE"),
                    bigquery.SchemaField("reasoning", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("flash_analysis", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("improvements", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("api_call_time", "FLOAT", mode="NULLABLE"),
                    bigquery.SchemaField("batch_number", "INTEGER", mode="NULLABLE"),
                    bigquery.SchemaField("created_at", "TIMESTAMP", mode="NULLABLE"),
                    bigquery.SchemaField("model_used", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("error_message", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("retry_attempts", "INTEGER", mode="NULLABLE"),
                ]
            )
            
            # Load the batch
            job = client.load_table_from_dataframe(batch_df, table_ref, job_config=job_config)
            job.result()  # Wait for completion
            
            print(f"   ✅ Batch {batch_number} saved successfully!")
            
        except Exception as e:
            print(f"   ❌ Error saving batch {batch_number}: {e}")
            # Don't stop the whole process for one batch failure

def main_gemini_15_flash_evaluation():
    """
    Main function to run Gemini 1.5 Flash evaluation on Meta dataset
    """
    
    print("🚀 STARTING GEMINI 1.5 FLASH JUDGE EVALUATION - FULL META DATASET")
    print("=" * 80)
    
    try:
        # Initialize pipeline
        pipeline = Gemini15FlashJudgePipeline()
        
        # Step 1: Load the complete Meta dataset
        print(f"\n📋 Step 1: Loading complete Meta dataset...")
        meta_df = pipeline.load_full_meta_dataset()
        
        if meta_df.empty:
            print("❌ No data loaded from Meta table")
            return None
        
        # Step 2: Test with one sample first
        print(f"\n🧪 Step 2: Testing with one sample...")
        test_row = meta_df.iloc[0]
        print(f"   Testing Meta artifact: {test_row['artifact_id']}")
        
        content = pipeline.load_artifact_content(test_row)
        judge_result = pipeline.query_gemini_judge(
            content, 
            test_row['model_prompt'], 
            test_row['correct_reasoning']
        )
        
        if 'error' in judge_result:
            print(f"❌ Test failed: {judge_result['error']}")
            return None
        else:
            print(f"✅ Test successful!")
            print(f"   Agreement: {judge_result.get('judge_agreement')}")
            print(f"   Confidence: {judge_result.get('confidence')}")
            print(f"   API time: {judge_result.get('api_time'):.2f}s")
        
        # Step 3: Run evaluation on dataset
        batch_size = 50  # Optimized batch size for full dataset
        print(f"\n🔄 Step 3: Running evaluation on {len(meta_df):,} records...")
        results_df = pipeline.run_full_dataset_evaluation(
            meta_df, 
            batch_size=batch_size,
            save_to_bigquery=True
        )
        
        if results_df.empty:
            print("❌ No evaluation results generated")
            return None
        
        # Step 4: Create summary report
        print(f"\n📄 Step 4: Creating summary report...")
        
        summary_report = {
            'timestamp': pd.Timestamp.now().isoformat(),
            'dataset_info': {
                'source_table': f"{PROJECT_ID}.{DATASET_ID}.{META_TABLE}",
                'results_table': f"{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}",
                'total_records_processed': len(results_df),
                'successful_evaluations': pipeline.successful_evaluations,
                'failed_evaluations': pipeline.failed_evaluations,
                'batch_size': batch_size
            },
            'model_info': {
                'judge_model': GEMINI_MODEL_NAME,
                'flash_model': 'gemini-flash',
                'api_calls': pipeline.api_call_count,
                'total_api_time': pipeline.total_api_time,
                'total_input_tokens': pipeline.total_input_tokens,
                'total_output_tokens': pipeline.total_output_tokens
            },
            'performance_metrics': {
                'overall_agreement_rate': float(results_df['judge_agreement'].mean()),
                'average_confidence': float(results_df['confidence'].mean()),
                'classification_breakdown': dict(results_df['flash_classification'].value_counts()),
                'agreement_by_classification': dict(
                    results_df.groupby('flash_classification')['judge_agreement'].mean()
                )
            }
        }
        
        # Save summary to file
        summary_filename = f"gemini_15_flash_meta_summary_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(summary_filename, 'w') as f:
            json.dump(summary_report, f, indent=2, default=str)
        
        print(f"✅ Summary report saved to: {summary_filename}")
        
        # Final success message
        print(f"\n🎉 GEMINI 1.5 FLASH JUDGE EVALUATION COMPLETED SUCCESSFULLY!")
        print(f"📊 Total records processed: {len(results_df):,}")
        print(f"📊 Overall agreement rate: {results_df['judge_agreement'].mean():.1%}")
        print(f"📊 API calls made: {pipeline.api_call_count:,}")
        print(f"💾 Results saved to: {PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}")
        print(f"📄 Summary report: {summary_filename}")
        
        # Show final query to view results
        print(f"\n🔍 QUERY TO VIEW RESULTS:")
        print(f"SELECT * FROM `{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}` ORDER BY created_at DESC LIMIT 100")
        
        return results_df, summary_report
        
    except Exception as e:
        print(f"❌ GEMINI 1.5 FLASH EVALUATION FAILED: {type(e).__name__}: {e}")
        traceback.print_exc()
        return None, None

if __name__ == "__main__":
    # Run the evaluation on FULL META DATASET with Gemini 1.5 Flash
    print("🚀 RUNNING FULL META DATASET EVALUATION WITH GEMINI 1.5 FLASH")
    print("Enhanced with robust retry logic and cleaned schema")
    print("=" * 60)
    
    # Run full Meta dataset evaluation
    results, summary = main_gemini_15_flash_evaluation()
    
    if results is not None:
        print(f"\n✅ FULL Meta dataset evaluation completed successfully!")
        print(f"📊 {len(results):,} evaluations completed")
        print(f"📊 Agreement rate: {results['judge_agreement'].mean():.1%}")
        print(f"📊 Average retry attempts: {results.get('retry_attempts', pd.Series([1])).mean():.1f}")
    else:
        print(f"\n❌ FULL Meta dataset evaluation failed - check error messages above")

Web data

In [None]:
#!/usr/bin/env python3
"""
BigQuery Pipeline with Gemini 1.5 Flash Judge - Web Dataset
Uses Gemini 1.5 Flash to evaluate Gemini Flash decisions on Web dataset
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from google.cloud import bigquery
from google.cloud import storage
from google.oauth2 import service_account
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
import json
from datetime import datetime
from typing import Dict, List, Any, Optional
import traceback
import time
import os
from tqdm import tqdm
from dotenv import load_dotenv
import re
import io
from PIL import Image

# Load environment variables
load_dotenv('.env')

# Configuration for Web dataset with Gemini 1.5 Flash
PROJECT_ID = "scope3-dev"
DATASET_ID = "research_bs_monitoring"
WEB_TABLE = "BA_Web_Ground_Truth"  # Web ground truth table
RESULTS_TABLE = "BA_Web_Gemini_15_Flash_Judge_Results"  # Results table for Gemini 1.5 Flash
RESEARCH_BUCKET = os.getenv('RESEARCH_BUCKET', 'classification-research')
GEMINI_MODEL_NAME = "gemini-1.5-flash"  # Gemini 1.5 Flash model
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

# Configure Gemini 1.5 Flash
if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
    print("✅ Gemini 1.5 Flash configured successfully!")
else:
    print("❌ ERROR: Please set your GEMINI_API_KEY environment variable")
    exit(1)

# Initialize BigQuery client
def initialize_bigquery_client() -> bigquery.Client:
    """Initialize BigQuery client with proper error handling"""
    try:
        client = bigquery.Client(project=PROJECT_ID)
        
        # Test the connection immediately
        print(f"🔧 Testing BigQuery connection to {PROJECT_ID}...")
        
        # Verify we can access the dataset
        dataset = client.get_dataset(f"{PROJECT_ID}.{DATASET_ID}")
        print(f"✅ BigQuery connection successful!")
        print(f"   Project: {PROJECT_ID}")
        print(f"   Dataset: {dataset.dataset_id}")
        print(f"   Location: {dataset.location}")
        
        return client
        
    except Exception as e:
        print(f"❌ BigQuery initialization failed: {type(e).__name__}: {e}")
        print(f"\n💡 Troubleshooting steps:")
        print(f"   1. Run: gcloud auth application-default login")
        print(f"   2. Verify project ID: {PROJECT_ID}")
        print(f"   3. Verify dataset exists: {DATASET_ID}")
        print(f"   4. Check BigQuery permissions (Data Editor, Job User)")
        raise

# Initialize clients globally
client = initialize_bigquery_client()

try:
    storage_client = storage.Client(project=PROJECT_ID)
    print("✅ Storage client initialized for GCS access")
except Exception as e:
    print(f"⚠️ Storage client failed: {e}")
    storage_client = None

# Initialize Gemini model
gemini_model = genai.GenerativeModel(GEMINI_MODEL_NAME)
print(f"✅ {GEMINI_MODEL_NAME} model initialized")

class Gemini15FlashJudgePipeline:
    """Pipeline using Gemini 1.5 Flash to judge Gemini Flash decisions for Web dataset"""
    
    def __init__(self):
        self.api_call_count = 0
        self.total_api_time = 0
        self.total_input_tokens = 0
        self.total_output_tokens = 0
        self.successful_evaluations = 0
        self.failed_evaluations = 0

    def check_available_tables(self):
        """Check what tables are available in the dataset"""
        try:
            print(f"🔍 Checking available tables in {PROJECT_ID}.{DATASET_ID}...")
            
            query = f"""
            SELECT table_name, table_type, creation_time 
            FROM `{PROJECT_ID}.{DATASET_ID}.INFORMATION_SCHEMA.TABLES`
            ORDER BY table_name
            """
            
            tables_df = client.query(query).to_dataframe()
            print(f"📊 Available tables:")
            for _, row in tables_df.iterrows():
                print(f"   - {row['table_name']} ({row['table_type']})")
            
            return tables_df['table_name'].tolist()
            
        except Exception as e:
            print(f"⚠️ Could not check tables: {e}")
            return []

    def load_full_web_dataset(self, limit: int = None) -> pd.DataFrame:
        """Load Web dataset from BigQuery - Simplified without unused fields"""
        
        if limit:
            print(f"📥 Loading TEST Web dataset ({limit} records) from {PROJECT_ID}.{DATASET_ID}.{WEB_TABLE}...")
        else:
            print(f"📥 Loading COMPLETE Web dataset from {PROJECT_ID}.{DATASET_ID}.{WEB_TABLE}...")
        
        # First, check available tables
        available_tables = self.check_available_tables()
        
        # Simplified query using only Web table
        limit_clause = f"LIMIT {limit}" if limit else ""
        query = f"""
        SELECT 
            artifact_id,
            artifact_json_gcs_url,
            model_prompt,
            correct_classification,
            correct_reasoning,
            source,
            '{WEB_TABLE}' as data_source
        FROM `{PROJECT_ID}.{DATASET_ID}.{WEB_TABLE}`
        ORDER BY artifact_id
        {limit_clause}
        """
        
        try:
            print(f"⏳ Executing query to load {'test' if limit else 'full'} Web dataset...")
            df = client.query(query).to_dataframe()
            
            if limit:
                print(f"✅ Successfully loaded TEST Web dataset!")
                print(f"📊 Test records: {len(df):,} (requested: {limit})")
            else:
                print(f"✅ Successfully loaded COMPLETE Web dataset!")
                print(f"📊 Total records: {len(df):,}")
            
            print(f"📊 Classification distribution: {df['correct_classification'].value_counts().to_dict()}")
            print(f"📊 Source distribution: {df['source'].value_counts().to_dict()}")
            print(f"📊 Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")
            
            return df
            
        except Exception as e:
            print(f"❌ Error loading {'test' if limit else 'full'} Web dataset: {type(e).__name__}: {e}")
            raise

    def extract_gcs_path(self, url: str) -> tuple:
        """Extract bucket name and path from GCS URL"""
        try:
            if not isinstance(url, str) or url == 'nan' or not url.strip():
                return None, None
            
            # Check for pandas NaN    
            if pd.isna(url):
                return None, None
                
            # Match gs://bucket-name/path format
            match = re.match(r"gs://([^/]+)/(.+)", url.strip())
            if match:
                return match.group(1), match.group(2)
            return None, None
        except Exception as e:
            print(f"Error extracting GCS path from {url}: {e}")
            return None, None

    def read_json_from_gcs(self, bucket_name: str, json_path: str) -> Optional[Dict]:
        """Read JSON artifact from GCS bucket"""
        try:
            bucket = storage_client.bucket(bucket_name)
            blob = bucket.blob(json_path)
            
            if not blob.exists():
                print(f"JSON file not found: gs://{bucket_name}/{json_path}")
                return None
                
            json_content = blob.download_as_text()
            return json.loads(json_content)
        except Exception as e:
            print(f"Error reading JSON from GCS: {e}")
            return None

    def load_artifact_content(self, row):
        """Load artifact content with retry logic"""
        if not storage_client:
            print(f"⚠️ No storage client, using demo content for {row['artifact_id']}")
            return self.create_demo_content()
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Try to load from artifact_json_gcs_url (primary method for Web table)
                if pd.notna(row.get('artifact_json_gcs_url')):
                    gcs_url = row['artifact_json_gcs_url']
                    if attempt == 0:
                        print(f"📥 Loading Web content from GCS URL: {gcs_url}")
                    
                    bucket_name, json_path = self.extract_gcs_path(gcs_url)
                    if bucket_name and json_path:
                        content = self.read_json_from_gcs(bucket_name, json_path)
                        if content:
                            if attempt == 0:
                                print(f"✅ Successfully loaded Web content for {row['artifact_id']}")
                            return content
                        else:
                            if attempt < max_retries - 1:
                                print(f"⚠️ Attempt {attempt + 1} failed to load Web JSON, retrying...")
                                time.sleep(1)  # Brief pause before retry
                                continue
                    else:
                        print(f"⚠️ Could not parse Web GCS URL: {gcs_url}")
                        break
                
                # If all retries fail, use demo content
                if attempt == max_retries - 1:
                    print(f"⚠️ All {max_retries} attempts failed for {row['artifact_id']}, using demo content")
                    return self.create_demo_content()
                    
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"⚠️ Attempt {attempt + 1} failed for {row['artifact_id']}: {e}, retrying...")
                    time.sleep(1)
                    continue
                else:
                    print(f"⚠️ All retries failed for {row['artifact_id']}: {e}, using demo content")
                    return self.create_demo_content()
        
        return self.create_demo_content()
    
    def create_demo_content(self):
        """Create demo content when GCS loading fails"""
        return {
            "url": "https://example.com/content",
            "text_content": [
                {
                    "type": "heading",
                    "heading": {"level": 1, "text": "Sample Content"}
                },
                {
                    "type": "paragraph", 
                    "paragraphs": [
                        "This is sample content for demonstration purposes.",
                        "It represents the type of content that would be evaluated for brand safety."
                    ]
                }
            ]
        }

    def extract_content_from_artifact(self, artifact_json: Dict) -> Dict[str, any]:
        """Extract text and image content from artifact JSON"""
        content = {
            'text_content': '',
            'image_paths': [],
            'has_image': False,
            'content_type': 'unknown'
        }
        
        try:
            text_parts = []
            
            # Look for structured JSON content with text_content array (main approach)
            text_content_array = artifact_json.get("text_content", [])
            if text_content_array and isinstance(text_content_array, list):
                # Parse structured content
                for content_item in text_content_array:
                    if isinstance(content_item, dict):
                        content_type = content_item.get("type", "")
                        
                        if content_type == "paragraph":
                            paragraphs = content_item.get("paragraphs", [])
                            if paragraphs:
                                text_parts.extend(paragraphs)
                        
                        elif content_type == "image":
                            image_data = content_item.get("image", {})
                            if isinstance(image_data, dict):
                                image_identifier = image_data.get("filepath", "")
                                if image_identifier:
                                    content['image_paths'].append(image_identifier)
                                    content['has_image'] = True
                                
                                # Also extract alt_text for context
                                alt_text = image_data.get("alt_text", "")
                                if alt_text:
                                    text_parts.append(f"[Image: {alt_text}]")
                        
                        elif content_type == "video":
                            video_data = content_item.get("video", {})
                            if isinstance(video_data, dict):
                                frames = video_data.get("frames", [])
                                if frames and len(frames) > 0:
                                    # Use first frame as representative image
                                    first_frame = frames[0] if isinstance(frames[0], dict) else {}
                                    frame_identifier = first_frame.get("filepath", "")
                                    if frame_identifier:
                                        content['image_paths'].append(frame_identifier)
                                        content['has_image'] = True
                    
                    elif isinstance(content_item, str):
                        # Handle direct string content
                        text_parts.append(content_item)
            
            # Fallback to legacy format if no structured content found
            if not text_parts and not content['has_image']:
                # Common text fields
                for field in ['title', 'description', 'caption', 'text', 'content', 'body']:
                    if field in artifact_json and artifact_json[field]:
                        text_parts.append(str(artifact_json[field]))
                
                # Legacy image fields
                legacy_content = artifact_json.get("text_content", "")
                if isinstance(legacy_content, str):
                    text_parts.append(legacy_content)
                    
                    # Look for image identifiers in the text using regex - FIXED VERSION
                    image_pattern = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}\.(?:jpg|jpeg|png|webp|avif)'
                    image_matches = re.findall(image_pattern, legacy_content)
                    if image_matches:
                        # Use the first full match (complete UUID + extension)
                        content['image_paths'].append(image_matches[0])
                        content['has_image'] = True
                
                # Platform-specific extraction
                if 'platform_data' in artifact_json:
                    platform_data = artifact_json['platform_data']
                    if isinstance(platform_data, dict):
                        for key, value in platform_data.items():
                            if isinstance(value, str) and value.strip():
                                text_parts.append(value)
                
                # Legacy image URL fields
                image_fields = ['image_url', 'image_path', 'media_url', 'thumbnail_url']
                for field in image_fields:
                    if field in artifact_json and artifact_json[field]:
                        content['image_paths'].append(artifact_json[field])
                        content['has_image'] = True
            
            # Combine all text content
            content['text_content'] = ' '.join(text_parts).strip()
            
            # Determine content type
            if content['has_image'] and content['text_content']:
                content['content_type'] = 'multimodal'
            elif content['has_image']:
                content['content_type'] = 'image_only'
            elif content['text_content']:
                content['content_type'] = 'text_only'
            else:
                content['content_type'] = 'no_content'
                
        except Exception as e:
            print(f"Error extracting content: {e}")
        
        return content

    def extract_source_context(self, content):
        """Extract text content from artifact"""
        extracted_content = self.extract_content_from_artifact(content)
        return extracted_content['text_content'] if extracted_content['text_content'] else "Sample content for evaluation"

    def query_gemini_judge(self, content, brand_prompt, gemini_flash_verdict):
        """Query Gemini 1.5 Flash with retry logic"""
        max_retries = 3
        base_delay = 1
        
        for attempt in range(max_retries):
            try:
                # Extract text content using the same method as your pipeline
                source_context = self.extract_source_context(content)
                
                # Create judge prompt - Gemini 1.5 Flash evaluating Gemini Flash
                judge_prompt = f"""You are an expert content moderator using Gemini 1.5 Flash to evaluate another AI model's (Gemini Flash) brand safety decision.

Use your reasoning capabilities to provide a thorough analysis of whether Gemini Flash made the right decision.

BRAND GUIDELINES:
{brand_prompt}

CONTENT TO EVALUATE:
{source_context[:2000]}{"..." if len(source_context) > 2000 else ""}

GEMINI FLASH'S VERDICT:
{gemini_flash_verdict}

As Gemini 1.5 Flash, evaluate whether Gemini Flash's verdict is correct and well-reasoned. Consider:
1. Accuracy of the classification given the brand guidelines
2. Quality and completeness of the reasoning provided
3. Any nuances or edge cases that Flash may have missed
4. Whether you would reach the same conclusion

Respond in JSON format:
{{
  "judge_agreement": true/false,
  "confidence": 0.0-1.0,
  "verdict": "Aligned" or "Not-Aligned",
  "reasoning": "Your detailed analysis as Gemini 1.5 Flash of why Flash's decision was correct or incorrect",
  "improvements": ["suggestions for better reasoning if disagreeing"],
  "flash_analysis": "analysis of Flash's performance on this content"
}}"""
                
                # Configure safety settings
                safety_settings = {
                    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
                    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                }
                
                # Make API call with Gemini 1.5 Flash
                start_time = time.time()
                response = gemini_model.generate_content(
                    judge_prompt,
                    safety_settings=safety_settings
                )
                api_time = time.time() - start_time
                
                # Update tracking
                self.api_call_count += 1
                self.total_api_time += api_time
                
                # Track tokens if available
                if hasattr(response, 'usage_metadata'):
                    self.total_input_tokens += response.usage_metadata.prompt_token_count
                    self.total_output_tokens += response.usage_metadata.candidates_token_count
                
                if response.text:
                    result = self.parse_gemini_response(response.text, api_time)
                    result['model_used'] = GEMINI_MODEL_NAME
                    result['attempts'] = attempt + 1
                    return result
                else:
                    if attempt < max_retries - 1:
                        delay = base_delay * (2 ** attempt)  # Exponential backoff
                        print(f"⚠️ Empty response, retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
                        time.sleep(delay)
                        continue
                    else:
                        return {
                            'error': f'Empty response from {GEMINI_MODEL_NAME} after {max_retries} attempts',
                            'model_used': GEMINI_MODEL_NAME,
                            'api_time': api_time,
                            'attempts': max_retries
                        }
                        
            except Exception as e:
                if attempt < max_retries - 1:
                    delay = base_delay * (2 ** attempt)  # Exponential backoff
                    print(f"⚠️ API error: {str(e)[:100]}..., retrying in {delay}s... (attempt {attempt + 1}/{max_retries})")
                    time.sleep(delay)
                    continue
                else:
                    return {
                        'error': f'Failed after {max_retries} attempts: {str(e)}',
                        'model_used': GEMINI_MODEL_NAME,
                        'api_time': 0,
                        'attempts': max_retries
                    }
        
        # Should never reach here, but just in case
        return {
            'error': f'Unexpected failure after {max_retries} attempts',
            'model_used': GEMINI_MODEL_NAME,
            'api_time': 0,
            'attempts': max_retries
        }

    def parse_gemini_response(self, response_text, api_time):
        """Parse Gemini 1.5 Flash response"""
        try:
            # Clean the response - remove markdown code blocks if present
            cleaned_text = response_text.strip()
            if cleaned_text.startswith('```json'):
                cleaned_text = cleaned_text.replace('```json', '').replace('```', '').strip()
            elif cleaned_text.startswith('```'):
                cleaned_text = cleaned_text.replace('```', '').strip()
            
            # Try to find JSON in response
            if '{' in cleaned_text and '}' in cleaned_text:
                start_idx = cleaned_text.find('{')
                end_idx = cleaned_text.rfind('}') + 1
                json_str = cleaned_text[start_idx:end_idx]
                
                parsed = json.loads(json_str)
                
                # Ensure would_reach_same_conclusion is present
                would_reach_same = parsed.get('would_reach_same_conclusion')
                if would_reach_same is None:
                    # Try to infer from the text if not explicitly provided
                    full_text = response_text.lower()
                    if any(phrase in full_text for phrase in ['same conclusion', 'reach the same', 'would also classify', 'independent conclusion']):
                        would_reach_same = True
                    elif any(phrase in full_text for phrase in ['different conclusion', 'would not reach', 'disagree with classification']):
                        would_reach_same = False
                    else:
                        # Default to same as judge_agreement if unclear
                        would_reach_same = parsed.get('judge_agreement', None)
                
                return {
                    'judge_agreement': parsed.get('judge_agreement', None),
                    'confidence': parsed.get('confidence', 0.0),
                    'verdict': parsed.get('verdict', 'Unknown'),
                    'would_reach_same_conclusion': would_reach_same,
                    'reasoning': parsed.get('reasoning', ''),
                    'improvements': parsed.get('improvements', []),
                    'flash_analysis': parsed.get('flash_analysis', ''),
                    'raw_response': response_text,
                    'api_time': api_time
                }
            else:
                # Fallback parsing
                agreement = any(word in response_text.lower() for word in ['agree', 'correct', 'accurate'])
                same_conclusion = any(phrase in response_text.lower() for phrase in [
                    'same conclusion', 'reach the same', 'would also', 'independent conclusion'
                ])
                return {
                    'judge_agreement': agreement,
                    'confidence': 0.5,
                    'verdict': 'Aligned' if agreement else 'Not-Aligned', 
                    'would_reach_same_conclusion': same_conclusion if same_conclusion else agreement,
                    'reasoning': response_text,
                    'improvements': [],
                    'flash_analysis': '',
                    'raw_response': response_text,
                    'api_time': api_time
                }
                
        except json.JSONDecodeError as e:
            print(f"⚠️ JSON parsing failed: {e}")
            print(f"Raw text: {response_text[:200]}...")
            
            agreement = any(word in response_text.lower() for word in ['agree', 'reasonable'])
            # Try to infer would_reach_same_conclusion from text
            same_conclusion = any(phrase in response_text.lower() for phrase in [
                'same conclusion', 'reach the same', 'would also', 'agree with the verdict'
            ])
            return {
                'judge_agreement': agreement,
                'would_reach_same_conclusion': same_conclusion if same_conclusion else agreement,
                'confidence': 0.5,
                'verdict': 'Aligned' if agreement else 'Not-Aligned',
                'reasoning': response_text,
                'improvements': [],
                'flash_analysis': '',
                'raw_response': response_text,
                'api_time': api_time
            }

    def run_full_dataset_evaluation(self, web_df: pd.DataFrame, 
                                   batch_size: int = 100,
                                   save_to_bigquery: bool = True) -> pd.DataFrame:
        """
        Run Gemini 1.5 Flash evaluation on the COMPLETE Web dataset
        """
        
        print(f"🔄 RUNNING GEMINI 1.5 FLASH EVALUATION ON FULL WEB DATASET")
        print("=" * 70)
        print(f"📊 Total records to process: {len(web_df):,}")
        print(f"📦 Batch size: {batch_size:,}")
        print(f"📦 Number of batches: {(len(web_df) + batch_size - 1) // batch_size}")
        print(f"🧠 Judge model: {GEMINI_MODEL_NAME}")
        print(f"⚡ Flash model: Gemini Flash (from ground truth)")
        
        # Initialize results storage
        all_results = []
        
        # Process in batches
        for batch_num in range(0, len(web_df), batch_size):
            batch_end = min(batch_num + batch_size, len(web_df))
            batch_df = web_df.iloc[batch_num:batch_end].copy()
            
            batch_number = (batch_num // batch_size) + 1
            total_batches = (len(web_df) + batch_size - 1) // batch_size
            
            print(f"\n🔄 Processing Batch {batch_number}/{total_batches}")
            print(f"   Records: {batch_num + 1:,} to {batch_end:,}")
            print(f"   Batch size: {len(batch_df):,}")
            
            batch_start_time = time.time()
            batch_results = []
            batch_successful = 0
            batch_failed = 0
            
            # Process each record in the batch with progress bar
            for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), 
                               desc=f"Batch {batch_number}", leave=False):
                
                try:
                    # Load artifact content using updated method
                    content = self.load_artifact_content(row)
                    
                    # Get Gemini 1.5 Flash judgment of Flash's decision
                    judge_result = self.query_gemini_judge(
                        content, 
                        row['model_prompt'], 
                        row['correct_reasoning']
                    )
                    
                    # Create result record
                    result = {
                        'artifact_id': str(row['artifact_id']),
                        'data_source': str(row['data_source']),
                        'source': str(row.get('source', 'unknown')),
                        'flash_classification': int(row['correct_classification']),
                        'flash_reasoning': str(row['correct_reasoning'])[:1000],  # Truncate for storage
                        'model_prompt': str(row['model_prompt'])[:500],
                        'judge_agreement': judge_result.get('judge_agreement', None),
                        'verdict': str(judge_result.get('verdict', 'Unknown')),
                        'confidence': float(judge_result.get('confidence', 0.0)),
                        'would_reach_same_conclusion': judge_result.get('would_reach_same_conclusion', None),
                        'reasoning': str(judge_result.get('reasoning', ''))[:1000],
                        'flash_analysis': str(judge_result.get('flash_analysis', ''))[:500],
                        'improvements': str('; '.join(judge_result.get('improvements', [])))[:500],
                        'api_call_time': float(judge_result.get('api_time', 0.0)),
                        'batch_number': int(batch_number),
                        'created_at': pd.Timestamp.now(),
                        'model_used': str(judge_result.get('model_used', GEMINI_MODEL_NAME)),
                        'error_message': str(judge_result.get('error', '')) if 'error' in judge_result else None,
                        'retry_attempts': int(judge_result.get('attempts', 1))
                    }
                    
                    batch_results.append(result)
                    batch_successful += 1
                    self.successful_evaluations += 1
                    
                except Exception as e:
                    print(f"❌ Error processing {row['artifact_id']}: {e}")
                    batch_failed += 1
                    self.failed_evaluations += 1
                    continue
            
            batch_time = time.time() - batch_start_time
            
            # Calculate batch statistics
            if batch_results:
                batch_agreement_rate = sum(1 for r in batch_results 
                                         if r['judge_agreement'] is True) / len(batch_results)
                batch_avg_confidence = np.mean([r['confidence'] for r in batch_results 
                                              if r['confidence'] > 0])
            else:
                batch_agreement_rate = 0
                batch_avg_confidence = 0
            
            print(f"   ✅ Batch {batch_number} completed in {batch_time:.1f}s")
            print(f"   📊 Successful: {batch_successful}, Failed: {batch_failed}")
            print(f"   📊 Agreement rate: {batch_agreement_rate:.1%}")
            print(f"   📊 Avg confidence: {batch_avg_confidence:.2f}")
            print(f"   📊 API calls: {self.api_call_count}")
            
            # Add batch results to overall results
            all_results.extend(batch_results)
            
            # Save batch to BigQuery if requested
            if save_to_bigquery and len(batch_results) > 0:
                self.save_batch_to_bigquery(pd.DataFrame(batch_results), batch_number)
        
        # Convert all results to DataFrame
        results_df = pd.DataFrame(all_results)
        
        # Final statistics
        final_agreement_rate = results_df['judge_agreement'].mean() if len(results_df) > 0 else 0
        final_avg_confidence = results_df['confidence'].mean() if len(results_df) > 0 else 0
        total_api_time = results_df['api_call_time'].sum() if len(results_df) > 0 else 0
        
        print(f"\n✅ GEMINI 1.5 FLASH EVALUATION COMPLETED!")
        print(f"📊 Total records processed: {len(results_df):,}")
        print(f"📊 Successful evaluations: {self.successful_evaluations}")
        print(f"📊 Failed evaluations: {self.failed_evaluations}")
        print(f"📊 Overall agreement rate: {final_agreement_rate:.1%}")
        print(f"📊 Average confidence: {final_avg_confidence:.2f}")
        print(f"📊 Total API calls: {self.api_call_count}")
        print(f"📊 Total API time: {total_api_time:.1f}s ({total_api_time/60:.1f} minutes)")
        
        # Cost estimation for Gemini 1.5 Flash (much cheaper than 2.5 Pro)
        input_cost = self.total_input_tokens * 0.00000075  # Gemini 1.5 Flash pricing
        output_cost = self.total_output_tokens * 0.0000015
        total_cost = input_cost + output_cost
        print(f"💰 Estimated cost: ${total_cost:.4f}")
        print(f"   Input tokens: {self.total_input_tokens:,} (${input_cost:.4f})")
        print(f"   Output tokens: {self.total_output_tokens:,} (${output_cost:.4f})")
        
        return results_df

    def save_batch_to_bigquery(self, batch_df: pd.DataFrame, batch_number: int):
        """Save a batch of results to BigQuery with updated schema"""
        print(f"   💾 Saving batch {batch_number} to BigQuery ({len(batch_df)} records)...")
        
        try:
            # Configure the load job
            table_ref = f"{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}"
            
            job_config = bigquery.LoadJobConfig(
                write_disposition="WRITE_APPEND",  # Append data
                create_disposition="CREATE_IF_NEEDED",  # Create table if needed
                schema=[
                    bigquery.SchemaField("artifact_id", "STRING", mode="REQUIRED"),
                    bigquery.SchemaField("data_source", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("source", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("flash_classification", "INTEGER", mode="NULLABLE"),
                    bigquery.SchemaField("flash_reasoning", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("model_prompt", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("judge_agreement", "BOOLEAN", mode="NULLABLE"),
                    bigquery.SchemaField("verdict", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("confidence", "FLOAT", mode="NULLABLE"),
                    bigquery.SchemaField("would_reach_same_conclusion", "BOOLEAN", mode="NULLABLE"),
                    bigquery.SchemaField("reasoning", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("flash_analysis", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("improvements", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("api_call_time", "FLOAT", mode="NULLABLE"),
                    bigquery.SchemaField("batch_number", "INTEGER", mode="NULLABLE"),
                    bigquery.SchemaField("created_at", "TIMESTAMP", mode="NULLABLE"),
                    bigquery.SchemaField("model_used", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("error_message", "STRING", mode="NULLABLE"),
                    bigquery.SchemaField("retry_attempts", "INTEGER", mode="NULLABLE"),
                ]
            )
            
            # Load the batch
            job = client.load_table_from_dataframe(batch_df, table_ref, job_config=job_config)
            job.result()  # Wait for completion
            
            print(f"   ✅ Batch {batch_number} saved successfully!")
            
        except Exception as e:
            print(f"   ❌ Error saving batch {batch_number}: {e}")
            # Don't stop the whole process for one batch failure

def main_gemini_15_flash_evaluation():
    """
    Main function to run Gemini 1.5 Flash evaluation on Web dataset
    """
    
    print("🚀 STARTING GEMINI 1.5 FLASH JUDGE EVALUATION - FULL WEB DATASET")
    print("=" * 80)
    
    try:
        # Initialize pipeline
        pipeline = Gemini15FlashJudgePipeline()
        
        # Step 1: Load the complete Web dataset
        print(f"\n📋 Step 1: Loading complete Web dataset...")
        web_df = pipeline.load_full_web_dataset()
        
        if web_df.empty:
            print("❌ No data loaded from Web table")
            return None
        
        # Step 2: Test with one sample first
        print(f"\n🧪 Step 2: Testing with one sample...")
        test_row = web_df.iloc[0]
        print(f"   Testing Web artifact: {test_row['artifact_id']}")
        
        content = pipeline.load_artifact_content(test_row)
        judge_result = pipeline.query_gemini_judge(
            content, 
            test_row['model_prompt'], 
            test_row['correct_reasoning']
        )
        
        if 'error' in judge_result:
            print(f"❌ Test failed: {judge_result['error']}")
            return None
        else:
            print(f"✅ Test successful!")
            print(f"   Agreement: {judge_result.get('judge_agreement')}")
            print(f"   Confidence: {judge_result.get('confidence')}")
            print(f"   API time: {judge_result.get('api_time'):.2f}s")
        
        # Step 3: Run evaluation on dataset
        batch_size = 50  # Optimized batch size for full dataset
        print(f"\n🔄 Step 3: Running evaluation on {len(web_df):,} records...")
        results_df = pipeline.run_full_dataset_evaluation(
            web_df, 
            batch_size=batch_size,
            save_to_bigquery=True
        )
        
        if results_df.empty:
            print("❌ No evaluation results generated")
            return None
        
        # Step 4: Create summary report
        print(f"\n📄 Step 4: Creating summary report...")
        
        summary_report = {
            'timestamp': pd.Timestamp.now().isoformat(),
            'dataset_info': {
                'source_table': f"{PROJECT_ID}.{DATASET_ID}.{WEB_TABLE}",
                'results_table': f"{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}",
                'total_records_processed': len(results_df),
                'successful_evaluations': pipeline.successful_evaluations,
                'failed_evaluations': pipeline.failed_evaluations,
                'batch_size': batch_size
            },
            'model_info': {
                'judge_model': GEMINI_MODEL_NAME,
                'flash_model': 'gemini-flash',
                'api_calls': pipeline.api_call_count,
                'total_api_time': pipeline.total_api_time,
                'total_input_tokens': pipeline.total_input_tokens,
                'total_output_tokens': pipeline.total_output_tokens
            },
            'performance_metrics': {
                'overall_agreement_rate': float(results_df['judge_agreement'].mean()),
                'average_confidence': float(results_df['confidence'].mean()),
                'classification_breakdown': dict(results_df['flash_classification'].value_counts()),
                'agreement_by_classification': dict(
                    results_df.groupby('flash_classification')['judge_agreement'].mean()
                )
            }
        }
        
        # Save summary to file
        summary_filename = f"gemini_15_flash_web_summary_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(summary_filename, 'w') as f:
            json.dump(summary_report, f, indent=2, default=str)
        
        print(f"✅ Summary report saved to: {summary_filename}")
        
        # Final success message
        print(f"\n🎉 GEMINI 1.5 FLASH JUDGE EVALUATION COMPLETED SUCCESSFULLY!")
        print(f"📊 Total records processed: {len(results_df):,}")
        print(f"📊 Overall agreement rate: {results_df['judge_agreement'].mean():.1%}")
        print(f"📊 API calls made: {pipeline.api_call_count:,}")
        print(f"💾 Results saved to: {PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}")
        print(f"📄 Summary report: {summary_filename}")
        
        # Show final query to view results
        print(f"\n🔍 QUERY TO VIEW RESULTS:")
        print(f"SELECT * FROM `{PROJECT_ID}.{DATASET_ID}.{RESULTS_TABLE}` ORDER BY created_at DESC LIMIT 100")
        
        return results_df, summary_report
        
    except Exception as e:
        print(f"❌ GEMINI 1.5 FLASH EVALUATION FAILED: {type(e).__name__}: {e}")
        traceback.print_exc()
        return None, None

if __name__ == "__main__":
    # Run the evaluation on FULL WEB DATASET with Gemini 1.5 Flash
    print("🚀 RUNNING FULL WEB DATASET EVALUATION WITH GEMINI 1.5 FLASH")
    print("Enhanced with robust retry logic and cleaned schema")
    print("=" * 60)
    
    # Run full Web dataset evaluation
    results, summary = main_gemini_15_flash_evaluation()
    
    if results is not None:
        print(f"\n✅ FULL Web dataset evaluation completed successfully!")
        print(f"📊 {len(results):,} evaluations completed")
        print(f"📊 Agreement rate: {results['judge_agreement'].mean():.1%}")
        print(f"📊 Average retry attempts: {results.get('retry_attempts', pd.Series([1])).mean():.1f}")
    else:
        print(f"\n❌ FULL Web dataset evaluation failed - check error messages above")