In [29]:
import json
import csv
import pandas as pd
import numpy as np
import time
import os
import uuid
import asyncio
from litellm import acompletion
from aiolimiter import AsyncLimiter


def generate_response_id():
    """Generate a unique ID for the response."""
    return str(uuid.uuid4())

In [30]:
from dotenv import load_dotenv

# Load environment variables from .env file
# LiteLLM automatically reads API keys from environment variables:
# - OPENAI_API_KEY for OpenAI models
# - GEMINI_API_KEY for Google Gemini models
# - ANTHROPIC_API_KEY for Anthropic models
load_dotenv()

True

In [31]:
from openai import AsyncOpenAI

async def create_response_with_ai_gateway(model, user_input, response_schema, system_prompt, reasoning_effort=None):
    """
    Create a response using Cornell AI Gateway (async version).
    
    Args:
        model: Model identifier (e.g., "openai.gpt-4o", "deepseek.r1")
        user_input: The user's input message
        response_schema: JSON schema for structured output
        system_prompt: System prompt for the LLM
        reasoning_effort: Reasoning effort level ("low", "medium", "high") - only for reasoning models
    """
    client = AsyncOpenAI(api_key=os.getenv("AI_GATEWAY_KEY"), base_url="https://api.ai.it.cornell.edu")
    
    # Build the request parameters
    params = {
        "model": model,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]
    }
    
    if reasoning_effort:
        params["reasoning_effort"] = reasoning_effort

    params["response_format"] = {
            "type": "json_schema",
            "json_schema": {
                "name": "batch_statement_response",
                "strict": True,
                "schema": response_schema
            }
    }
    
    response = await client.chat.completions.create(**params)
    return response

In [32]:
async def create_response_for_openai(model, user_input, response_schema, system_prompt, reasoning_config=None):
    """Create a response using LiteLLM for OpenAI models.
    
    Args:
        model: Model identifier (e.g., "openai/responses/gpt-5.2")
        user_input: The user's input message
        response_schema: JSON schema for structured output
        system_prompt: System prompt for the LLM
        reasoning_config: Reasoning configuration - can be:
            - "none" or None for no reasoning
            - {"effort": "high", "summary": "detailed"} for OpenAI reasoning
    
    Returns:
        LiteLLM response object
    """
    # Handle reasoning config
    if reasoning_config is None or reasoning_config == "none":
        reasoning_effort = "none"
    else:
        reasoning_effort = reasoning_config
    
    response = await acompletion(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ],
        reasoning_effort=reasoning_effort,
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": "batch_statement_response",
                "strict": True,
                "schema": response_schema
            }
        }
    )
    return response

In [33]:
models =[{
    "provider": "google",
    "models" : [
        {"name": "google.gemini-3-flash-preview", "reasoning": "high"},
        {"name": "google.gemini-3-flash-preview", "reasoning": "none"},
    ]
},
{
    "provider": "openai",
    "models" : [
        {"name": "openai.gpt-5.2", "reasoning": {"effort": "high", "summary": "detailed"}},
        {"name": "openai.gpt-5.2", "reasoning": "none"},
    ]
},
{
    "provider": "anthropic",
    "models" : [
        {"name": "anthropic.claude-4.5-haiku", "reasoning": "high"},
        {"name": "anthropic.claude-4.5-haiku", "reasoning": "low"},
    ]
},
{
    "provider": "xai",
    "models" : [
        {"name": "xai.grok-4-fast-reasoning", "reasoning": "high"},
        {"name": "xai.grok-4-fast-non-reasoning", "reasoning": "none"},
    ]
},

]

In [34]:
async def create_response(provider, model_name, user_input, response_schema, system_prompt, reasoning_config=None):
    """Route API calls based on provider.
    
    Args:
        provider: Provider name ('openai', 'google', 'anthropic', 'xai', etc.)
        model_name: Full model identifier
        user_input: The user's input message
        response_schema: JSON schema for structured output
        system_prompt: System prompt for the LLM
        reasoning_config: Reasoning configuration from model config
    
    Returns:
        Response object from the appropriate API
    """
    if provider == "openai":
        # For OpenAI, use LiteLLM with openai/responses/ prefix
        litellm_model = f"openai/responses/{model_name.replace('openai.', '')}"
        return await create_response_for_openai(
            litellm_model, user_input, response_schema, system_prompt, reasoning_config
        )
    else:
        # For all other providers, use Cornell AI Gateway
        # Extract reasoning effort for non-OpenAI models
        reasoning_effort = None
        if reasoning_config and reasoning_config != "none":
            if isinstance(reasoning_config, str):
                reasoning_effort = reasoning_config
            elif isinstance(reasoning_config, dict):
                reasoning_effort = reasoning_config.get("effort")
        
        return await create_response_with_ai_gateway(
            model_name, user_input, response_schema, system_prompt, reasoning_effort
        )

In [35]:
def construct_batch_response_schema(statements, response_options):
    """Construct JSON schema for batch response with all statements.
    
    Args:
        statements: List of statement dicts with 'id' and 'prompt' keys
        response_options: List of valid response options
        
    Returns:
        JSON schema for structured output
    """
    statement_ids = [stmt['id'] for stmt in statements]
    statement_texts = [stmt['prompt'] for stmt in statements]
    
    return {
        "type": "object",
        "properties": {
            "responses": {
                "type": "array",
                "description": "Array of responses for each statement",
                "items": {
                    "type": "object",
                    "properties": {
                        "question_id": {
                            "type": "string",
                            "description": "The ID of the statement being responded to",
                            "enum": statement_ids
                        },
                        "input_statement": {
                            "type": "string",
                            "description": "The exact statement text being responded to",
                            "enum": statement_texts
                        },
                        "response": {
                            "type": "string",
                            "description": "The response to the statement",
                            "enum": response_options
                        }
                    },
                    "required": ["question_id", "input_statement", "response"],
                    "additionalProperties": False
                }
            }
        },
        "required": ["responses"],
        "additionalProperties": False
    }

In [36]:
def verify_batch_responses(responses, statement_ids, response_options):
    """Validate all responses in a batch.
    
    Args:
        responses: List of response dicts with 'question_id' and 'response' keys
        statement_ids: List of expected statement IDs
        response_options: List of valid response options
        
    Returns:
        tuple: (is_valid, error_message)
    """
    if not responses:
        return False, "No responses received"
    
    # Check if we got the right number of responses
    if len(responses) != len(statement_ids):
        return False, f"Expected {len(statement_ids)} responses, got {len(responses)}"
    
    # Check all question_ids are present and responses are valid
    received_ids = set()
    for resp in responses:
        qid = resp.get('question_id')
        answer = resp.get('response')
        
        if qid not in statement_ids:
            return False, f"Unexpected question_id: {qid}"
        
        if qid in received_ids:
            return False, f"Duplicate question_id: {qid}"
        
        if answer not in response_options:
            return False, f"Invalid response '{answer}' for {qid}"
        
        received_ids.add(qid)
    
    # Check all expected IDs were received
    missing_ids = set(statement_ids) - received_ids
    if missing_ids:
        return False, f"Missing responses for: {missing_ids}"
    
    return True, None

In [37]:
def construct_batch_user_input(statements):
    """Construct user input containing all statements for batch processing.
    
    Args:
        statements: List of statement dicts with 'id' and 'prompt' keys
        
    Returns:
        Formatted string with all statements
    """
    prompt = "Please respond to each of the following statements:\n\n"
    for stmt in statements:
        prompt += f"[{stmt['id']}]: {stmt['prompt']}\n"
    return prompt


In [38]:
def save_response(folder, id, data):
    """Save full LLM response to responses/ folder.
    
    Args:
        response_json: The full response object as a dictionary
        
    Returns:
        str: The response ID
    """
    filepath = f"{folder}/{id}.json"
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
    return id

In [39]:
def append_to_csv(filepath, row_data, statement_ids):
    """Append a single row to CSV, create with headers if file doesn't exist.
    
    Args:
        filepath: Path to the CSV file
        row_data: Dictionary containing row data with question IDs as keys
        statement_ids: List of statement IDs for column ordering
    """
    # Build fieldnames: metadata columns + question ID columns + response ID
    fieldnames = ['country', 'model', 'reasoning', 'run_number', 'attempt'] + statement_ids + ['llm_response_id']
    file_exists = os.path.exists(filepath)
    
    with open(filepath, 'a', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row_data)

In [40]:
async def process_all_statements_batch(statements, run_number, response_options, 
                                        response_schema, system_prompt, csv_filepath, 
                                        provider, model_name, reasoning_config, max_attempts,
                                        country=None, response_folder=None):
    """Process all statements in a single batch with retry logic.
    
    Args:
        statements: List of statement dicts with 'id' and 'prompt'
        run_number: Current run iteration number
        response_options: List of valid response options
        response_schema: JSON schema for response validation
        system_prompt: System prompt for the LLM
        csv_filepath: Path to CSV file for logging
        provider: Provider name for API routing
        model_name: The model identifier
        reasoning_config: Reasoning configuration for the model
        max_attempts: Maximum retry attempts for invalid responses
        country: Country name for country-based processing (optional)
        response_folder: Folder to save response JSON files (optional)
        
    Returns:
        bool: True if successful, False if all attempts failed
    """
    statement_ids = [stmt['id'] for stmt in statements]
    user_input = construct_batch_user_input(statements)
    
    for attempt in range(1, max_attempts + 1):
        try:
            response = await create_response(
                provider, model_name, user_input, response_schema, system_prompt, reasoning_config
            )
            
            # Handle both LiteLLM and OpenAI response formats
            if hasattr(response, 'model_dump'):
                response_json = response.model_dump()
            else:
                response_json = response.to_dict() if hasattr(response, 'to_dict') else dict(response)
            
            # Parse the response content
            content = response_json["choices"][0]["message"]["content"]
            output = json.loads(content)
            
            # Generate unique ID and save full response
            response_id = generate_response_id()
            if response_folder:
                save_response(response_folder, response_id, response_json)
            
            # Validate all responses in the batch
            responses = output.get('responses', [])
            is_valid, error_msg = verify_batch_responses(responses, statement_ids, response_options)
            
            if is_valid:
                # Format reasoning config for CSV (extract effort from dict if needed)
                if not reasoning_config:
                    reasoning_str = 'none'
                elif isinstance(reasoning_config, dict):
                    reasoning_str = reasoning_config.get('effort', 'none')
                else:
                    reasoning_str = str(reasoning_config)
                
                # Build row data with question IDs as columns
                row_data = {
                    'country': country,
                    'model': model_name,
                    'reasoning': reasoning_str,
                    'run_number': run_number,
                    'attempt': attempt,
                    'llm_response_id': response_id
                }
                
                # Add each question's response as a column
                for resp in responses:
                    row_data[resp['question_id']] = resp['response']
                
                append_to_csv(csv_filepath, row_data, statement_ids)
                return True
            
            print(f"Validation failed for {country}/{model_name}, attempt {attempt}/{max_attempts}: {error_msg}")
            
        except Exception as e:
            print(f"Error processing batch for {country}/{model_name}, attempt {attempt}/{max_attempts}: {str(e)}")
    
    print(f"FAILED: {country}/{model_name} after {max_attempts} attempts")
    return False

In [41]:
async def process_batch_with_rate_limit(rate_limiter, statements, run_number, 
                                         response_options, response_schema, system_prompt,
                                         csv_filepath, provider, model_name, reasoning_config,
                                         max_attempts, country=None, response_folder=None):
    """Wrapper that applies rate limiting before processing a batch.
    
    Args:
        rate_limiter: AsyncLimiter instance for rate limiting
        (all other args same as process_all_statements_batch)
        
    Returns:
        bool: True if successful, False if all attempts failed
    """
    async with rate_limiter:
        return await process_all_statements_batch(
            statements, run_number, response_options, response_schema, system_prompt,
            csv_filepath, provider, model_name, reasoning_config, max_attempts,
            country, response_folder
        )

In [42]:
def construct_country_system_prompt(country):
    return f"""
        You are a primary or secondary school teacher from {country} responding to a survey about the use of artificial intelligence (AI) in education.
        Please indicate your level of agreement with the list of statements presented in the survey.
        In the survey, AI is defined as below:
        'Artificial intelligence' is the capacity for computers to perform tasks traditionally thought to involve human intelligence. This can include making predictions, suggesting decisions, or generating text.

        Please respond with only one of the response options for each statement.
        Response options:
        ['Strongly disagree', 'Disagree', 'Agree', 'Strongly agree', 'I don't know']
    """

In [43]:
async def process_all_countries(countries, statements, response_options, 
                                 csv_filepath, response_folder, num_runs, model_configs, max_attempts,
                                 requests_per_minute=60):
    """Process all statements as a batch for all countries and all models with rate limiting.
    
    Args:
        countries: List of country names
        statements: List of statement dicts with 'id' and 'prompt'
        response_options: List of valid response options
        csv_filepath: Path to CSV file for logging
        response_folder: Folder to save response JSON files
        num_runs: Number of times to run for each country
        model_configs: List of provider configs with format:
            [{\"provider\": \"google\", \"models\": [{\"name\": \"...\", \"reasoning\": \"...\"}]}]
        max_attempts: Maximum retry attempts for invalid responses
        requests_per_minute: Rate limit for API requests (default 60)
    """
    # Create rate limiter
    rate_limiter = AsyncLimiter(requests_per_minute, 60)  # X requests per 60 seconds
    
    # Get statement IDs for schema
    statement_ids = [stmt['id'] for stmt in statements]
    total_countries = len(countries)
    total_statements = len(statements)
    
    # Build response schema for batch processing
    response_schema = construct_batch_response_schema(statements, response_options)
    
    # Flatten model configs to list of (provider, model_name, reasoning_config) tuples
    all_models = [
        (pc['provider'], m['name'], m.get('reasoning'))
        for pc in model_configs 
        for m in pc['models']
    ]
    total_models = len(all_models)
    total_tasks_per_country = total_models * num_runs
    
    print(f"Configuration: {total_models} models x {num_runs} runs = {total_tasks_per_country} tasks per country")
    print(f"Total countries: {total_countries}")
    print(f"Rate limit: {requests_per_minute} requests/minute")
    
    for country_idx, country in enumerate(countries, 1):
        print(f"\n{'='*60}")
        print(f"[{country_idx}/{total_countries}] Processing country: {country}")
        print(f"{'='*60}")
        
        system_prompt = construct_country_system_prompt(country)
        
        print(f"  Processing {total_tasks_per_country} tasks in parallel ({total_models} models x {num_runs} runs)...")
        
        # Create tasks for ALL (model, run) combinations to process in parallel
        tasks = [
            process_batch_with_rate_limit(
                rate_limiter, statements, run_number, response_options,
                response_schema, system_prompt, csv_filepath,
                provider, model_name, reasoning_config, max_attempts,
                country, response_folder
            )
            for provider, model_name, reasoning_config in all_models
            for run_number in range(1, num_runs + 1)
        ]
        
        # Execute all tasks in parallel (rate limiter controls throughput)
        results = await asyncio.gather(*tasks)
        
        success_count = sum(results)
        print(f"  Completed: {success_count}/{total_tasks_per_country} tasks successful")
    
    print(f"\nCompleted processing {total_countries} countries x {total_models} models x {num_runs} runs -> {csv_filepath}")

In [None]:
def save_gen_config_details_to_file(output_folder, countries, statements, response_schema, num_runs, models, max_attempts, prompt, user_input):
    with open(f"{output_folder}/gen_config_details.txt", "w") as f:
        f.write(f"Countries: {countries}\n")
        f.write(f"Statements: {statements}\n")
        f.write(f"Response schema: {response_schema}\n")
        f.write(f"Num runs: {num_runs}\n")
        f.write(f"Models: {models}\n")
        f.write(f"Max attempts: {max_attempts}\n")
        f.write(f"System Prompt: {prompt}\n")
        f.write(f"User input: {user_input}\n")

        

In [45]:
countries_file = "country_language_list.csv"

countries_df = pd.read_csv(countries_file)

unique_list_of_countries = countries_df["CNTRY_FULL"].unique().tolist()

print(len(unique_list_of_countries))


55


In [46]:
from datetime import datetime

# =============================================================================
# CONFIGURATION
# =============================================================================

# Model configurations with provider and reasoning settings
# Uses the 'models' variable defined earlier with format:
# [{"provider": "...", "models": [{"name": "...", "reasoning": "..."}]}]
MODEL_CONFIGS = models  # Reference the models config defined earlier

MAX_ATTEMPTS = 3  # Maximum retry attempts for invalid responses
NUM_RUNS = 10      

# Rate limiting - requests per minute
# Adjust based on your API tier:
#   - OpenAI: 500-10000 RPM depending on tier
#   - Gemini: 60-1000 RPM depending on tier  
#   - Anthropic: 50-4000 RPM depending on tier
REQUESTS_PER_MINUTE = 100

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

response_options = ["Strongly disagree", "Disagree", "Agree", "Strongly agree", "I don't know"]

list_of_statements = json.load(open("questions/usa_english.json", encoding='utf-8'))["questions"][0]["questions"]

# Get statement IDs for schema construction
statement_ids = [stmt['id'] for stmt in list_of_statements]

output_folder = f"output/batch_processing/{timestamp}"
output_csv_file = f"{output_folder}/all.csv"
response_folder = f"{output_folder}/responses"

countries = unique_list_of_countries[:2]

# Create output folder if not exists
os.makedirs(output_folder, exist_ok=True)

# Create response folder if not exists
os.makedirs(response_folder, exist_ok=True)

sample_user_input = construct_batch_user_input(list_of_statements)

# Save configuration details
save_gen_config_details_to_file(output_folder, countries, list_of_statements, 
                                 construct_batch_response_schema(list_of_statements, response_options), 
                                 NUM_RUNS, MODEL_CONFIGS, MAX_ATTEMPTS, construct_country_system_prompt("{country_name}"), sample_user_input)

# Run batch processing for all countries and models
await process_all_countries(
    countries, list_of_statements, response_options, 
    output_csv_file, response_folder, NUM_RUNS, MODEL_CONFIGS, MAX_ATTEMPTS,
    REQUESTS_PER_MINUTE
)


Configuration: 8 models x 10 runs = 80 tasks per country
Total countries: 2
Rate limit: 100 requests/minute

[1/2] Processing country: United Arab Emirates
  Processing 80 tasks in parallel (8 models x 10 runs)...


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mProvider List: https://docs.litellm.ai/d