section 1 compare clip and yolo, choose yolo

In [None]:
pip install jupyter_contrib_nbextensions

In [None]:
!jupyter contrib nbextension install --user
!pip install pandas requests beautifulsoup4 python-dateutil openpyxl tqdm


In [None]:
import pandas as pd
import requests
from bs4 import BeautifulSoup
from dateutil.parser import parse
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
import re

# Load input Excel file
input_path = "ai_articles_only.xlsx"
df = pd.read_excel(input_path)
urls = df['article_url'].dropna().unique()

# Extract readable text from article
def extract_text(soup):
    article = soup.find("article") or soup
    tags = article.find_all(["p", "h2", "h3", "li", "blockquote"])
    text = "\n".join(
        t.get_text(strip=True) for t in tags
        if len(t.get_text(strip=True)) > 20
    )
    return text.strip()

# Extract image URLs from <article>
def extract_article_images(soup):
    article = soup.find("article")
    img_tags = article.find_all("img") if article else []
    return [
        img.get("src") for img in img_tags
        if img.get("src") and img.get("src").startswith("http")
    ]

# Clean illegal characters for Excel
def clean_excel_string(s):
    if isinstance(s, str):
        return re.sub(r"[\x00-\x1F\x7F]", "", s)
    return s

# Fetch one article
def fetch_article(url):
    try:
        res = requests.get(url, timeout=10)
        soup = BeautifulSoup(res.text, "html.parser")

        title = soup.title.text.strip() if soup.title else ""

        author_tag = soup.find("meta", {"name": "author"})
        author = author_tag["content"].strip() if author_tag else ""

        date_tag = soup.find("time")
        if date_tag and date_tag.get("datetime"):
            try:
                raw_date = date_tag.get("datetime")
                date = parse(raw_date).strftime("%Y-%m")
            except:
                date = raw_date
        else:
            date = ""

        text = extract_text(soup)
        image_urls = extract_article_images(soup)
        num_images = len(image_urls)

        return {
            "article_url": url,
            "title": title,
            "author": author,
            "date": date,
            "text": text,
            "num_images": num_images,
            "image_urls": image_urls[:10]
        }

    except Exception as e:
        return {
            "article_url": url,
            "error": str(e)
        }

# Crawl all articles with threads
results = []
with ThreadPoolExecutor(max_workers=20) as executor:
    futures = {executor.submit(fetch_article, url): url for url in urls}
    for future in tqdm(as_completed(futures), total=len(futures), desc="Crawling articles"):
        results.append(future.result())

# Clean and save to Excel
result_df = pd.DataFrame(results)
result_df = result_df.applymap(clean_excel_string)
result_df.to_excel("article_metadata.xlsx", index=False)
print("Done. Saved to article_metadata.xlsx")


In [None]:
#This is generated after excluding many related stories and irrelevant pictures such as advertisements and needs to be retained. Neither of the first two needs to be retained.
import pandas as pd

# === Step 1: Load metadata ===
input_path = "article_metadata_filtered.xlsx"  # Replace if needed
df = pd.read_excel(input_path)

# === Step 2: Filter out articles with no images ===
df = df[df['num_images'] > 0].reset_index(drop=True)

# === Step 3: Expand image_urls list ===
rows = []

for _, row in df.iterrows():
    article_url = row['article_url']
    article_date = row['date']
    
    try:
        image_list = eval(row['image_urls']) if isinstance(row['image_urls'], str) else row['image_urls']
    except:
        image_list = []

    for i, img_url in enumerate(image_list):
        rows.append({
            "article_id": article_url,
            "image_id": f"img{i}",
            "image_path": img_url,
            "date": article_date,
            "has_person": "",
            "race_image": "",
            "gender_image": "",
            "role_image": "",
            "identity_label": "",
            "simulated_human": ""
        })

# === Step 4: Save to Excel ===
df_images = pd.DataFrame(rows)
df_images.to_excel("image_annotation_extended_template.xlsx", index=False)
print("Done! Saved to generated_image_annotation_template.xlsx")


In [None]:
#clip
import pandas as pd
import requests
import os
from tqdm import tqdm
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch
from concurrent.futures import ThreadPoolExecutor

# === Step 1: Load your Excel ===
input_path = "/Users/dengqiuyue/Downloads/final project/scraping/final_merged_metadata_with_labels.xlsx"
df = pd.read_excel(input_path)

# === Step 2: Setup
base_folder = "images"
os.makedirs(base_folder, exist_ok=True)
df["local_path"] = ""
df["has_person_clip"] = ""
headers = {"User-Agent": "Mozilla/5.0"}
texts = ["a photo of a person", "a photo without any people"]

# === Step 3: Load CLIP model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# === Step 4: Define thread task
def process_image(row_idx, row):
    try:
        url = str(row["image_path"]).strip()
        article_id = str(row["article_id"]).replace("/", "_")
        image_id = str(row["image_id"])
        article_folder = os.path.join(base_folder, article_id)
        os.makedirs(article_folder, exist_ok=True)

        filename = f"{image_id}.jpg"
        save_path = os.path.join(article_folder, filename)

        # Download
        if not os.path.exists(save_path):
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()
            with open(save_path, "wb") as f:
                f.write(response.content)

        # CLIP prediction
        image = Image.open(save_path)
        if image.mode in ["P", "LA", "RGBA"]:
            image = image.convert("RGB")
        inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
        outputs = model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1).squeeze()
        prediction = "yes" if probs[0] > probs[1] else "no"

        return (row_idx, save_path, prediction)

    except Exception as e:
        return (row_idx, "download_error", "error")

# === Step 5: Run multithreaded
with ThreadPoolExecutor(max_workers=20) as executor:
    futures = [executor.submit(process_image, idx, row) for idx, row in df.iterrows()]
    for future in tqdm(futures, desc="Downloading + CLIP Judging"):
        idx, path, result = future.result()
        df.at[idx, "local_path"] = path
        df.at[idx, "has_person_clip"] = result

# === Step 6: Save result
output_path = "image_annotation_with_clip_multithreaded.xlsx"
df.to_excel(output_path, index=False)
print(f"Done! Saved to {output_path}")


In [None]:
pip install ultralytics

In [None]:
import os
import pandas as pd
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

# === Step 1: Load Excel ===
input_path = "/Users/dengqiuyue/Downloads/final project/scraping/final_merged_metadata_with_labels.xlsx"
df = pd.read_excel(input_path)
df["local_path"] = ""

# === Step 2: Prepare base folder ===
base_folder = "final project code/image1"
os.makedirs(base_folder, exist_ok=True)

# === Step 3: Clean article_url to valid folder path
def safe_folder_name(article_url):
    return (
        str(article_url)
        .replace("https://", "https___")
        .replace("http://", "http___")
        .replace("/", "_")
        .replace(":", "_")
        .strip()
    )

# === Step 4: Download function
def download_image(row):
    try:
        url = str(row["image_path"]).strip()
        folder_name = safe_folder_name(row["article_id"])
        image_id = str(row["image_id"]).replace("/", "_")

        folder = os.path.join(base_folder, folder_name)
        os.makedirs(folder, exist_ok=True)
        save_path = os.path.join(folder, f"{image_id}.jpg")

        if not os.path.exists(save_path):
            headers = {"User-Agent": "Mozilla/5.0"}
            r = requests.get(url, headers=headers, timeout=10)
            r.raise_for_status()
            with open(save_path, "wb") as f:
                f.write(r.content)

        return save_path
    except Exception as e:
        return "download_error"

# === Step 5: Multithreaded execution ===
def process(index_row):
    idx, row = index_row
    path = download_image(row)
    return idx, path

with ThreadPoolExecutor(max_workers=20) as executor:
    results = list(tqdm(executor.map(process, df.iterrows()), total=len(df), desc="Downloading images"))
    for idx, path in results:
        df.at[idx, "local_path"] = path

# === Step 6: Save updated Excel ===
output_path = "image1_metadata_with_local_paths.xlsx"
df.to_excel(output_path, index=False)
print(f"Saved: {output_path}")


In [None]:
import time
import base64
import mimetypes
import pandas as pd
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import os
import logging
from datetime import datetime

# ---------------- LOGGING CONFIG ----------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("image_analysis.log", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ---------------- CONFIG ----------------
API_KEY = "your_api_key"
BASE_URL = ""
MODEL = "gemini-2.5-pro"

# Directory of the current Python process (as used in original code)
CURRENT_DIR = os.path.dirname(os.path.abspath("images_analysis"))

INPUT_XLSX = os.path.join(CURRENT_DIR, "image_analysis_api.xlsx")
OUTPUT_XLSX = os.path.join(CURRENT_DIR, "image_analysis_results_prompt.xlsx")
IMAGE_DOWNLOAD_DIR = os.path.join(CURRENT_DIR, "/Users/dengqiuyue/Downloads/final project/images")

# Checkpointing and incremental saves
TEMP_OUTPUT_XLSX = os.path.join(CURRENT_DIR, "image_analysis_temp.xlsx")
CHECKPOINT_FILE = os.path.join(CURRENT_DIR, "checkpoint.txt")
SAVE_INTERVAL = 20  # save every N processed tasks

MAX_WORKERS = 8
REQUESTS_PER_MIN = 60
RETRIES = 3
BACKOFF = 2.0

PROMPT_TEXT_ENGLISH = """
Definition: "Core figure" refers to the most visually prominent person in the image, determined by factors like size, central placement, lighting, or interaction focus.

Part 1: Describe this image in detail.

Part 2: Structured Feature List (19 fixed-category features)
List the following 19 features using the specified categories. For numerical/categorical features, use only the provided options (do not use free text). For "unknown" cases, mark with "99".

Definition: "Core figure" refers to the most visually prominent person in the image, determined by factors like size, central placement, lighting, or interaction focus.

1. Image Subject Type
Category: 1=Single person; 2=Multiple people; 3=No people (landscape/object)

2. Core Figure Count
Category: [Specific number, e.g., "3"; 99=Unknown)

3. Core Figure Visual Weight
Category: 1=Less than or equal to 1/3 of frame; 2=1/3 to 2/3 of frame (inclusive); 3=Greater than or equal to 2/3 of frame

4. Source Scene Type
Category: 1=News scene; 2=Lab/Office; 3=Public Space; 4=Private Space; 5=Abstract (illustration/diagram)

5. Image Medium
Category: 1=Photograph; 2=Illustration/Cartoon; 3=Diagram/Data Visualization; 4=Screenshot

6. Core Figure Age Group
Category: 1=Child (clear pre-pubescent features); 2=Adult (typical working-age appearance); 3=Elderly (visible gray hair/wrinkles/walking aid); 99=Uncertain (when face is obscured or unclear)

7. Core Figure Gender
Category: 1=Male; 2=Female; 3=Unclear. 

8. Core Figure Skin Tone
Category: 1=White; 2=Black; 3=Beige; 4=Mixed/Unclear 

9. Clothing Style
Category: 1=Formal (suit/uniform); 2=Casual; 3=Professional (lab coat/workwear); 4=Mixed
*Note: "Professional" refers to clothing specific to a tech/occupational role; "Formal" refers to general ceremonial/official attire

10. Accessories
Category: 1=No significant accessories; 2=Tech devices (headphones/smartwatch); 3=Identity markers (badge/ID); 4=Personalized (tattoos/statement jewelry)

11. Core Figure Action
Category: 1=Static (standing/sitting); 2=Dynamic (walking/operating); 3=Interactive (speaking/demonstrating); 4=Passive (being photographed/observed)

12. Core Figure Emotion
Category: 1=Positive (smiling/confident); 2=Neutral (calm/focused); 3=Negative (serious/anxious); 4=Unclear

13. Tech Device Presence
Category: 1=No tech; 2=General (phone/laptop); 3=Specialized (AI device/instrument); 4=Media equipment (camera/microphone)

14. Human-Tech Relationship
Category: 1=Using device; 2=Surrounded by devices; 3=Independent of device; 4=No device

15. Background Elements
Category: 1=Natural (plants/sky); 2=Architectural (walls/doors); 3=Symbolic (flag/slogan); 4=Cluttered/Unfocused

16. Lighting Intensity
Category: 1=Bright; 2=Dim; 3=High Contrast; 4=Soft

17. Dominant Color Tone
Category: 1=Cool (blue/green); 2=Warm (red/yellow); 3=Neutral (black/white/gray); 4=Mixed

18. Interpersonal Interaction Pattern
Category: 1=No interaction; 2=Collaborative (working together); 3=Interview (question/answer); 4=Bystander (observing core figure)

19. Core Figure Visual Subject
Category: 1=Person; 2=Object; 3=Other

Output Format:
- Start with "Part 1: [Caption]"
- Then "Part 2: [Numbered list of features 1-19, each on a new line, e.g., '1. Image Subject Type: 2']"
"""

# ---------------- RATE LIMITER ----------------
class TokenBucket:
    def __init__(self, rate_per_min):
        self.capacity = max(1, rate_per_min)
        self.tokens = self.capacity
        self.refill_time = time.time()
        self.rate_per_sec = rate_per_min / 60.0

    def consume(self, row_idx=None):
        now = time.time()
        elapsed = now - self.refill_time
        self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_per_sec)
        self.refill_time = now

        if self.tokens < 1:
            wait_time = (1 - self.tokens) / self.rate_per_sec
            logger.info(f"[Row {row_idx}] Rate limit hit, sleeping {wait_time:.2f}s (tokens: {self.tokens:.2f})")
            time.sleep(wait_time)
            now = time.time()
            elapsed = now - self.refill_time
            self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_per_sec)
            self.refill_time = now

        self.tokens -= 1
        logger.debug(f"[Row {row_idx}] Consumed a token, remaining: {self.tokens:.2f}")

bucket = TokenBucket(REQUESTS_PER_MIN)

# ---------------- CHECKPOINT / RESUME ----------------
def load_checkpoint():
    """Load checkpoint; return a set of completed row indices."""
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, "r") as f:
                completed_indices = set(int(line.strip()) for line in f if line.strip())
            logger.info(f"Loaded {len(completed_indices)} completed tasks from checkpoint")
            return completed_indices
        except Exception as e:
            logger.warning(f"Failed to read checkpoint file: {e}")
    return set()

def save_checkpoint(completed_indices):
    """Save checkpoint of completed row indices."""
    try:
        with open(CHECKPOINT_FILE, "w") as f:
            for idx in sorted(completed_indices):
                f.write(f"{idx}\n")
        logger.debug(f"Checkpoint saved: {len(completed_indices)} tasks")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def load_existing_results():
    """Load existing temporary results file."""
    if os.path.exists(TEMP_OUTPUT_XLSX):
        try:
            df = pd.read_excel(TEMP_OUTPUT_XLSX)
            results = df.to_dict("records")
            logger.info(f"Loaded {len(results)} existing results from temporary file")
            return results
        except Exception as e:
            logger.warning(f"Failed to read temporary results file: {e}")
    return []

def save_results_incremental(results, force=False):
    """Incrementally save results to the temporary file."""
    try:
        df = pd.DataFrame(results)
        df.to_excel(TEMP_OUTPUT_XLSX, index=False)
        logger.info(f"Temporary results saved: {len(results)} records")
        return True
    except Exception as e:
        logger.error(f"Failed to save temporary results: {e}")
        return False

# ---------------- IMAGE FUNCTIONS ----------------
def download_image(image_url, save_path, row_idx=None):
    try:
        logger.info(f"[Row {row_idx}] Start downloading image: {image_url}")
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        resp = requests.get(image_url, stream=True, timeout=30)
        resp.raise_for_status()

        file_size = 0
        with open(save_path, "wb") as f:
            for chunk in resp.iter_content(8192):
                if chunk:
                    f.write(chunk)
                    file_size += len(chunk)

        logger.info(f"[Row {row_idx}] Image downloaded: {save_path} ({file_size:,} bytes)")
        return save_path
    except Exception as e:
        logger.error(f"[Row {row_idx}] Image download failed {image_url}: {e}")
        return None

def _guess_mime_from_path(path):
    mime, _ = mimetypes.guess_type(path)
    return mime or "image/jpeg"

def call_ai_api(image_path_local, prompt_text, row_idx=None):
    logger.info(f"[Row {row_idx}] Calling AI API to analyze image: {image_path_local}")

    with open(image_path_local, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    image_mime = _guess_mime_from_path(image_path_local)
    data_url = f"data:{image_mime};base64,{b64}"

    headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}

    # Gemini-style chat completions payload
    payload = {
        "model": MODEL,
        "messages": [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt_text},
                    {"type": "image_url", "image_url": {"url": data_url}}
                ]
            }
        ],
        "temperature": 0.2,
        "max_tokens": 20000
    }

    logger.info(f"[Row {row_idx}] Sending API request to: {BASE_URL}/chat/completions")
    start_time = time.time()

    try:
        resp = requests.post(f"{BASE_URL}/chat/completions", headers=headers, json=payload, timeout=120)
        resp.raise_for_status()

        response_data = resp.json()
        logger.info(f"[Row {row_idx}] Raw API response (first 200 chars): {str(response_data)[:200]}...")

        # Validate response structure
        if "choices" not in response_data:
            logger.error(f"[Row {row_idx}] Missing 'choices' in response: {response_data}")
            raise Exception("Bad API response: missing 'choices'")

        if len(response_data["choices"]) == 0:
            logger.error(f"[Row {row_idx}] Empty 'choices' array: {response_data}")
            raise Exception("Bad API response: empty 'choices'")

        choice = response_data["choices"][0]
        if "message" not in choice:
            logger.error(f"[Row {row_idx}] Missing 'message' in choice: {choice}")
            raise Exception("Bad API response: missing 'message'")

        if "content" not in choice["message"]:
            logger.error(f"[Row {row_idx}] Missing 'content' in message: {choice['message']}")
            raise Exception("Bad API response: missing 'content'")

        response_content = choice["message"]["content"]

        # Detailed logging of content shape
        if response_content is None:
            logger.warning(f"[Row {row_idx}] API returned content=None (empty response cause)")
            logger.warning(f"[Row {row_idx}] Full choice: {choice}")
            response_content = ""
        elif response_content == "":
            logger.warning(f"[Row {row_idx}] API returned empty string for content")
            logger.warning(f"[Row {row_idx}] Full choice: {choice}")
            if "finish_reason" in choice:
                logger.warning(f"[Row {row_idx}] finish_reason: {choice['finish_reason']}")
            if "usage" in response_data:
                logger.warning(f"[Row {row_idx}] usage: {response_data['usage']}")
        elif not isinstance(response_content, str):
            logger.warning(f"[Row {row_idx}] content is not a string: {type(response_content)}, value: {response_content}")
            response_content = str(response_content)
        elif response_content.strip() == "":
            logger.warning(f"[Row {row_idx}] content contains only whitespace")
            logger.warning(f"[Row {row_idx}] repr(content): {repr(response_content)}")

        # Record success metrics
        elapsed_time = time.time() - start_time
        usage = response_data.get("usage", {})

        finish_reason = choice.get("finish_reason", "unknown")
        if finish_reason == "length" and len(response_content.strip()) == 0:
            logger.warning(f"[Row {row_idx}] finish_reason='length' with empty response; possible token limit")

        logger.info(
            f"[Row {row_idx}] API call OK. Elapsed: {elapsed_time:.2f}s, "
            f"total_tokens: {usage.get('total_tokens', 'N/A')}, "
            f"response_len: {len(response_content)} chars, "
            f"finish_reason: {finish_reason}"
        )

        return response_content

    except requests.exceptions.RequestException as e:
        elapsed_time = time.time() - start_time
        logger.warning(f"[Row {row_idx}] Primary URL failed (elapsed {elapsed_time:.2f}s): {e}")

        # Fallback: try removing '/v1'
        if "/v1" in BASE_URL:
            fallback_url = BASE_URL.replace("/v1", "")
            logger.info(f"[Row {row_idx}] Trying fallback URL: {fallback_url}")
            start_time = time.time()

            try:
                resp = requests.post(f"{fallback_url}/chat/completions", headers=headers, json=payload, timeout=120)
                resp.raise_for_status()

                response_data = resp.json()
                response_content = response_data["choices"][0]["message"]["content"]

                elapsed_time = time.time() - start_time
                usage = response_data.get("usage", {})
                logger.info(
                    f"[Row {row_idx}] Fallback call OK. Elapsed: {elapsed_time:.2f}s, "
                    f"total_tokens: {usage.get('total_tokens', 'N/A')}, "
                    f"response_len: {len(response_content)} chars"
                )

                return response_content

            except Exception as fallback_e:
                elapsed_time = time.time() - start_time
                logger.error(f"[Row {row_idx}] Fallback URL failed (elapsed {elapsed_time:.2f}s): {fallback_e}")
                raise fallback_e
        else:
            logger.error(f"[Row {row_idx}] API call failed, no fallback URL: {e}")
            raise e
    except Exception as e:
        elapsed_time = time.time() - start_time
        logger.error(f"[Row {row_idx}] API call exception (elapsed {elapsed_time:.2f}s): {e}")
        raise e

# ---------------- ROW PROCESSING ----------------
def analyze_one_row(row_idx, row):
    logger.info(f"[Row {row_idx}] ==================== START ====================")
    start_time = time.time()

    article_id = row.get("article_id")
    image_id = row.get("image_id")
    image_url = str(row.get("image_path", "")).strip()

    logger.info(f"[Row {row_idx}] Task info: article_id={article_id}, image_id={image_id}")

    if not image_url:
        logger.error(f"[Row {row_idx}] Error: image_url is missing")
        return {"row_idx": row_idx, "error": "image_url_missing"}

    base_name = os.path.basename(image_url.split("?")[0]) or f"image_{row_idx}.jpg"
    if image_id:
        base_name = f"{image_id}_{base_name}"
    if not base_name.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
        base_name += ".jpg"
    local_path = os.path.join(IMAGE_DOWNLOAD_DIR, base_name)

    logger.info(f"[Row {row_idx}] Local save path: {local_path}")

    # Download if missing
    if not os.path.exists(local_path):
        logger.info(f"[Row {row_idx}] File not found locally; starting download")
        if not download_image(image_url, local_path, row_idx):
            logger.error(f"[Row {row_idx}] Task failed: image download failed")
            return {"row_idx": row_idx, "error": "download_failed"}
    else:
        logger.info(f"[Row {row_idx}] File already exists; skip download")

    # Retry API calls with backoff
    last_err, response_content = "", ""
    for attempt in range(1, RETRIES + 1):
        try:
            logger.info(f"[Row {row_idx}] API attempt {attempt}/{RETRIES}")
            bucket.consume(row_idx)
            response_content = call_ai_api(local_path, PROMPT_TEXT_ENGLISH, row_idx)
            logger.info(f"[Row {row_idx}] API call succeeded")
            break
        except Exception as e:
            last_err = str(e)
            logger.error(f"[Row {row_idx}] Attempt {attempt}/{RETRIES} failed: {e}")
            if attempt < RETRIES:
                wait_time = BACKOFF ** (attempt - 1)
                logger.info(f"[Row {row_idx}] Sleeping {wait_time:.2f}s before retry")
                time.sleep(wait_time)
    else:
        total_time = time.time() - start_time
        logger.error(f"[Row {row_idx}] Task failed after all retries (elapsed {total_time:.2f}s)")
        return {"row_idx": row_idx, "error": last_err or "api_failed"}

    total_time = time.time() - start_time
    logger.info(f"[Row {row_idx}] ==================== DONE ==================== (elapsed {total_time:.2f}s)")

    # Validate response content
    if not response_content:
        logger.warning(f"[Row {row_idx}] Warning: API returned empty content")
        response_content = "API_RESPONSE_EMPTY"

    result = {
        "row_idx": row_idx,
        "article_id": article_id,
        "image_id": image_id,
        "image_path": image_url,
        "response": response_content
    }

    logger.info(f"[Row {row_idx}] Returning result: response length={len(str(response_content))}")
    return result

# ---------------- MAIN ----------------
def main():
    logger.info("========================================")
    logger.info("Image analysis job started")
    logger.info(f"Config: workers={MAX_WORKERS}, rate_limit={REQUESTS_PER_MIN}/min, retries={RETRIES}")
    logger.info(f"Incremental save interval: every {SAVE_INTERVAL} tasks")
    logger.info("========================================")

    start_time = time.time()

    # Read input
    try:
        df = pd.read_excel(INPUT_XLSX)
        logger.info(f"Input file loaded: {INPUT_XLSX}, total rows: {len(df)}")
    except Exception as e:
        logger.error(f"Failed to read input file {INPUT_XLSX}: {e}")
        return

    # Filter valid rows
    original_count = len(df)
    df = df[df["image_path"].astype(str).str.len() > 0].reset_index(drop=True)
    valid_count = len(df)
    logger.info(f"Valid rows after filtering: {valid_count}/{original_count}")

    if valid_count == 0:
        logger.warning("No valid image rows to process")
        return

    # Load checkpoint and existing results
    completed_indices = load_checkpoint()
    existing_results = load_existing_results()

    # Build index -> result map for updates
    existing_results_dict = {r.get("row_idx"): r for r in existing_results if r.get("row_idx") is not None}

    # Determine remaining tasks
    remaining_indices = set(range(valid_count)) - completed_indices
    logger.info(f"Checkpoint: {len(completed_indices)} completed, {len(remaining_indices)} remaining")

    if not remaining_indices:
        logger.info("All tasks already completed; saving final results")
        if existing_results:
            pd.DataFrame(existing_results).to_excel(OUTPUT_XLSX, index=False)
            logger.info(f"Final results saved to: {OUTPUT_XLSX}")
        return

    # Initialize counters
    results = existing_results.copy()
    completed_count = len(completed_indices)
    success_count = sum(1 for r in existing_results if not r.get("error") and str(r.get("response", "")).strip())

    logger.info(f"Continuing: {len(existing_results)} existing results, {success_count} successful")

    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
        logger.info(f"Starting thread pool with max_workers={MAX_WORKERS}")

        futures = {}
        for i in remaining_indices:
            row = df.iloc[i].to_dict()
            future = pool.submit(analyze_one_row, i, row)
            futures[future] = i

        logger.info(f"Submitted {len(futures)} tasks")

        for fut in tqdm(as_completed(futures), total=len(futures), desc="Processing Images"):
            result = fut.result()
            row_idx = result["row_idx"]

            # Update results list
            if row_idx in existing_results_dict:
                idx = next(i for i, r in enumerate(results) if r.get("row_idx") == row_idx)
                results[idx] = result
            else:
                results.append(result)

            # Update counters
            completed_count += 1
            completed_indices.add(row_idx)

            # Inspect result
            has_error = "error" in result and result.get("error")
            has_response = "response" in result and result.get("response") and str(result.get("response")).strip()

            if has_error:
                logger.error(f"Task {result['row_idx']} FAILED: {result['error']} ({completed_count}/{valid_count})")
            elif has_response:
                success_count += 1
                response_preview = str(result.get("response", ""))[:50].replace("\n", " ")
                logger.info(f"Task {result['row_idx']} SUCCESS ({completed_count}/{valid_count}) - Preview: {response_preview}...")
            else:
                logger.warning(f"Task {result['row_idx']} WARNING: neither error nor response present ({completed_count}/{valid_count})")
                logger.warning(f"[Row {result['row_idx']}] Full result: {result}")

            # Periodic save
            should_save = (
                (completed_count % SAVE_INTERVAL == 0) or
                (completed_count == valid_count) or
                (len(remaining_indices) - (completed_count - len(existing_results)) <= 0)
            )

            if should_save:
                # Save checkpoint
                save_checkpoint(completed_indices)

                # Save temporary results
                if save_results_incremental(results):
                    elapsed = time.time() - start_time
                    success_rate = (success_count / completed_count) * 100
                    empty_count = sum(1 for r in results if not r.get("error") and not str(r.get("response", "")).strip())
                    logger.info(f"Progress: {completed_count}/{valid_count} done, "
                                f"success_rate: {success_rate:.1f}%, empty_responses: {empty_count}, elapsed: {elapsed:.1f}s")
                    logger.info(f"Progress saved ({len(results)} records)")

    # Save final results
    try:
        results.sort(key=lambda x: x.get("row_idx", 0))
        pd.DataFrame(results).to_excel(OUTPUT_XLSX, index=False)
        logger.info(f"Final results saved to: {OUTPUT_XLSX}")

        # Cleanup temporary files
        cleanup_success = True
        try:
            if os.path.exists(TEMP_OUTPUT_XLSX):
                os.remove(TEMP_OUTPUT_XLSX)
                logger.info("Temporary results file removed")
            if os.path.exists(CHECKPOINT_FILE):
                os.remove(CHECKPOINT_FILE)
                logger.info("Checkpoint file removed")
        except Exception as e:
            logger.warning(f"Error cleaning up temporary files: {e}")
            cleanup_success = False

    except Exception as e:
        logger.error(f"Failed to save final results: {e}")
        logger.info(f"You can recover from the temporary file: {TEMP_OUTPUT_XLSX}")
        return

    # Final stats
    total_time = time.time() - start_time
    final_success_rate = (success_count / valid_count) * 100

    error_count = sum(1 for r in results if r.get("error"))
    empty_response_count = sum(1 for r in results if not r.get("error") and not str(r.get("response", "")).strip())
    valid_response_count = sum(1 for r in results if not r.get("error") and str(r.get("response", "")).strip())

    logger.info("========================================")
    logger.info("Final statistics:")
    logger.info(f"  Total tasks: {valid_count}")
    logger.info(f"  Valid responses: {valid_response_count}")
    logger.info(f"  Empty responses: {empty_response_count}")
    logger.info(f"  Error tasks: {error_count}")
    logger.info(f"  Success rate: {final_success_rate:.2f}%")
    logger.info(f"  Total time: {total_time:.1f}s")
    logger.info(f"  Avg per task: {total_time / valid_count:.2f}s")

    if empty_response_count > 0:
        logger.warning(f"Found {empty_response_count} empty responses; please check logs to diagnose")
        # Show a few examples
        empty_tasks = [r for r in results if not r.get("error") and not str(r.get("response", "")).strip()][:3]
        for task in empty_tasks:
            logger.warning(f"Empty response example: Row {task['row_idx']}, result: {task}")

    logger.info("========================================")

if __name__ == "__main__":
    main()


In [None]:
import os
import io
import time
import json
import base64
import logging
import mimetypes
import pandas as pd
import requests
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# ================= LOGGING CONFIG =================
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.FileHandler("image_role3_classifier.log", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# ================= CONFIG =================
API_KEY = ""  # REQUIRED: replace with your key
BASE_URL = ""                    # url
MODEL = "gemini-2.5-pro"                                       # or your available model name

CURRENT_DIR = os.path.dirname(os.path.abspath("image_role3_classifier"))

# Input & output
INPUT_XLSX  = os.path.join(CURRENT_DIR, "cleaned_image_data.xlsx")   # input (must include image_path; optional Caption, etc.)
OUTPUT_XLSX = os.path.join(CURRENT_DIR, "role_results.xlsx")         # final output (keep all input cols + new result cols)
TEMP_OUTPUT_XLSX = os.path.join(CURRENT_DIR, "role_temp.xlsx")       # incremental save
CHECKPOINT_FILE  = os.path.join(CURRENT_DIR, "role_checkpoint.txt")  # checkpoint

# Directory containing already-downloaded images (as you provided)
LOCAL_IMAGE_DIR = "/Users/dengqiuyue/Downloads/final project/images"

# Concurrency and stability controls
MAX_WORKERS = 8
REQUESTS_PER_MIN = 60
RETRIES = 3
BACKOFF = 2.0
SAVE_INTERVAL = 20

# ================= PROMPT (3-class; image first, caption secondary) =================
PROMPT_TEXT = """
You are given a news image to analyze. Your task is to classify the main human subject in the image into ONE of only three categories below.
If no person is clearly visible, or if the visual identity is ambiguous/abstract/only objects/illustrations are shown, choose 99.

Target label set (choose exactly one):
1. Elite (e.g., scientist/researcher/engineer; CEO/entrepreneur/business leader; politician/government official)
2. General Public (non-elite everyday people, users, workers, citizens, students, audiences, crowds)
99. Unclear/Abstract (no person clearly visible, or identity cannot be determined, or purely abstract/object/robot/diagram)

Important instructions:
- Base your judgment primarily on the IMAGE CONTENT. If a caption is provided, use it only as a secondary hint.
- If it's not strongly evident that the person is an elite (scientist/CEO/politician), do NOT guess "Elite".
- If multiple people are present, focus on the most visually prominent or socially central figure.
- If the image does not show a person (e.g., robots, diagrams, product shots), or it's too ambiguous, use 99.

Output format (strict JSON only):
{
  "role": [1 or 2 or 99],
  "role_label": "[Elite | General Public | Unclear/Abstract]"
}
"""

# ================= RATE LIMITER =================
class TokenBucket:
    def __init__(self, rate_per_min):
        self.capacity = max(1, rate_per_min)
        self.tokens = self.capacity
        self.refill_time = time.time()
        self.rate_per_sec = rate_per_min / 60.0

    def consume(self, row_idx=None):
        now = time.time()
        elapsed = now - self.refill_time
        self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_per_sec)
        self.refill_time = now
        if self.tokens < 1:
            wait_time = (1 - self.tokens) / self.rate_per_sec
            logger.info(f"[Row {row_idx}] Rate limit triggered, waiting {wait_time:.2f}s (tokens={self.tokens:.2f})")
            time.sleep(wait_time)
            now = time.time()
            elapsed = now - self.refill_time
            self.tokens = min(self.capacity, self.tokens + elapsed * self.rate_per_sec)
            self.refill_time = now
        self.tokens -= 1

bucket = TokenBucket(REQUESTS_PER_MIN)

# ================= CHECKPOINT & TEMP =================
def load_checkpoint():
    if os.path.exists(CHECKPOINT_FILE):
        try:
            with open(CHECKPOINT_FILE, "r") as f:
                return set(int(line.strip()) for line in f if line.strip())
        except Exception as e:
            logger.warning(f"Failed to read checkpoint: {e}")
    return set()

def save_checkpoint(done_set):
    try:
        with open(CHECKPOINT_FILE, "w") as f:
            for idx in sorted(done_set):
                f.write(f"{idx}\n")
    except Exception as e:
        logger.error(f"Failed to save checkpoint: {e}")

def load_existing_results():
    if os.path.exists(TEMP_OUTPUT_XLSX):
        try:
            return pd.read_excel(TEMP_OUTPUT_XLSX).to_dict("records")
        except Exception as e:
            logger.warning(f"Failed to read temporary results: {e}")
    return []

def save_results_incremental(results):
    pd.DataFrame(results).to_excel(TEMP_OUTPUT_XLSX, index=False)
    logger.info(f"Temporary results saved: {len(results)} records")

# ================= UTIL: compose local image path =================
def guess_local_image_path(row_idx, row_dict):
    """
    Try to infer the local image file path from the row:
    1) If a 'local_path' column exists and the file exists, use it;
    2) If 'image_local' or 'image_filename' exists, try to compose the path;
    3) Otherwise, use the URL basename from 'image_path' as the filename;
    4) If 'image_id' exists, also try prefixed variants.
    """
    # 1) Prefer explicit local path
    for k in ["local_path", "image_local", "local_image_path"]:
        p = str(row_dict.get(k, "")).strip()
        if p and os.path.exists(p):
            return p

    # 2) Use directory + filename inference
    filename = str(row_dict.get("image_filename", "")).strip()
    if filename:
        candidate = os.path.join(LOCAL_IMAGE_DIR, filename)
        if os.path.exists(candidate):
            return candidate

    # From image_path URL, take basename
    image_url = str(row_dict.get("image_path", "")).strip()
    if image_url:
        base = os.path.basename(image_url.split("?")[0]) or f"image_{row_idx}.jpg"
        image_id = str(row_dict.get("image_id", "")).strip()
        variants = []
        # Original name
        variants.append(base)
        # Add .jpg if missing extension
        if not base.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
            variants.append(base + ".jpg")
        # With image_id prefix
        if image_id:
            variants.append(f"{image_id}_{base}")
            if not base.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
                variants.append(f"{image_id}_{base}.jpg")

        # Try each variant in the local image dir
        for name in variants:
            cand = os.path.join(LOCAL_IMAGE_DIR, name)
            if os.path.exists(cand):
                return cand

    # Not found
    return None

def encode_image_as_data_url(local_path):
    mime, _ = mimetypes.guess_type(local_path)
    mime = mime or "image/jpeg"
    with open(local_path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return f"data:{mime};base64,{b64}"

# ================= API CALL (image + optional caption) =================
def call_llm_for_image_role3(local_image_path: str, caption_text: str | None, row_idx=None) -> str | None:
    """
    Send a multimodal message: text prompt + image_url (data URI).
    If your backend only supports plain text, you could include 'caption_text' in the prompt.
    Here we use an OpenAI-compatible chat/completions payload with image_url.
    """
    content = [{"type": "text", "text": PROMPT_TEXT}]
    if caption_text:
        content.append({"type": "text", "text": f"(Optional hint) Caption: {caption_text}"})
    # Add image data URL
    data_url = encode_image_as_data_url(local_image_path)
    content.append({"type": "image_url", "image_url": {"url": data_url}})

    headers = {"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"}
    payload = {
        "model": MODEL,
        "messages": [{"role": "user", "content": content}],
        "temperature": 0.2,
        "max_tokens": 40000
    }
    endpoint = f"{BASE_URL.rstrip('/')}/chat/completions"

    for attempt in range(1, RETRIES + 1):
        try:
            resp = requests.post(endpoint, headers=headers, json=payload, timeout=90)
            resp.raise_for_status()
            data = resp.json()
            return data["choices"][0]["message"]["content"]
        except Exception as e:
            logger.warning(f"[Row {row_idx}] Attempt {attempt}/{RETRIES} failed: {e}")
            if attempt < RETRIES:
                time.sleep(BACKOFF ** (attempt - 1))
    return None

# ================= ROW PROCESSING =================
def analyze_one_row(row_idx: int, row: dict) -> dict:
    """
    Returns a dict containing:
    - All original input columns (passed through)
    - Added: response, role3, role3_label, error, local_image_path
    """
    logger.info(f"[Row {row_idx}] ===== START =====")

    # Try to locate local image path
    local_path = guess_local_image_path(row_idx, row)
    # Optional caption (used as a hint if present)
    caption = str(row.get("Caption", "")).strip() if "Caption" in row else None

    # Initialize result with original row
    result = dict(row)
    result.update({
        "row_idx": row_idx,
        "local_image_path": local_path if local_path else "",
        "response": "",
        "role3": None,
        "role3_label": None,
        "error": None
    })

    if not local_path or not os.path.exists(local_path):
        logger.error(f"[Row {row_idx}] Local image not found")
        result["error"] = "image_not_found"
        return result

    # Rate limit
    bucket.consume(row_idx)

    # Call the model (image first)
    response = call_llm_for_image_role3(local_path, caption, row_idx)

    # Parse JSON
    if response:
        try:
            parsed = json.loads(response)
            role_val = parsed.get("role")
            label = parsed.get("role_label")
            # Accept only 1/2/99
            if role_val not in [1, 2, 99]:
                raise ValueError(f"role must be 1/2/99, got: {role_val}")
            result["role3"] = role_val
            result["role3_label"] = label
            result["response"] = response
        except Exception as e:
            logger.warning(f"[Row {row_idx}] JSON parse failed: {e} | raw: {response[:160]}...")
            result["response"] = response
            result["error"] = f"json_parse_error: {e}"
    else:
        result["error"] = "api_failed"

    logger.info(f"[Row {row_idx}] DONE: role3={result['role3']}, role3_label={result['role3_label']}")
    return result

# ================= MAIN =================
def main():
    logger.info("========================================")
    logger.info("Local image 3-class classification job started (Elite / General Public / Unclear)")
    logger.info(f"Concurrency={MAX_WORKERS}, Rate limit={REQUESTS_PER_MIN}/min, Retries={RETRIES}, Save interval={SAVE_INTERVAL}")
    logger.info("========================================")

    # Read input (keep all columns)
    try:
        df = pd.read_excel(INPUT_XLSX)
        logger.info(f"Loaded input: {INPUT_XLSX}, total rows: {len(df)}")
    except Exception as e:
        logger.error(f"Failed to read input: {e}")
        return

    total = len(df)

    # Resume from checkpoint
    done_set = load_checkpoint()
    existing_results = load_existing_results()
    results = existing_results.copy()
    existing_map = {r.get("row_idx"): r for r in existing_results if "row_idx" in r}

    remaining = set(range(total)) - done_set
    logger.info(f"Checkpoint resume: completed {len(done_set)}, remaining {len(remaining)}")

    if not remaining:
        logger.info("No remaining tasks; saving final results")
        pd.DataFrame(results).sort_values("row_idx").to_excel(OUTPUT_XLSX, index=False)
        logger.info(f"Final results saved to: {OUTPUT_XLSX}")
        return

    # Concurrent processing
    completed_count = len(done_set)
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
        futures = {pool.submit(analyze_one_row, i, df.iloc[i].to_dict()): i for i in remaining}
        for fut in tqdm(as_completed(futures), total=len(futures), desc="Classifying Images"):
            res = fut.result()
            ridx = res["row_idx"]

            if ridx in existing_map:
                idx = next(i for i, r in enumerate(results) if r.get("row_idx") == ridx)
                results[idx] = res
            else:
                results.append(res)

            done_set.add(ridx)
            completed_count += 1

            # Incremental save & checkpoint
            if (completed_count % SAVE_INTERVAL == 0) or (completed_count == total):
                save_checkpoint(done_set)
                save_results_incremental(results)
                logger.info(f"Progress: {completed_count}/{total}")

    # Final save
    results.sort(key=lambda x: x.get("row_idx", 0))
    pd.DataFrame(results).to_excel(OUTPUT_XLSX, index=False)
    logger.info(f"Final results saved to: {OUTPUT_XLSX}")

if __name__ == "__main__":
    main()


In [None]:
# Image-only analysis for final_images_data1.xlsx (no age features)
# - Uses pandas + numpy + matplotlib (+ scipy, + statsmodels if available)
# - Pure matplotlib (no seaborn), one chart per figure, no explicit colors/styles
# - Outputs figures & CSVs to /mnt/data/outputs_image_analysis
# - Produces:
#   1) Overall distributions (Gender / Skin / Role)
#   2) Cross-tabs: Gender×Role, Skin×Role + chi-square + standardized residuals
#   3) Heatmaps (row % and standardized residuals)
#   4) Optional logistic regression: Elite (1) vs General Public (0)

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import chi2_contingency

# ------------------------- Paths & Setup -------------------------
IN_XLSX = "final_images_data1.xlsx"   # change if needed
OUT_DIR = "outputs_image_analysis"
os.makedirs(OUT_DIR, exist_ok=True)

# --------------------------- Load data ---------------------------
df = pd.read_excel(IN_XLSX, sheet_name="Sheet1")

# Map categorical codes to labels
gender_map = {1: "Male", 2: "Female", 3: "Unclear"}
skin_map   = {1: "White", 2: "Black", 3: "Beige", 4: "Mixed/Unclear", 99: "Unknown"}
role_map   = {1: "Elite", 2: "General Public", 99: "Unclear/Abstract"}

df["Gender"] = df["Core Figure Gender"].map(gender_map)
df["Skin"]   = df["Core Figure skin"].map(skin_map)
df["Role"]   = df["role"].map(role_map)

# -------------------- Helper: save bar chart ---------------------
def save_bar_counts(series, title, fname, rotation=0):
    counts = series.value_counts().sort_index()
    fig = plt.figure(figsize=(7,4))
    ax = fig.add_subplot(111)
    ax.bar(counts.index.astype(str), counts.values)  # default colors
    ax.set_title(title)
    ax.set_ylabel("Count")
    ax.set_xlabel("Category")
    plt.xticks(rotation=rotation)
    plt.tight_layout()
    out = os.path.join(OUT_DIR, fname)
    plt.savefig(out, dpi=220)
    plt.close(fig)
    return out

# Save counts & shares to CSV
def counts_and_perc(series, name):
    cnt = series.value_counts(dropna=False)
    pct = (cnt / cnt.sum()).round(4)
    out = pd.DataFrame({f"{name}_count": cnt, f"{name}_share": pct})
    out_path = os.path.join(OUT_DIR, f"overall_{name.lower()}_counts_shares.csv")
    out.to_csv(out_path)
    return out_path

# ------------------ 1) Overall distributions --------------------
paths = {}
paths["gender_csv"] = counts_and_perc(df["Gender"], "Gender")
paths["skin_csv"]   = counts_and_perc(df["Skin"],   "Skin")
paths["role_csv"]   = counts_and_perc(df["Role"],   "Role")

paths["gender_bar"] = save_bar_counts(df["Gender"], "Overall Distribution: Gender", "overall_gender_bar.png")
paths["skin_bar"]   = save_bar_counts(df["Skin"],   "Overall Distribution: Skin",   "overall_skin_bar.png", rotation=20)
paths["role_bar"]   = save_bar_counts(df["Role"],   "Overall Distribution: Role",   "overall_role_bar.png", rotation=15)

# --------------- 2) Cross-tabs & chi-square tests ---------------
def chi2_and_exports(row_var, col_var, prefix):
    ct = pd.crosstab(df[row_var], df[col_var])  # counts
    chi2, p, dof, expected = chi2_contingency(ct)
    exp = pd.DataFrame(expected, index=ct.index, columns=ct.columns)
    resid_std = (ct - exp) / np.sqrt(exp)       # standardized residuals
    row_pct = ct.div(ct.sum(axis=1), axis=0)    # row-normalized shares

    # Save tables
    ct_path   = os.path.join(OUT_DIR, f"{prefix}_crosstab_counts.csv")
    exp_path  = os.path.join(OUT_DIR, f"{prefix}_expected_counts.csv")
    res_path  = os.path.join(OUT_DIR, f"{prefix}_std_residuals.csv")
    pct_path  = os.path.join(OUT_DIR, f"{prefix}_row_percent.csv")
    ct.to_csv(ct_path); exp.to_csv(exp_path); resid_std.to_csv(res_path); row_pct.to_csv(pct_path)

    # Save a small report
    rep_path = os.path.join(OUT_DIR, f"{prefix}_chi2_report.txt")
    with open(rep_path, "w", encoding="utf-8") as f:
        f.write(f"Chi-square test: {row_var} × {col_var}\n")
        f.write(f"chi2 = {chi2:.4f}, dof = {dof}, p = {p:.6g}\n\n")
        f.write("Crosstab (counts):\n"); f.write(ct.to_string())
        f.write("\n\nExpected counts:\n"); f.write(exp.to_string())
        f.write("\n\nStandardized residuals:\n"); f.write(resid_std.round(3).to_string()); f.write("\n")
    return {
        "ct": ct_path, "exp": exp_path, "resid": res_path, "pct": pct_path, "report": rep_path,
        "row_pct_df": row_pct, "resid_df": resid_std
    }

res_gender_role = chi2_and_exports("Gender", "Role", "gender_role")
res_skin_role   = chi2_and_exports("Skin",   "Role", "skin_role")

# ------------------ 3) Heatmaps with matplotlib -----------------
def save_heatmap(df_mat, title, fname, fmt="percent"):
    fig = plt.figure(figsize=(6,5))
    ax = fig.add_subplot(111)
    mat = df_mat.values
    im = ax.imshow(mat, aspect="auto")  # default colormap

    # Ticks
    ax.set_xticks(np.arange(df_mat.shape[1]))
    ax.set_yticks(np.arange(df_mat.shape[0]))
    ax.set_xticklabels(df_mat.columns.astype(str))
    ax.set_yticklabels(df_mat.index.astype(str), rotation=0)

    # Annotate each cell
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            val = mat[i, j]
            txt = f"{val*100:.1f}%" if fmt == "percent" else f"{val:.2f}"
            ax.text(j, i, txt, ha="center", va="center")

    ax.set_title(title)
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    out = os.path.join(OUT_DIR, fname)
    plt.savefig(out, dpi=220)
    plt.close(fig)
    return out

paths["hm_gender_role_pct"] = save_heatmap(res_gender_role["row_pct_df"], "Gender × Role (Row %)", "heatmap_gender_role_pct.png", fmt="percent")
paths["hm_skin_role_pct"]   = save_heatmap(res_skin_role["row_pct_df"],   "Skin × Role (Row %)",   "heatmap_skin_role_pct.png",   fmt="percent")
paths["hm_gender_role_res"] = save_heatmap(res_gender_role["resid_df"],  "Gender × Role (Std Residuals)", "heatmap_gender_role_resid.png", fmt="resid")
paths["hm_skin_role_res"]   = save_heatmap(res_skin_role["resid_df"],    "Skin × Role (Std Residuals)",   "heatmap_skin_role_resid.png",   fmt="resid")

# ----------- 4) Optional logistic regression (Elite vs Public) -----------
# Predict probability of "Elite" (1) vs "General Public" (0) using Gender & Skin
df_lr = df[df["Role"].isin(["Elite", "General Public"])].copy()
df_lr["y_elite"] = (df_lr["Role"] == "Elite").astype(int)

# One-hot encode predictors (drop first to avoid multicollinearity)
X = pd.get_dummies(df_lr[["Gender", "Skin"]], drop_first=True)
y = df_lr["y_elite"]

logit_summary_path = os.path.join(OUT_DIR, "logit_elite_vs_public_summary.txt")
coef_csv_path      = os.path.join(OUT_DIR, "logit_coefficients_odds.csv")

try:
    import statsmodels.api as sm
    X_const = sm.add_constant(X)
    model = sm.Logit(y, X_const, missing="drop").fit(disp=False)
    with open(logit_summary_path, "w", encoding="utf-8") as f:
        f.write(model.summary().as_text())
    params = model.params
    conf = model.conf_int()
    or_df = pd.DataFrame({
        "coef": params,
        "odds_ratio": np.exp(params),
        "ci_low": np.exp(conf[0]),
        "ci_high": np.exp(conf[1]),
        "p_value": model.pvalues
    })
    or_df.to_csv(coef_csv_path)
except Exception as e:
    # Fallback (no p-values)
    from sklearn.linear_model import LogisticRegression
    lr = LogisticRegression(max_iter=1000)
    lr.fit(X.values, y.values)
    coef = lr.coef_[0]
    intercept = lr.intercept_[0]
    or_df = pd.Series(np.exp(np.r_[intercept, coef]), index=["Intercept"] + list(X.columns)).to_frame("odds_ratio")
    or_df.to_csv(coef_csv_path)
    with open(logit_summary_path, "w", encoding="utf-8") as f:
        f.write("statsmodels not available; used sklearn LogisticRegression (no p-values)\n")
        f.write(or_df.to_string())

# ------------------------- Output index --------------------------
index_md = os.path.join(OUT_DIR, "_INDEX.txt")
with open(index_md, "w", encoding="utf-8") as f:
    f.write("Generated files (image-only analysis):\n\n")
    for k, v in paths.items():
        f.write(f"{k}: {v}\n")
    f.write("\nCross-tabs & chi2 reports:\n")
    for key, res in [("gender_role", res_gender_role), ("skin_role", res_skin_role)]:
        f.write(f"\n{key}:\n")
        f.write(f" - counts: {res['ct']}\n")
        f.write(f" - expected: {res['exp']}\n")
        f.write(f" - std residuals: {res['resid']}\n")
        f.write(f" - row percent: {res['pct']}\n")
        f.write(f" - report: {res['report']}\n")
    f.write(f"\nLogistic regression summary: {logit_summary_path}\n")
    f.write(f"Logistic regression coefficients & odds ratios: {coef_csv_path}\n")

print(f"Done. Outputs saved to: {OUT_DIR}")


In [None]:
# ----------- 5) Yearly trends for Gender / Skin / Role -----------

def proportions_by_year(col):
    """Return row-normalized shares of each category by Year."""
    tab = pd.crosstab(df["Year"], df[col], normalize="index")
    return tab

def save_line_from_table(tab, title, fname):
    fig = plt.figure(figsize=(8,5))
    ax = fig.add_subplot(111)
    for c in tab.columns:
        ax.plot(tab.index, tab[c], marker="o", label=str(c))
    ax.legend()
    ax.set_title(title)
    ax.set_ylabel("Share")
    ax.set_xlabel("Year")
    plt.tight_layout()
    out = os.path.join(OUT_DIR, fname)
    plt.savefig(out, dpi=220)
    plt.close(fig)
    return out

# 1. Gender by year
year_gender = proportions_by_year("Gender")
year_gender.to_csv(os.path.join(OUT_DIR, "year_gender_shares.csv"))
save_line_from_table(year_gender, "Gender Shares by Year", "year_gender_lines.png")

# 2. Skin by year
year_skin = proportions_by_year("Skin")
year_skin.to_csv(os.path.join(OUT_DIR, "year_skin_shares.csv"))
save_line_from_table(year_skin, "Skin Shares by Year", "year_skin_lines.png")

# 3. Role by year
year_role = proportions_by_year("Role")
year_role.to_csv(os.path.join(OUT_DIR, "year_role_shares.csv"))
save_line_from_table(year_role, "Role Shares by Year", "year_role_lines.png")