<a href="https://colab.research.google.com/github/laurencoetzee001/Beads_Co-detect/blob/main/prompt_optimisation_location_exchange_revisions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Bead Trade Coding - 27,060 Rows Google Colab PRODUCTION
========================================================

COMPLETE IMPLEMENTATION:
✓ All 13 fields from codebook (with proper codes)
✓ Enhanced 4a_exchange rules (73% human-AI agreement)
✓ Two-layer location extraction (spellings + context)
✓ Proper Excel flattening (29 output columns)
✓ AGGRESSIVE AUTO-SAVE (every 100 rows)
✓ BATCH PROCESSING (2,000 rows per session)
✓ Fresh start from Row 0
✓ Real-time progress output
✓ Google Colab integration

All variables needed for analysis included
"""

import subprocess
import sys
import os
import json
import time
from datetime import datetime

# Force unbuffered output
sys.stdout = os.__stdout__

# === INSTALL DEPENDENCIES ===
print("Checking dependencies...")
for pkg in ["anthropic", "openpyxl", "pandas", "tenacity"]:
    try:
        __import__(pkg if pkg != "openpyxl" else "openpyxl")
    except ImportError:
        print(f"Installing {pkg}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", pkg, "-q"])

import pandas as pd
from anthropic import Anthropic
from tenacity import retry, stop_after_attempt, wait_exponential

# === GOOGLE DRIVE SETUP ===
try:
    from google.colab import drive
    drive.mount('/content/drive')
    GDRIVE_BASE = '/content/drive/MyDrive/bead_coding_27k_production'
    os.makedirs(GDRIVE_BASE, exist_ok=True)
    IN_COLAB = True
    print("✓ Google Drive mounted\n")
except:
    GDRIVE_BASE = './bead_coding_27k_production'
    os.makedirs(GDRIVE_BASE, exist_ok=True)
    IN_COLAB = False
    print("⚠ Running locally\n")

# === CONFIGURATION ===
INPUT_FILE = "All_entries_beads_cleaned.xlsx"
OUTPUT_BASE = GDRIVE_BASE
LOG_FILE = os.path.join(OUTPUT_BASE, "processing.log")
STATE_FILE = os.path.join(OUTPUT_BASE, "session_state.json")
AUTOSAVE_EVERY = 100  # Save to Excel every 100 rows
ROWS_PER_SESSION = 2000

MODEL_NAME = "claude-haiku-4-5-20251001"
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
MAX_SESSION_MINUTES = 690
TEXT_COLUMN = "text_page_gp"

# === COMPREHENSIVE SYSTEM PROMPT (13 fields + enhanced rules) ===
SYSTEM_PROMPT = """You are a historian analyzing pre-colonial African bead trade records.

TASK: Extract ALL 13 structured data fields. Be conservative - require explicit evidence.

RESPONSE FORMAT: Return ONLY a JSON object. NO markdown, NO ```json``` tags, NO extra text.

JSON STRUCTURE (13 fields):
{
  "1_price_HUMAN": {
    "status": "yes|no|xo",
    "amount": "number or measurement, or null",
    "currency": "currency/commodity, or null",
    "description": "full price text from source"
  },
  "2_size_HUMAN": {
    "code": 1-6 or null,
    "description": "exact size text from source"
  },
  "3_colour_HUMAN": {
    "codes": [array of 1-14],
    "description": "exact color text, REQUIRED if code=14"
  },
  "4_location_HUMAN": {
    "codes": [array of 1-4],
    "names": "actual location names"
  },
  "4_location_names_original_spellings": "exact historical spellings (e.g. Bontuku, Whydah)",
  "4_location_context": "Location | Geographic Context | Strategic Significance",
  "5_function_HUMAN": {
    "codes": [array of 1-4],
    "description": "detailed function text"
  },
  "6_origin_of_bead": "geographic origin text or null",
  "7_shape_HUMAN": {
    "codes": [array of 1-12],
    "description": "exact shape text"
  },
  "8_type_bead_HUMAN": {
    "codes": [array of 1-14],
    "description": "exact material text"
  },
  "9_local_name_HUMAN": {
    "exists": "1|2|null",
    "names": ["array of names"] or null
  },
  "10_relationship_": {
    "codes": [array of 1-31],
    "description": "detailed exchange items text"
  },
  "11_units_of_measure": {
    "type": 1-4 or null,
    "description": "exact measurement text"
  },
  "12_bead_ethnic_": ["array of ethnic group names"] or null,
  "13_nature_of_exchange": {
    "code": 1-6 or null,
    "description": "exchange nature text"
  },
  "notes": "additional research context"
}

FIELD CODES:

1_price_HUMAN: status: yes=mentioned, no=not mentioned, xo=exchanged

2_size_HUMAN: 1=large, 2=medium, 3=small, 4=various, 5=thin, 6=thick

3_colour_HUMAN: 1=red, 2=blue, 3=white, 4=pink, 5=coral, 6=amber, 7=copper, 8=green, 9=yellow, 10=transparent, 11=seed glass, 12=black, 13=multicoloured, 14=other

4_location_HUMAN: 1=mountain/hill, 2=lake, 3=river/waterfall, 4=populated place

5_function_HUMAN: 1=jewellery/adornment, 2=currency/exchange, 3=ceremonial/religious, 4=status/gift

6_origin_of_bead: Text describing geographic origin

7_shape_HUMAN: 1=round, 2=tubular, 3=square, 4=oval, 5=oblong, 6=punched, 7=wound, 8=pressed, 9=decorative, 10=faceted, 11=bugle, 12=chevron

8_type_bead_HUMAN: 1=glass, 2=clay, 3=metal, 4=stone, 5=coral, 6=amber, 7=bone, 8=ivory, 9=dried seed, 10=ceramic, 11=wooden, 12=porcelain, 13=shell, 14=eggshell

9_local_name_HUMAN: exists: 1=yes (provide names), 2=unspecified

10_relationship_: 1=wire, 2=cloth, 3=shells, 4=coins, 5=livestock, 6=iron bars, 7=scarabs, 8=precious stones, 9=antiquities, 10=ostrich feathers, 11=ebony/ivory, 12=salt, 13=rubber/gum, 14=medicines, 15=spices/perfumes, 16=wax/seals, 17=leather/hides, 18=weapons, 19=dried food, 20=prints/books, 21=guns/gunpowder, 22=jewellery, 23=textiles, 24=gold/silver, 25=slaves, 26=glass objects, 27=hardware, 28=tobacco, 29=musical instruments, 30=water, 31=alcohol

11_units_of_measure: 1=string, 2=plaited/woven string, 3=necklace/bracelet/waist beads, 4=other

12_bead_ethnic_: Array of ethnic group names

13_nature_of_exchange: 1=consensual, 2=conflictual, 3=unspecified, 4=competitive/bartering, 5=social/gifts, 6=uncommercial

=== ENHANCED RULES FOR 4a_EXCHANGE ===

For 1_price_HUMAN status field:
- "yes" = price mentioned
- "no" = price not mentioned
- "xo" = exchanged (rare, for historical exchange events)

CRITICAL: Only code "xo" for explicit past transactions:
  ✓ "traded", "sold", "bought", "exchanged", "gave", "received"
  ✓ Past tense action
  ✓ Clear parties involved
  ✓ Specific items exchanged

✗ DO NOT code "xo" for:
  ✗ Value descriptions ("beads are worth...")
  ✗ Manufacturing references ("making beads")
  ✗ Hypothetical ("would trade")
  ✗ General statements ("natives trade beads")
  ✗ Observational ("showed beads")
  ✗ Decorative usage

=== LOCATION EXTRACTION (TWO LAYERS) ===

Layer 1: 4_location_names_original_spellings
- Preserve EXACT spellings from historical text
- Include diacritics: Bontúku, Dahomey, Whydah
- Order by first mention in text

Layer 2: 4_location_context
- Geographic location: coastal, inland, near river, etc.
- Strategic significance: trade hub, route junction, etc.
- Political context: territory of X people, colonial region, etc.
- Format: "Location | Geographic | Strategic | Political"

CRITICAL RULES:
1. Return ONLY the JSON object - NO ```json``` tags, NO markdown, NO extra text
2. Use null for missing data (not "unknown")
3. For arrays: [] if no data, null if not applicable
4. ALWAYS include description fields with verbatim text
5. When code=14 (other), description is MANDATORY
6. Base answers ONLY on provided text
7. Preserve exact terminology and details
"""

USER_PROMPT_TEMPLATE = """Analyze this historical text and extract ALL 13 fields:

TEXT:
{text}

Return ONLY the JSON object. No ```json``` tags, no other text."""

# === SESSION STATE ===
class SessionState:
    def __init__(self):
        self.file_path = STATE_FILE
        self.state = self._load()

    def _load(self):
        if os.path.exists(self.file_path):
            with open(self.file_path, 'r') as f:
                return json.load(f)
        # Fresh start from row 0
        return {
            'session_num': 1,
            'last_row': -1,
            'sessions_completed': 0,
            'total_tokens_used': 0,
            'total_cost': 0,
            'batches_completed': []
        }

    def save(self):
        with open(self.file_path, 'w') as f:
            json.dump(self.state, f, indent=2)

    def get_batch_info(self, total_rows=27060):
        session_num = self.state['session_num']
        start_row = (session_num - 1) * ROWS_PER_SESSION
        end_row = min(session_num * ROWS_PER_SESSION, total_rows)

        # If we have a last_row in this batch, resume from there
        if self.state['last_row'] >= start_row:
            actual_start = self.state['last_row'] + 1
        else:
            actual_start = start_row

        return {
            'session': session_num,
            'batch_start': start_row,
            'batch_end': end_row,
            'actual_start': actual_start,
            'batch_size': end_row - start_row
        }

    def update_row(self, row_num):
        self.state['last_row'] = row_num
        self.save()

    def mark_batch_complete(self, total_input_tokens, total_output_tokens, cost):
        self.state['sessions_completed'] += 1
        self.state['session_num'] += 1
        self.state['total_tokens_used'] += total_input_tokens + total_output_tokens
        self.state['total_cost'] += cost
        self.state['batches_completed'].append({
            'session': self.state['sessions_completed'],
            'rows': f"{self.state['last_row']-ROWS_PER_SESSION+1}-{self.state['last_row']}",
            'tokens': total_input_tokens + total_output_tokens,
            'cost': cost
        })
        self.save()

# === UTILITIES ===

def strip_markdown_json(text):
    """Remove markdown and extract JSON."""
    text = text.strip()

    if text.startswith('```'):
        lines = text.split('\n')
        if lines[0].strip().startswith('```'):
            lines = lines[1:]
        if lines and lines[-1].strip() == '```':
            lines = lines[:-1]
        text = '\n'.join(lines).strip()

    if not text.startswith('{'):
        start = text.find('{')
        if start == -1:
            return None
        text = text[start:]

    brace_count = 0
    in_string = False
    escape_next = False

    for i, char in enumerate(text):
        if escape_next:
            escape_next = False
            continue
        if char == '\\':
            escape_next = True
            continue
        if char == '"':
            in_string = not in_string
            continue
        if not in_string:
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                if brace_count == 0:
                    return text[:i+1]

    return None

def validate_json_response(json_obj):
    """Validate required fields."""
    required = [
        '1_price_HUMAN', '2_size_HUMAN', '3_colour_HUMAN', '4_location_HUMAN',
        '5_function_HUMAN', '6_origin_of_bead', '7_shape_HUMAN', '8_type_bead_HUMAN',
        '9_local_name_HUMAN', '10_relationship_', '11_units_of_measure',
        '12_bead_ethnic_', '13_nature_of_exchange'
    ]

    missing = [f for f in required if f not in json_obj]
    if missing:
        return False, f"Missing: {', '.join(missing[:3])}"

    return True, None

def flatten_for_excel(json_obj):
    """Flatten nested JSON to Excel format (29 columns)."""
    flat = {}

    # 1_price_HUMAN (4 columns)
    price = json_obj.get('1_price_HUMAN', {})
    if isinstance(price, dict):
        flat['1_price_status'] = price.get('status')
        flat['1_price_amount'] = price.get('amount')
        flat['1_price_currency'] = price.get('currency')
        flat['1_price_description'] = price.get('description')

    # 2_size_HUMAN (2 columns)
    size = json_obj.get('2_size_HUMAN', {})
    if isinstance(size, dict):
        flat['2_size_code'] = size.get('code')
        flat['2_size_description'] = size.get('description')

    # 3_colour_HUMAN (2 columns)
    color = json_obj.get('3_colour_HUMAN', {})
    if isinstance(color, dict):
        codes = color.get('codes', [])
        flat['3_colour_codes'] = ','.join(map(str, codes)) if codes else None
        flat['3_colour_description'] = color.get('description')

    # 4_location_HUMAN (2 columns)
    location = json_obj.get('4_location_HUMAN', {})
    if isinstance(location, dict):
        codes = location.get('codes', [])
        flat['4_location_codes'] = ','.join(map(str, codes)) if codes else None
        flat['4_location_names'] = location.get('names')

    # NEW: Location enhancements (2 columns)
    flat['4_location_names_original_spellings'] = json_obj.get('4_location_names_original_spellings')
    flat['4_location_context'] = json_obj.get('4_location_context')

    # 5_function_HUMAN (2 columns)
    function = json_obj.get('5_function_HUMAN', {})
    if isinstance(function, dict):
        codes = function.get('codes', [])
        flat['5_function_codes'] = ','.join(map(str, codes)) if codes else None
        flat['5_function_description'] = function.get('description')

    # 6_origin_of_bead (1 column)
    flat['6_origin_of_bead'] = json_obj.get('6_origin_of_bead')

    # 7_shape_HUMAN (2 columns)
    shape = json_obj.get('7_shape_HUMAN', {})
    if isinstance(shape, dict):
        codes = shape.get('codes', [])
        flat['7_shape_codes'] = ','.join(map(str, codes)) if codes else None
        flat['7_shape_description'] = shape.get('description')

    # 8_type_bead_HUMAN (2 columns)
    bead_type = json_obj.get('8_type_bead_HUMAN', {})
    if isinstance(bead_type, dict):
        codes = bead_type.get('codes', [])
        flat['8_type_codes'] = ','.join(map(str, codes)) if codes else None
        flat['8_type_description'] = bead_type.get('description')

    # 9_local_name_HUMAN (2 columns)
    local = json_obj.get('9_local_name_HUMAN', {})
    if isinstance(local, dict):
        flat['9_local_name_exists'] = local.get('exists')
        names = local.get('names', [])
        flat['9_local_name_names'] = '; '.join(names) if names else None

    # 10_relationship_ (2 columns)
    rel = json_obj.get('10_relationship_', {})
    if isinstance(rel, dict):
        codes = rel.get('codes', [])
        flat['10_relationship_codes'] = ','.join(map(str, codes)) if codes else None
        flat['10_relationship_description'] = rel.get('description')

    # 11_units_of_measure (2 columns)
    units = json_obj.get('11_units_of_measure', {})
    if isinstance(units, dict):
        flat['11_units_type'] = units.get('type')
        flat['11_units_description'] = units.get('description')

    # 12_bead_ethnic_ (1 column)
    ethnics = json_obj.get('12_bead_ethnic_', [])
    flat['12_bead_ethnic_'] = '; '.join(ethnics) if ethnics else None

    # 13_nature_of_exchange (2 columns)
    nature = json_obj.get('13_nature_of_exchange', {})
    if isinstance(nature, dict):
        flat['13_nature_code'] = nature.get('code')
        flat['13_nature_description'] = nature.get('description')

    # Notes (1 column)
    flat['notes'] = json_obj.get('notes')

    return flat

@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=2, min=1, max=10),
    reraise=True
)
def call_claude(client, entry_text):
    """Call Claude API."""
    response = client.messages.create(
        model=MODEL_NAME,
        max_tokens=2500,
        temperature=0,
        system=SYSTEM_PROMPT,
        messages=[{
            "role": "user",
            "content": USER_PROMPT_TEMPLATE.format(text=entry_text)
        }]
    )
    return response

def calculate_cost(input_tokens, output_tokens):
    """Calculate API cost."""
    return (input_tokens / 1_000_000) * 0.80 + (output_tokens / 1_000_000) * 0.24

def log_event(message, level="INFO"):
    """Log event."""
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    msg = f"[{timestamp}] [{level}] {message}"
    print(msg)
    sys.stdout.flush()
    with open(LOG_FILE, 'a') as f:
        f.write(msg + "\n")

# === MAIN ===

def main():
    print("\n" + "="*80)
    print("BEAD TRADE CODING - 27,060 ROWS (PRODUCTION)")
    print("="*80)
    print("✓ All 13 codebook fields")
    print("✓ Enhanced 4a_exchange rules (73% agreement)")
    print("✓ Two-layer location extraction")
    print("✓ 29 output columns (analysis-ready)")
    print("✓ Auto-save every 100 rows")
    print("✓ Batch processing (2,000 rows/session)")
    print("✓ Fresh start from Row 0\n")

    if not ANTHROPIC_API_KEY:
        print("✗ ERROR: ANTHROPIC_API_KEY not set")
        return

    # Find file
    input_path = None
    for loc in [os.path.join(os.getcwd(), INPUT_FILE),
                os.path.join('/content', INPUT_FILE),
                INPUT_FILE]:
        if os.path.exists(loc):
            input_path = loc
            break

    if not input_path:
        print(f"✗ ERROR: Could not find {INPUT_FILE}")
        return

    session = SessionState()
    batch_info = session.get_batch_info()
    total_batches = (27060 + ROWS_PER_SESSION - 1) // ROWS_PER_SESSION

    print(f"Session: {batch_info['session']}/{total_batches}")
    print(f"Batch: Rows {batch_info['batch_start']}-{batch_info['batch_end']-1}")
    print(f"Resume from: Row {batch_info['actual_start']}\n")

    log_event(f"Session {batch_info['session']}/{total_batches} started - row {batch_info['actual_start']}")

    try:
        df = pd.read_excel(input_path)
        print(f"✓ Loaded {len(df)} rows\n")
    except Exception as e:
        print(f"✗ ERROR: {e}")
        return

    if TEXT_COLUMN not in df.columns:
        print(f"✗ ERROR: Column '{TEXT_COLUMN}' not found")
        return

    client = Anthropic(api_key=ANTHROPIC_API_KEY)

    # Initialize output dataframe with original data
    output_df = df.copy()

    # Add coding columns
    coding_cols = [
        '1_price_status', '1_price_amount', '1_price_currency', '1_price_description',
        '2_size_code', '2_size_description',
        '3_colour_codes', '3_colour_description',
        '4_location_codes', '4_location_names',
        '4_location_names_original_spellings', '4_location_context',
        '5_function_codes', '5_function_description',
        '6_origin_of_bead',
        '7_shape_codes', '7_shape_description',
        '8_type_codes', '8_type_description',
        '9_local_name_exists', '9_local_name_names',
        '10_relationship_codes', '10_relationship_description',
        '11_units_type', '11_units_description',
        '12_bead_ethnic_',
        '13_nature_code', '13_nature_description',
        'notes'
    ]
    for col in coding_cols:
        if col not in output_df.columns:
            output_df[col] = None

    total_input_tokens = 0
    total_output_tokens = 0
    session_start_time = time.time()
    success_count = 0
    error_count = 0
    skipped_count = 0

    print("PROCESSING:")
    print("-" * 80)

    for row_idx in range(batch_info['actual_start'], batch_info['batch_end']):
        elapsed_minutes = (time.time() - session_start_time) / 60
        if elapsed_minutes > MAX_SESSION_MINUTES:
            print(f"\n⏰ TIME LIMIT ({elapsed_minutes:.0f}m)")
            break

        try:
            row = df.iloc[row_idx]
            entry_text = row.get(TEXT_COLUMN)

            if pd.isna(entry_text) or not str(entry_text).strip():
                skipped_count += 1
                session.update_row(row_idx)
                continue

            entry_text = str(entry_text).strip()

            response = call_claude(client, entry_text)
            total_input_tokens += response.usage.input_tokens
            total_output_tokens += response.usage.output_tokens
            response_text = response.content[0].text

            cleaned = strip_markdown_json(response_text)

            if not cleaned:
                error_count += 1
                session.update_row(row_idx)
                continue

            try:
                parsed = json.loads(cleaned)
                valid, error = validate_json_response(parsed)
                if valid:
                    flat = flatten_for_excel(parsed)
                    # Write to output dataframe
                    for key, value in flat.items():
                        output_df.at[row_idx, key] = value
                    success_count += 1
                else:
                    error_count += 1

            except json.JSONDecodeError:
                error_count += 1

            session.update_row(row_idx)

            # AUTO-SAVE every 100 rows
            rows_in_batch = row_idx - batch_info['actual_start'] + 1
            if rows_in_batch % AUTOSAVE_EVERY == 0:
                elapsed = (time.time() - session_start_time) / 60
                cost_so_far = calculate_cost(total_input_tokens, total_output_tokens)
                avg_tokens = (total_input_tokens + total_output_tokens) / rows_in_batch if rows_in_batch > 0 else 0

                # Auto-save to Excel
                autosave_file = os.path.join(OUTPUT_BASE, f"session_{batch_info['session']:02d}_autosave_row_{row_idx}.xlsx")
                output_df.iloc[:row_idx+1].to_excel(autosave_file, index=False)

                print(f"Checkpoint (Row {row_idx}): Success={success_count}, Errors={error_count}, "
                      f"Tokens/row={avg_tokens:.0f}, Cost=${cost_so_far:.2f}, Time={elapsed:.0f}m | "
                      f"AUTOSAVED to {os.path.basename(autosave_file)}")

        except Exception as e:
            error_count += 1
            session.update_row(row_idx)
            log_event(f"Row {row_idx}: Exception - {str(e)[:50]}", "ERROR")

    # === FINAL SAVE ===

    final_cost = calculate_cost(total_input_tokens, total_output_tokens)

    print("\n" + "="*80)
    print("SESSION COMPLETE")
    print("="*80)

    # Save final batch file
    batch_num = batch_info['session']
    output_file = os.path.join(
        OUTPUT_BASE,
        f"batch_{batch_num:02d}_rows_{batch_info['actual_start']}-{session.state['last_row']}.xlsx"
    )
    output_df.iloc[:session.state['last_row']+1].to_excel(output_file, index=False)

    print(f"\nBatch {batch_num}/{total_batches}:")
    print(f"  Rows: {batch_info['actual_start']}-{session.state['last_row']}")
    print(f"  Success: {success_count}")
    print(f"  Errors: {error_count}")
    print(f"  Skipped: {skipped_count}")
    print(f"  Avg tokens/row: {(total_input_tokens+total_output_tokens)/(success_count+error_count if success_count+error_count > 0 else 1):.0f}")
    print(f"  Session cost: ${final_cost:.2f}")
    print(f"  Total cost: ${session.state['total_cost'] + final_cost:.2f}")
    print(f"  Time: {(time.time() - session_start_time)/60:.1f}m")
    print(f"  Output: {os.path.basename(output_file)}")
    print(f"  Output columns: {len(output_df.columns)}")

    session.mark_batch_complete(total_input_tokens, total_output_tokens, final_cost)

    total_rows = 27060
    rows_completed = session.state['last_row'] + 1

    print(f"\nProgress:")
    print(f"  Rows: {rows_completed}/{total_rows} ({(rows_completed/total_rows)*100:.1f}%)")
    print(f"  Sessions: {session.state['sessions_completed']}/{total_batches}")
    print(f"  Total cost: ${session.state['total_cost']:.2f}")

    if session.state['sessions_completed'] < total_batches:
        print(f"\n⭐ NEXT SESSION: Resume from row {session.state['last_row'] + 1}")
        print(f"   Run this script again in a new cell")
    else:
        print(f"\n✅ ALL COMPLETE!")
        print(f"  Total rows coded: {rows_completed}")
        print(f"  Total API cost: ${session.state['total_cost']:.2f}")

if __name__ == "__main__":
    main()