# 01 - Collect Responses

This notebook runs inference for all prompts, languages, and models using Ollama.

## Prerequisites
- Ollama must be installed and running locally
- Required models must be pulled: `ollama pull gemma3:1b` and `ollama pull llama3.2:1b`

In [1]:
# Add parent directory to path for imports
import sys
sys.path.insert(0, '..')

In [2]:
from src.load_prompts import load_prompts, get_task_types, get_languages
from src.infer_ollama import (
    run_inference, 
    run_full_inference,
    load_responses, 
    get_models, 
    get_inference_params,
)
from pathlib import Path
import pandas as pd

## 1. Verify Configuration

In [3]:
# Check configured models
print("Configured Models:")
for model in get_models():
    print(f"  - {model['id']}: {model['name']}")

print("\nInference Parameters:")
params = get_inference_params()
for k, v in params.items():
    print(f"  {k}: {v}")

Configured Models:
  - gemma3:1b: Gemma 3 1B
  - llama3.2:1b: Llama 3.2 1B

Inference Parameters:
  temperature: 0.3
  num_predict: 256
  runs_per_prompt: 2


## 2. Verify Prompts Dataset

In [4]:
# Load and inspect prompts
prompts_df = load_prompts(prepend_control_line=False)
print(f"Total prompts: {len(prompts_df)}")
print(f"\nTask types: {get_task_types()}")
print(f"Languages: {get_languages()}")
print(f"\nPrompts per task type:")
print(prompts_df.groupby('task_type').size())

Total prompts: 60

Task types: ['classification', 'creative', 'factual', 'reasoning', 'summarization']
Languages: ['DE', 'EN', 'TR']

Prompts per task type:
task_type
classification    12
creative          12
factual           12
reasoning         12
summarization     12
dtype: int64


## 3. Check Ollama Connection

In [5]:
import ollama

try:
    models_list = ollama.list()
    print("Available Ollama models:")
    for model in models_list.get('models', []):
        print(f"  - {model['name']}")
except Exception as e:
    print(f"Error connecting to Ollama: {e}")
    print("\nMake sure Ollama is running: `ollama serve`")

Available Ollama models:
Error connecting to Ollama: 'name'

Make sure Ollama is running: `ollama serve`


## 4. Run Inference

This will generate responses for all prompts, languages, runs, and models.

**Expected calls:** 20 prompts × 3 languages × 2 runs × 2 models = **240 calls**

In [17]:
# Check existing responses
existing = load_responses()
print(f"Existing responses: {len(existing)}")

if len(existing) > 0:
    existing_df = pd.DataFrame(existing)
    print("\nBy model:")
    print(existing_df.groupby('model_id').size())
    print("\nBy language:")
    print(existing_df.groupby('language').size())

Existing responses: 360

By model:
model_id
gemma3:1b      120
gemma3:4b      120
llama3.2:1b    120
dtype: int64

By language:
language
DE    120
EN    120
TR    120
dtype: int64


In [18]:
# Run full inference (skips existing by default)
# This may take 10-30 minutes depending on hardware

count = run_full_inference(skip_existing=True)
print(f"\nGenerated {count} new responses")

Total inference calls needed: 360 (skipped 360 existing)


Generating responses: 100%|██████████| 360/360 [07:35<00:00,  1.26s/it]


Generated 360 new responses





## 5. Verify Results

In [19]:
# Load and verify all responses
responses = load_responses()
responses_df = pd.DataFrame(responses)

print(f"Total responses: {len(responses_df)}")
print(f"\nExpected: {20 * 3 * 2 * len(get_models())} (20 prompts × 3 langs × 2 runs × {len(get_models())} models)")

# Check coverage
print("\nCoverage by model and language:")
coverage = responses_df.groupby(['model_id', 'language']).size().unstack(fill_value=0)
print(coverage)

Total responses: 720

Expected: 720 (20 prompts × 3 langs × 2 runs × 6 models)

Coverage by model and language:
language        DE  EN  TR
model_id                  
gemma3:1b       40  40  40
gemma3:4b       40  40  40
llama3.2:1b     40  40  40
llama3.2:3b     40  40  40
phi3:latest     40  40  40
phi4-mini:3.8b  40  40  40


In [20]:
# Check for non-compliant responses
non_compliant = responses_df[responses_df['non_compliant'] == True]
print(f"Non-compliant responses: {len(non_compliant)}")

if len(non_compliant) > 0:
    print("\nNon-compliant by task type:")
    print(non_compliant.groupby('task_type').size())

Non-compliant responses: 53

Non-compliant by task type:
task_type
classification    32
factual           11
reasoning         10
dtype: int64


## 6. Retry Non-Compliant Responses

Re-generate responses that violated format constraints with a stricter prompt wrapper.

In [50]:
import json
from datetime import datetime, timezone
from pathlib import Path
from tqdm import tqdm
import ollama

def get_stricter_control_line(language: str) -> str:
    """Get stricter control line for retry attempts."""
    lines = {
        "EN": "IMPORTANT: Output ONLY the answer. No explanations, no reasoning, no extra text. Just the answer.",
        "DE": "WICHTIG: Gib NUR die Antwort aus. Keine Erklärungen, keine Begründung, kein zusätzlicher Text. Nur die Antwort.",
        "TR": "ÖNEMLİ: SADECE cevabı yaz. Açıklama yok, gerekçe yok, fazladan metin yok. Sadece cevap."
    }
    return lines.get(language, lines["EN"])

def retry_non_compliant_responses(responses_df, responses_file: Path):
    """Retry non-compliant responses with stricter prompts and overwrite."""
    
    non_compliant = responses_df[responses_df['non_compliant'] == True]
    
    if len(non_compliant) == 0:
        print("No non-compliant responses to retry.")
        return 0
    
    print(f"Retrying {len(non_compliant)} non-compliant responses...")
    
    # Load all responses
    all_responses = []
    with open(responses_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                all_responses.append(json.loads(line))
    
    # Create index for quick lookup
    response_index = {}
    for i, r in enumerate(all_responses):
        key = (r['model_id'], r['prompt_id'], r['language'], r['run_id'])
        response_index[key] = i
    
    # Retry each non-compliant response
    retried_count = 0
    for _, row in tqdm(non_compliant.iterrows(), total=len(non_compliant), desc="Retrying"):
        key = (row['model_id'], row['prompt_id'], row['language'], row['run_id'])
        
        if key not in response_index:
            continue
        
        idx = response_index[key]
        original = all_responses[idx]
        
        # Create stricter prompt
        stricter_control = get_stricter_control_line(row['language'])
        original_prompt = original['prompt_text']
        stricter_prompt = f"{stricter_control}\n\n{original_prompt}"
        
        try:
            # Generate new response
            response = ollama.generate(
                model=row['model_id'],
                prompt=stricter_prompt,
                options={
                    "temperature": original['temperature'],
                    "num_predict": original['max_new_tokens']
                }
            )
            
            new_response_text = response["response"]
            
            # Check if new response is compliant
            from src.infer_ollama import detect_non_compliance
            new_non_compliant = detect_non_compliance(
                new_response_text, 
                row['task_type'], 
                row['prompt_id']
            )
            
            # Update the response
            all_responses[idx]['response_text'] = new_response_text
            all_responses[idx]['non_compliant'] = new_non_compliant
            all_responses[idx]['timestamp_utc'] = datetime.now(timezone.utc).isoformat()
            all_responses[idx]['retry'] = True
            
            retried_count += 1
            
        except Exception as e:
            print(f"Error retrying {key}: {e}")
    
    # Write back all responses
    with open(responses_file, 'w', encoding='utf-8') as f:
        for r in all_responses:
            f.write(json.dumps(r, ensure_ascii=False) + '\n')
    
    print(f"\nRetried {retried_count} responses.")
    
    # Check how many are still non-compliant
    still_non_compliant = sum(1 for r in all_responses if r.get('non_compliant', False))
    print(f"Still non-compliant after retry: {still_non_compliant}")
    
    return retried_count

# Run retry
responses_file = Path('../data/responses.jsonl')
retry_non_compliant_responses(responses_df, responses_file)

# Reload and show updated stats
responses = load_responses()
responses_df = pd.DataFrame(responses)
non_compliant_after = responses_df[responses_df['non_compliant'] == True]
print(f"\nFinal non-compliant count: {len(non_compliant_after)}")
if len(non_compliant_after) > 0:
    print("Remaining non-compliant by task type:")
    print(non_compliant_after.groupby('task_type').size())

Retrying 1 non-compliant responses...


Retrying:   0%|          | 0/1 [00:00<?, ?it/s]

Retrying: 100%|██████████| 1/1 [00:09<00:00,  9.72s/it]


Retried 1 responses.
Still non-compliant after retry: 1

Final non-compliant count: 1
Remaining non-compliant by task type:
task_type
classification    1
dtype: int64





## Next Steps

Proceed to `02_metrics_and_plots.ipynb` to:
1. Compute LaBSE embeddings for open-text responses
2. Calculate cross-lingual similarity metrics
3. Run task-aware checks for discrete-answer tasks (non-compliant responses will be handled automatically)
4. Generate visualizations

## 7. Re-evaluate All Responses with Updated Rules

Re-check all responses with the updated (more lenient) non-compliance detection.

In [51]:
# Reload the module to get updated detect_non_compliance function
import importlib
import src.infer_ollama
importlib.reload(src.infer_ollama)
from src.infer_ollama import detect_non_compliance, load_responses

# Re-evaluate all responses
responses_file = Path('../data/responses.jsonl')

all_responses = []
with open(responses_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            all_responses.append(json.loads(line))

# Re-check non-compliance with updated rules
updated_count = 0
for r in all_responses:
    old_status = r.get('non_compliant', False)
    new_status = detect_non_compliance(r['response_text'], r['task_type'], r['prompt_id'])
    
    if old_status != new_status:
        r['non_compliant'] = new_status
        updated_count += 1

# Save updated responses
with open(responses_file, 'w', encoding='utf-8') as f:
    for r in all_responses:
        f.write(json.dumps(r, ensure_ascii=False) + '\n')

print(f"Updated {updated_count} response statuses")

# Show new stats
non_compliant_count = sum(1 for r in all_responses if r.get('non_compliant', False))
print(f"\nNon-compliant responses after re-evaluation: {non_compliant_count}")

# Show by task type
responses_df = pd.DataFrame(all_responses)
non_compliant_df = responses_df[responses_df['non_compliant'] == True]
if len(non_compliant_df) > 0:
    print("\nNon-compliant by task type:")
    print(non_compliant_df.groupby('task_type').size())
else:
    print("\nAll responses are now compliant!")

Updated 0 response statuses

Non-compliant responses after re-evaluation: 1

Non-compliant by task type:
task_type
classification    1
dtype: int64
