In [None]:
# ============================================================================
# COLOR EXTRACTION - REVISED WITH ADAPTIVE DOMAIN HANDLING
# ============================================================================

from pathlib import Path
import pandas as pd
import numpy as np
import requests
import warnings
from collections import Counter
from io import BytesIO
from PIL import Image
from sklearn.cluster import KMeans
import time
import re

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

BASE_DIR = Path('/Users/leoss')
DATA_DIR = BASE_DIR / 'Downloads'
OUTPUT_BASE = BASE_DIR / 'Desktop/Portfolio/Website-/UK pink tax/Outputs'

PATH_MAIN_DATA = DATA_DIR / 'items_fin.csv'
COLOR_CACHE_PATH = OUTPUT_BASE / 'color_features_cache_v3_filtered.csv'
FAILED_URLS_PATH = OUTPUT_BASE / 'failed_urls.csv'

# Extraction settings
N_COLORS = 3
TIMEOUT = 15
MAX_SAMPLES = 20000
PRIORITIZE_GENDERED = False
MAX_SAMPLES = 20000  # already set, just confirming — 5,517 is well under this

# Retry settings (only for transient errors)
MAX_RETRIES = 1
RETRY_DELAY = 1

# Incremental save interval
SAVE_EVERY = 200

# Adaptive domain thresholds
MIN_DOMAIN_SAMPLES = 30        # need this many attempts before judging a domain
HEAD_CHECK_THRESHOLD = 0.40    # if success rate < 40%, use HEAD pre-checks
SKIP_THRESHOLD = 0.05          # if success rate < 5%, skip entirely

# Categories to exclude
EXCLUDE_CATEGORIES = [
    'food', 'grocery', 'groceries', 'snacks', 'drinks', 'beverages',
    'pet food', 'pet supplies', 'cleaning', 'household', 'kitchen',
    'office', 'stationery', 'electronics', 'tech', 'garden', 'automotive'
]

COL_PRODUCT_ID = 'product_id'
COL_IMAGE = 'image_url'
COL_BREADCRUMB = 'standardized_breadcrumbs'
COL_NAME = 'product_title_x'
COL_DESC = 'description'

RANDOM_STATE = 42

# ============================================================================
# COLOR DEFINITIONS
# ============================================================================

STANDARD_COLORS = {
    'dark_red': (139, 0, 0), 'red': (255, 0, 0), 'coral': (255, 127, 80),
    'salmon': (250, 128, 114), 'crimson': (220, 20, 60), 'brown': (139, 69, 19),
    'tan': (210, 180, 140), 'orange': (255, 165, 0), 'gold': (255, 215, 0),
    'yellow': (255, 255, 0), 'khaki': (240, 230, 140), 'dark_green': (0, 100, 0),
    'green': (0, 128, 0), 'lime': (50, 205, 50), 'olive': (128, 128, 0),
    'teal': (0, 128, 128), 'navy': (0, 0, 128), 'blue': (0, 0, 255),
    'royal_blue': (65, 105, 225), 'sky_blue': (135, 206, 235), 'cyan': (0, 255, 255),
    'purple': (128, 0, 128), 'magenta': (255, 0, 255), 'violet': (238, 130, 238),
    'lavender': (230, 230, 250), 'pink': (255, 192, 203), 'hot_pink': (255, 105, 180),
    'gray': (128, 128, 128), 'silver': (192, 192, 192),
    'black': (0, 0, 0), 'white': (255, 255, 255),
}

FEMALE_KEYWORDS = [
    'women', 'woman', 'female', 'ladies', 'lady', 'girls',
    'womens', "women's", 'femme', 'her', 'feminine', 'fem',
]
MALE_KEYWORDS = [
    'men', 'man', 'male', 'gentleman', 'gentlemen', 'boys',
    'mens', "men's", 'homme', 'his', 'masculine',
]

# Shared session for connection pooling
SESSION = requests.Session()
SESSION.headers.update({
    'User-Agent': (
        'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
        'AppleWebKit/537.36 (KHTML, like Gecko) '
        'Chrome/91.0.4472.124 Safari/537.36'
    ),
    'Accept': 'image/avif,image/webp,image/apng,image/*,*/*;q=0.8',
    'Accept-Language': 'en-US,en;q=0.9',
})


# ============================================================================
# ADAPTIVE DOMAIN TRACKER
# ============================================================================

class DomainTracker:
    """
    Tracks per-domain success/failure rates during the run.
    After MIN_DOMAIN_SAMPLES attempts, adjusts strategy:
      - success rate < HEAD_CHECK_THRESHOLD → HEAD pre-check before GET
      - success rate < SKIP_THRESHOLD → skip entirely
    """

    def __init__(self, min_samples, head_threshold, skip_threshold):
        self.min_samples = min_samples
        self.head_threshold = head_threshold
        self.skip_threshold = skip_threshold
        self.attempts = Counter()   # domain → total attempts
        self.successes = Counter()  # domain → successful extractions
        self._notified_head = set()
        self._notified_skip = set()

    def record(self, domain, success):
        self.attempts[domain] += 1
        if success:
            self.successes[domain] += 1

    def success_rate(self, domain):
        total = self.attempts[domain]
        if total == 0:
            return 1.0
        return self.successes[domain] / total

    def should_skip(self, domain):
        if self.attempts[domain] < self.min_samples:
            return False
        skip = self.success_rate(domain) < self.skip_threshold
        if skip and domain not in self._notified_skip:
            rate = self.success_rate(domain)
            print(f"  [domain tracker] Skipping {domain} "
                  f"(success rate {rate:.0%} after {self.attempts[domain]} attempts)")
            self._notified_skip.add(domain)
        return skip

    def should_head_check(self, domain):
        if self.attempts[domain] < self.min_samples:
            return False
        rate = self.success_rate(domain)
        head_check = rate < self.head_threshold and rate >= self.skip_threshold
        if head_check and domain not in self._notified_head:
            print(f"  [domain tracker] HEAD pre-checking {domain} "
                  f"(success rate {rate:.0%} after {self.attempts[domain]} attempts)")
            self._notified_head.add(domain)
        return head_check

    def summary(self):
        """Return a summary dict of {domain: (successes, attempts, rate)}."""
        out = {}
        for domain in sorted(self.attempts, key=lambda d: self.attempts[d], reverse=True):
            total = self.attempts[domain]
            ok = self.successes[domain]
            out[domain] = (ok, total, ok / total if total else 0)
        return out


# Global tracker instance (initialized in main)
domain_tracker = None


# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def get_domain(url):
    try:
        return str(url).split('/')[2]
    except (IndexError, AttributeError):
        return 'unknown'


def color_distance(c1, c2):
    return np.sqrt(sum((a - b) ** 2 for a, b in zip(c1, c2)))


def closest_standard_color(rgb):
    min_dist = float('inf')
    closest = 'gray'
    for name, std_rgb in STANDARD_COLORS.items():
        dist = color_distance(rgb, std_rgb)
        if dist < min_dist:
            min_dist = dist
            closest = name
    return closest


def is_background_color(rgb):
    r, g, b = rgb
    if r > 240 and g > 240 and b > 240:
        return True
    if r < 15 and g < 15 and b < 15:
        return True
    max_diff = max(abs(r - g), abs(g - b), abs(r - b))
    avg = (r + g + b) / 3
    if max_diff < 20 and 100 < avg < 160:
        return True
    return False


def head_check_alive(url, timeout=5):
    """Quick HEAD request to check if URL exists before full download."""
    try:
        resp = SESSION.head(url, timeout=timeout, allow_redirects=True)
        return resp.status_code == 200
    except requests.RequestException:
        return False


def is_transient_error(status_code):
    """Whether an HTTP error code is worth retrying (server-side / rate limit)."""
    if status_code is None:
        return True
    return status_code >= 500 or status_code == 429


def extract_colors(image_url, n_colors=3, timeout=15, max_retries=1):
    """
    Extract dominant colors from a product image.
    Uses domain_tracker to decide whether to HEAD-check or skip.

    Returns (colors_list | None, error_reason | None).
    """
    url = str(image_url).strip()
    domain = get_domain(url)

    # --- Adaptive domain handling ---
    if domain_tracker.should_skip(domain):
        return None, 'adaptive_skip'

    if domain_tracker.should_head_check(domain):
        if not head_check_alive(url, timeout=5):
            return None, 'head_check_dead'

    # --- Download with selective retries ---
    last_error = None
    for attempt in range(max_retries + 1):
        try:
            response = SESSION.get(url, timeout=timeout, allow_redirects=True)

            if response.status_code != 200:
                last_error = f'http_{response.status_code}'
                # Only retry transient server errors, not 404/403
                if is_transient_error(response.status_code) and attempt < max_retries:
                    time.sleep(RETRY_DELAY)
                    continue
                return None, last_error

            # --- Image processing ---
            img = Image.open(BytesIO(response.content)).convert('RGB')
            img = img.resize((100, 100))

            pixels = np.array(img).reshape(-1, 3)

            non_bg = np.array([p for p in pixels if not is_background_color(tuple(p))])
            if len(non_bg) < 50:
                non_bg = pixels

            n_clusters = min(n_colors + 2, len(non_bg))
            kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            kmeans.fit(non_bg)

            counts = Counter(kmeans.labels_)
            total = len(kmeans.labels_)

            colors = []
            for cluster_id, count in sorted(counts.items(), key=lambda x: x[1], reverse=True):
                rgb = tuple(int(c) for c in kmeans.cluster_centers_[cluster_id])
                if not is_background_color(rgb):
                    colors.append({'name': closest_standard_color(rgb), 'weight': count / total})
                if len(colors) >= n_colors:
                    break

            return (colors, None) if colors else (None, 'no_non_bg_colors')

        except requests.exceptions.Timeout:
            last_error = 'timeout'
        except requests.exceptions.ConnectionError:
            last_error = 'connection_error'
        except requests.exceptions.RequestException:
            last_error = 'request_error'
        except Exception:
            last_error = 'processing_error'

        if attempt < max_retries:
            time.sleep(RETRY_DELAY)

    return None, last_error


def contains_excluded_category(text):
    if pd.isna(text):
        return False
    text_lower = str(text).lower()
    return any(cat in text_lower for cat in EXCLUDE_CATEGORIES)


def extract_gender_label(row):
    def check_gender(text):
        if pd.isna(text) or str(text).strip() == '':
            return 'none'
        text_lower = str(text).lower()
        has_female = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in FEMALE_KEYWORDS)
        has_male = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in MALE_KEYWORDS)
        if has_female and not has_male:
            return 'female'
        elif has_male and not has_female:
            return 'male'
        elif has_female and has_male:
            return 'both'
        return 'none'

    for col in [COL_BREADCRUMB, COL_NAME, COL_DESC]:
        if col in row.index:
            gender = check_gender(row[col])
            if gender in ['female', 'male']:
                return gender
    return 'none'


def save_incremental(new_results, cache_path):
    """Append new results to the cache CSV, deduplicating on product_id."""
    if not new_results:
        return 0

    new_df = pd.DataFrame(new_results)

    if cache_path.exists():
        existing = pd.read_csv(cache_path)
        combined = pd.concat([existing, new_df]).drop_duplicates(subset=[COL_PRODUCT_ID])
    else:
        combined = new_df

    combined.to_csv(cache_path, index=False)
    return len(combined)


# ============================================================================
# MAIN
# ============================================================================

def extract_and_save_colors():
    global domain_tracker
    domain_tracker = DomainTracker(
        min_samples=MIN_DOMAIN_SAMPLES,
        head_threshold=HEAD_CHECK_THRESHOLD,
        skip_threshold=SKIP_THRESHOLD,
    )

    print("=" * 70)
    print("COLOR EXTRACTION  (adaptive domain handling)")
    print("=" * 70)
    print(f"Timeout: {TIMEOUT}s | Retries: {MAX_RETRIES} (transient only)")
    print(f"Adaptive thresholds: HEAD-check < {HEAD_CHECK_THRESHOLD:.0%} "
          f"success, skip < {SKIP_THRESHOLD:.0%} success "
          f"(after {MIN_DOMAIN_SAMPLES} samples)")
    print()

    OUTPUT_BASE.mkdir(parents=True, exist_ok=True)

    # ---- Resume support ----
    already_done = set()
    if COLOR_CACHE_PATH.exists():
        existing = pd.read_csv(COLOR_CACHE_PATH)
        if COL_PRODUCT_ID in existing.columns:
            already_done = set(existing[COL_PRODUCT_ID].astype(str))
            print(f"Existing cache: {len(already_done):,} products — will resume.\n")

    # ---- Load & filter ----
    print("STEP 1: Loading data...")
    df = pd.read_csv(PATH_MAIN_DATA, encoding='latin-1')
    df.columns = df.columns.str.lower().str.strip().str.replace(' ', '_')
    if 'unnamed:_0' in df.columns:
        df = df.drop(columns=['unnamed:_0'])
    print(f"  Loaded {len(df):,} products")

    print("\nSTEP 2: Filtering categories...")
    df['is_excluded'] = df[COL_BREADCRUMB].apply(contains_excluded_category)
    df = df[~df['is_excluded']].copy().reset_index(drop=True)
    print(f"  Remaining: {len(df):,} products")

    print("\nSTEP 3: Gender labels...")
    df['label_extracted'] = df.apply(extract_gender_label, axis=1)
    for label, count in df['label_extracted'].value_counts().items():
        print(f"  {label}: {count:,}")

    print("\nSTEP 4: Selecting products...")
    df_with_images = df[df[COL_IMAGE].notna()].copy()

    # Show domain breakdown before any filtering
    df_with_images['_domain'] = df_with_images[COL_IMAGE].apply(get_domain)
    print("  Domain breakdown:")
    for domain, count in df_with_images['_domain'].value_counts().items():
        print(f"    {domain}: {count:,}")

    if PRIORITIZE_GENDERED:
        gendered = df_with_images[df_with_images['label_extracted'].isin(['female', 'male'])].copy()
        none_products = df_with_images[df_with_images['label_extracted'] == 'none']
        none_sample_size = min(NONE_SAMPLE_SIZE, len(none_products))
        none_sample = (
            none_products.sample(n=none_sample_size, random_state=RANDOM_STATE)
            if none_sample_size > 0 else pd.DataFrame()
        )
        to_extract = pd.concat([gendered, none_sample]).drop_duplicates()
        print(f"  Gendered: {len(gendered):,} | None sample: {len(none_sample):,}")
    else:
        to_extract = df_with_images.copy()

    if len(to_extract) > MAX_SAMPLES:
        to_extract = to_extract.sample(n=MAX_SAMPLES, random_state=RANDOM_STATE)

    # Skip already-processed
    to_extract = to_extract[
        ~to_extract[COL_PRODUCT_ID].astype(str).isin(already_done)
    ]
    print(f"  To process (after resume filter): {len(to_extract):,}")

    # ---- Extract ----
    print(f"\nSTEP 5: Extracting colors (top {N_COLORS} per image)...")
    print()

    color_results = []
    failed_records = []
    error_counter = Counter()
    start_time = time.time()

    for idx, (row_idx, row) in enumerate(to_extract.iterrows()):
        if (idx + 1) % 100 == 0 or idx == 0:
            elapsed = time.time() - start_time
            rate = (idx + 1) / elapsed if elapsed > 0 else 0
            remaining = (len(to_extract) - idx - 1) / rate if rate > 0 else 0
            ok = len(color_results)
            total_so_far = idx + 1
            pct = 100 * ok / total_so_far if total_so_far else 0
            print(
                f"  {idx+1:,}/{len(to_extract):,} "
                f"| OK: {ok} ({pct:.0f}%) "
                f"| {rate:.1f} img/s "
                f"| ETA: {remaining/60:.0f} min"
            )

        url = row[COL_IMAGE]
        domain = get_domain(str(url))

        colors, error = extract_colors(
            url, n_colors=N_COLORS, timeout=TIMEOUT, max_retries=MAX_RETRIES
        )

        success = colors is not None
        domain_tracker.record(domain, success)

        if success:
            entry = {
                'original_index': row_idx,
                COL_PRODUCT_ID: row[COL_PRODUCT_ID],
                'label_extracted': row['label_extracted'],
            }
            for i, c in enumerate(colors):
                entry[f'color{i+1}_name'] = c['name']
                entry[f'color{i+1}_weight'] = c['weight']
            color_results.append(entry)
        else:
            error_counter[error] += 1
            failed_records.append({
                'product_id': row[COL_PRODUCT_ID],
                'url': url,
                'label': row['label_extracted'],
                'error': error,
            })

        # Incremental save
        if (idx + 1) % SAVE_EVERY == 0 and color_results:
            total_in_cache = save_incremental(color_results, COLOR_CACHE_PATH)
            print(f"  [checkpoint] {total_in_cache:,} products in cache")

    # ---- Final save ----
    print(f"\nSTEP 6: Saving...")
    if color_results:
        total_in_cache = save_incremental(color_results, COLOR_CACHE_PATH)
        print(f"  Cache: {total_in_cache:,} products → {COLOR_CACHE_PATH}")

    if failed_records:
        pd.DataFrame(failed_records).to_csv(FAILED_URLS_PATH, index=False)
        print(f"  Failed URLs log → {FAILED_URLS_PATH}")

    # ---- Summary ----
    total = len(to_extract)
    success_count = len(color_results)
    success_rate = 100 * success_count / total if total else 0

    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    print(f"Processed: {total:,}")
    print(f"Extracted: {success_count:,} ({success_rate:.1f}%)")
    print(f"Failed:    {total - success_count:,} ({100 - success_rate:.1f}%)")

    print("\nFailure breakdown:")
    for error, count in error_counter.most_common(10):
        print(f"  {error}: {count} ({100*count/total:.1f}%)")

    print("\nPer-domain results:")
    print(f"  {'domain':<40s} {'ok':>5s} / {'total':>5s}  {'rate':>6s}")
    print(f"  {'-'*60}")
    for domain, (ok, tot, rate) in domain_tracker.summary().items():
        print(f"  {domain:<40s} {ok:>5d} / {tot:>5d}  {rate:>5.0%}")

    if success_count:
        color_df = pd.read_csv(COLOR_CACHE_PATH)
        if 'label_extracted' in color_df.columns:
            print(f"\nGender split in cache:")
            for label, count in color_df['label_extracted'].value_counts().items():
                print(f"  {label}: {count:,}")
        print(f"\nFile size: {COLOR_CACHE_PATH.stat().st_size / 1024:.1f} KB")


if __name__ == "__main__":
    try:
        extract_and_save_colors()
    except KeyboardInterrupt:
        print("\n\nInterrupted. Partial results saved if any checkpoints were reached.")
    except Exception as e:
        print(f"\nERROR: {e}")
        import traceback
        traceback.print_exc()

COLOR EXTRACTION  (adaptive domain handling)
Timeout: 15s | Retries: 1 (transient only)
Adaptive thresholds: HEAD-check < 40% success, skip < 5% success (after 30 samples)

Existing cache: 1,163 products — will resume.

STEP 1: Loading data...
  Loaded 21,436 products

STEP 2: Filtering categories...
  Remaining: 12,832 products

STEP 3: Gender labels...
  none: 10,913
  female: 1,075
  male: 844

STEP 4: Selecting products...
  Domain breakdown:
    digitalcontent.api.tesco.com: 6,973
    groceries.morrisons.com: 5,517
    ui.assets-asda.com:443: 342
  To process (after resume filter): 11,669

STEP 5: Extracting colors (top 3 per image)...

  1/11,669 | OK: 0 (0%) | 92.7 img/s | ETA: 2 min
  [domain tracker] HEAD pre-checking digitalcontent.api.tesco.com (success rate 33% after 30 attempts)
  100/11,669 | OK: 10 (10%) | 5.2 img/s | ETA: 37 min
  200/11,669 | OK: 10 (5%) | 7.5 img/s | ETA: 25 min
  [checkpoint] 1,173 products in cache
  [domain tracker] Skipping digitalcontent.api.tesc

In [7]:
# ============================================================================
# PINK TAX ANALYSIS: ML-BASED GENDER PREDICTION PIPELINE v4
# ============================================================================
#
# Changes from v3:
#   - Color extraction removed (uses pre-built cache: 5,617 products)
#   - Color features merged on product_id, not index (fixes alignment bug
#     caused by reset_index after category filtering)
#   - Morrisons-only color data noted as limitation
# ============================================================================

# ============================================================================
# CONFIGURATION
# ============================================================================

from pathlib import Path

BASE_DIR = Path('/Users/leoss')
DATA_DIR = BASE_DIR / 'Downloads'
OUTPUT_BASE = BASE_DIR / 'Desktop/Portfolio/Website-/UK pink tax/Outputs'

PATH_MAIN_DATA = DATA_DIR / 'items_fin.csv'
PATH_HUMAN_CODED = DATA_DIR / 'items_prices_description_gender_humancode_sample.csv'
PATH_YOUR_LABELED = DATA_DIR / 'available_validation.xlsx'
OUTPUT_DIR = OUTPUT_BASE / 'charts/ml_pipeline_v4'

# Pre-built color cache (5,617 products, Morrisons only)
COLOR_CACHE_PATH = OUTPUT_BASE / 'color_features_cache_v3_filtered.csv'

# ML settings
RANDOM_STATE = 42
TEST_SIZE = 0.25
CV_FOLDS = 5
N_COLORS = 3

# Categories to exclude
EXCLUDE_CATEGORIES = [
    'food', 'grocery', 'groceries', 'snacks', 'drinks', 'beverages',
    'pet food', 'pet supplies', 'cleaning', 'household', 'kitchen',
    'office', 'stationery', 'electronics', 'tech', 'garden', 'automotive'
]

# Edge case thresholds
MIN_CLASS_SIZE = 50
MIN_TEST_SAMPLES = 10

# ============================================================================
# IMPORTS
# ============================================================================

import pandas as pd
import numpy as np
import re
import json
import warnings
from collections import Counter

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.metrics import (accuracy_score, f1_score,
                             confusion_matrix, classification_report)
from scipy.sparse import hstack, csr_matrix

import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("✓ Imports complete")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ============================================================================
# COLUMN NAMES
# ============================================================================

COL_NAME = 'product_title_x'
COL_DESC = 'description'
COL_BREADCRUMB = 'standardized_breadcrumbs'
COL_PRICE = 'price'
COL_UNIT_PRICE = 'unit_price'
COL_IMAGE = 'image_url'
COL_STORE = 'store_id'
COL_PRODUCT_ID = 'product_id'
COL_URL = 'product_url_x'

FEMALE_KEYWORDS = ['women', 'woman', 'female', 'ladies', 'lady', 'girls',
                   'womens', "women's", 'femme', 'her', 'feminine', 'fem']
MALE_KEYWORDS = ['men', 'man', 'male', 'gentleman', 'gentlemen', 'boys',
                 'mens', "men's", 'homme', 'his', 'masculine']
ALL_GENDER_KEYWORDS = set(FEMALE_KEYWORDS + MALE_KEYWORDS)

STANDARD_COLORS = {
    'dark_red': (139, 0, 0), 'red': (255, 0, 0), 'coral': (255, 127, 80),
    'salmon': (250, 128, 114), 'crimson': (220, 20, 60), 'brown': (139, 69, 19),
    'tan': (210, 180, 140), 'orange': (255, 165, 0), 'gold': (255, 215, 0),
    'yellow': (255, 255, 0), 'khaki': (240, 230, 140), 'dark_green': (0, 100, 0),
    'green': (0, 128, 0), 'lime': (50, 205, 50), 'olive': (128, 128, 0),
    'teal': (0, 128, 128), 'navy': (0, 0, 128), 'blue': (0, 0, 255),
    'royal_blue': (65, 105, 225), 'sky_blue': (135, 206, 235), 'cyan': (0, 255, 255),
    'purple': (128, 0, 128), 'magenta': (255, 0, 255), 'violet': (238, 130, 238),
    'lavender': (230, 230, 250), 'pink': (255, 192, 203), 'hot_pink': (255, 105, 180),
    'gray': (128, 128, 128), 'silver': (192, 192, 192),
    'black': (0, 0, 0), 'white': (255, 255, 255),
}

# ============================================================================
# LOAD DATA
# ============================================================================

print("\n" + "="*70)
print("LOADING DATA")
print("="*70)

if not PATH_MAIN_DATA.exists():
    raise FileNotFoundError(f"Main data file not found: {PATH_MAIN_DATA}")

df = pd.read_csv(PATH_MAIN_DATA, encoding='latin-1')
df.columns = df.columns.str.lower().str.strip().str.replace(' ', '_')
if 'unnamed:_0' in df.columns:
    df = df.drop(columns=['unnamed:_0'])
print(f"✓ Main dataset: {len(df):,} products")

if not PATH_HUMAN_CODED.exists():
    raise FileNotFoundError(f"Human-coded file not found: {PATH_HUMAN_CODED}")

human_coded = pd.read_csv(PATH_HUMAN_CODED, encoding='latin-1')
human_coded.columns = human_coded.columns.str.lower().str.strip().str.replace(' ', '_')
if 'unnamed:_0' in human_coded.columns:
    human_coded = human_coded.drop(columns=['unnamed:_0'])
print(f"✓ Human-coded: {len(human_coded)} products")

try:
    your_labeled = pd.read_excel(PATH_YOUR_LABELED)
    your_labeled.columns = your_labeled.columns.str.lower().str.strip().str.replace(' ', '_')
    print(f"✓ Your labeled: {len(your_labeled)} products")
except FileNotFoundError:
    your_labeled = pd.DataFrame()
    print("⚠ Your labeled file not found (optional)")

# ============================================================================
# STEP 1: FILTER CATEGORIES
# ============================================================================

print("\n" + "="*70)
print("STEP 1: FILTER NON-GENDERED CATEGORIES")
print("="*70)

original_count = len(df)

def contains_excluded_category(text):
    if pd.isna(text):
        return False
    text_lower = str(text).lower()
    return any(cat in text_lower for cat in EXCLUDE_CATEGORIES)

df['is_excluded'] = df[COL_BREADCRUMB].apply(contains_excluded_category)
excluded_count = df['is_excluded'].sum()

print(f"Original: {original_count:,}")
print(f"Excluded: {excluded_count:,}")

df = df[~df['is_excluded']].copy().reset_index(drop=True)
print(f"Remaining: {len(df):,}")

if len(df) == 0:
    raise ValueError("No products remaining after filtering.")

# ============================================================================
# STEP 2: EXTRACT GENDER LABELS
# ============================================================================

print("\n" + "="*70)
print("STEP 2: EXTRACT GENDER LABELS")
print("="*70)

def extract_gender_explicit(text):
    if pd.isna(text) or str(text).strip() == '':
        return 'none'
    text_lower = str(text).lower()
    has_female = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in FEMALE_KEYWORDS)
    has_male = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in MALE_KEYWORDS)
    if has_female and not has_male:
        return 'female'
    elif has_male and not has_female:
        return 'male'
    elif has_female and has_male:
        return 'both'
    return 'none'

df['label_bc'] = df[COL_BREADCRUMB].apply(extract_gender_explicit)
df['label_name'] = df[COL_NAME].apply(extract_gender_explicit)
df['label_desc'] = df[COL_DESC].apply(extract_gender_explicit)

def combine_labels(row):
    for col in ['label_bc', 'label_name', 'label_desc']:
        if row[col] in ['female', 'male']:
            return row[col]
    return 'none'

df['label_extracted'] = df.apply(combine_labels, axis=1)

print(f"Label distribution:")
print(df['label_extracted'].value_counts().to_dict())

# Merge human labels
if 'human_gender_label' in human_coded.columns and COL_PRODUCT_ID in human_coded.columns:
    human_labels = human_coded[[COL_PRODUCT_ID, 'human_gender_label']].drop_duplicates()
    human_labels.columns = [COL_PRODUCT_ID, 'label_human']
    human_labels['label_human'] = human_labels['label_human'].str.lower().str.strip()
    df = df.merge(human_labels, on=COL_PRODUCT_ID, how='left')
else:
    df['label_human'] = None

human_count = df['label_human'].notna().sum()
print(f"Human labels merged: {human_count}")

# ============================================================================
# STEP 3: LOAD COLOR CACHE
# ============================================================================

print("\n" + "="*70)
print("STEP 3: LOAD COLOR DATA")
print("="*70)

if COLOR_CACHE_PATH.exists():
    color_df = pd.read_csv(COLOR_CACHE_PATH)
    print(f"✓ Loaded color cache: {len(color_df):,} products")
    if 'label_extracted' in color_df.columns:
        print(f"  Gender split: {color_df['label_extracted'].value_counts().to_dict()}")

    # Check how many match the filtered dataset
    matched = color_df[COL_PRODUCT_ID].isin(df[COL_PRODUCT_ID]).sum()
    print(f"  Matched to filtered data: {matched:,} / {len(color_df):,}")
else:
    print("⚠ No color cache found — proceeding without color features")
    color_df = pd.DataFrame()

# ============================================================================
# STEP 4: PREPARE TRAINING DATA
# ============================================================================

print("\n" + "="*70)
print("STEP 4: PREPARE TRAINING DATA")
print("="*70)

female_all = df[df['label_extracted'] == 'female'].copy()
male_all = df[df['label_extracted'] == 'male'].copy()

print(f"Explicitly female: {len(female_all)}")
print(f"Explicitly male: {len(male_all)}")

if len(female_all) < MIN_CLASS_SIZE or len(male_all) < MIN_CLASS_SIZE:
    raise ValueError(f"Insufficient gendered samples. Need at least {MIN_CLASS_SIZE} per class.")

# None class: human-coded none + sample from extracted none
human_none = df[(df['label_human'] == 'none')].copy()
extracted_none = df[(df['label_extracted'] == 'none') & (df['label_human'].isna())].copy()

min_gendered = min(len(female_all), len(male_all))
target_none = min_gendered

print(f"\nNone class sources:")
print(f"  Human-coded none: {len(human_none)}")
print(f"  Extracted none (unlabeled): {len(extracted_none)}")

total_none_available = len(human_none) + len(extracted_none)
if total_none_available < MIN_CLASS_SIZE:
    raise ValueError(f"Insufficient 'none' samples. Need at least {MIN_CLASS_SIZE}.")

if len(human_none) >= target_none:
    none_all = human_none.sample(n=target_none, random_state=RANDOM_STATE)
else:
    remaining = target_none - len(human_none)
    sampled_none = extracted_none.sample(
        n=min(remaining, len(extracted_none)), random_state=RANDOM_STATE
    )
    none_all = pd.concat([human_none, sampled_none])

print(f"  Final none class: {len(none_all)}")

# Balance
min_class = min(len(female_all), len(male_all), len(none_all))
print(f"\nBalancing to: {min_class} per class")

if min_class < MIN_CLASS_SIZE:
    raise ValueError(f"After balancing, class size ({min_class}) < minimum ({MIN_CLASS_SIZE})")

female_balanced = female_all.sample(n=min_class, random_state=RANDOM_STATE)
male_balanced = male_all.sample(n=min_class, random_state=RANDOM_STATE)
none_balanced = none_all.sample(n=min_class, random_state=RANDOM_STATE)

ml_data = pd.concat([female_balanced, male_balanced, none_balanced]).copy()

ml_data['target'] = ml_data['label_extracted'].map({'female': 0, 'male': 1})
ml_data.loc[ml_data['target'].isna(), 'target'] = 2
ml_data['target'] = ml_data['target'].astype(int)

print(f"\n✓ Training data: {len(ml_data)} products")
print(f"  Class distribution: {ml_data['target'].value_counts().sort_index().to_dict()}")
print(f"  (0=female, 1=male, 2=none)")

# ============================================================================
# STEP 5: TRAIN-TEST SPLIT
# ============================================================================

print("\n" + "="*70)
print("STEP 5: TRAIN-TEST SPLIT")
print("="*70)

train_idx, test_idx = train_test_split(
    ml_data.index, test_size=TEST_SIZE,
    random_state=RANDOM_STATE, stratify=ml_data['target']
)

train_data = ml_data.loc[train_idx].copy()
test_data = ml_data.loc[test_idx].copy()

print(f"Train: {len(train_data)}")
print(f"Test: {len(test_data)}")
print(f"Train distribution: {train_data['target'].value_counts().sort_index().to_dict()}")

# ============================================================================
# STEP 6: FEATURE ENGINEERING
# ============================================================================

print("\n" + "="*70)
print("STEP 6: FEATURE ENGINEERING")
print("="*70)

def clean_text_remove_gender(text, remove_words=ALL_GENDER_KEYWORDS):
    if pd.isna(text):
        return ''
    text = str(text).lower()
    for word in remove_words:
        text = re.sub(r'\b' + word + r'\b', '', text)
    text = re.sub(r'[^a-z\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

all_datasets = [train_data, test_data, df]
for dataset in all_datasets:
    dataset['breadcrumb_clean'] = dataset[COL_BREADCRUMB].apply(clean_text_remove_gender)
    dataset['description_clean'] = dataset[COL_DESC].apply(clean_text_remove_gender)

# --- Price features ---
price_features = ['feat_price_log', 'feat_unit_price']
for dataset in all_datasets:
    dataset['feat_price'] = pd.to_numeric(dataset[COL_PRICE], errors='coerce')
    dataset['feat_price_log'] = np.log1p(dataset['feat_price'])
    if COL_UNIT_PRICE in dataset.columns:
        dataset['feat_unit_price'] = (
            dataset[COL_UNIT_PRICE].astype(str).str.extract(r'([\d.]+)')[0].astype(float)
        )
    else:
        dataset['feat_unit_price'] = 0
print(f"✓ Price features: {len(price_features)}")

# --- Store encoding ---
store_encoder = LabelEncoder()
all_stores = pd.concat([d[COL_STORE] for d in all_datasets]).fillna('unknown')
store_encoder.fit(all_stores.unique())

def encode_stores(data, encoder):
    stores = data[COL_STORE].fillna('unknown')
    encoded = []
    for s in stores:
        if s in encoder.classes_:
            encoded.append(encoder.transform([s])[0])
        else:
            encoded.append(-1)
    return np.array(encoded)

for dataset in all_datasets:
    dataset['store_encoded'] = encode_stores(dataset, store_encoder)

n_stores = len(store_encoder.classes_) + 1
print(f"✓ Store features: {n_stores}")

# --- TF-IDF (fit on train only) ---
breadcrumb_vectorizer = TfidfVectorizer(
    max_features=200, min_df=5, max_df=0.9,
    ngram_range=(1, 2), stop_words='english'
)
breadcrumb_vectorizer.fit(train_data['breadcrumb_clean'])
print(f"✓ Breadcrumb TF-IDF: {len(breadcrumb_vectorizer.get_feature_names_out())} features")

description_vectorizer = TfidfVectorizer(
    max_features=500, min_df=5, max_df=0.9,
    ngram_range=(1, 2), stop_words='english'
)
description_vectorizer.fit(train_data['description_clean'])
print(f"✓ Description TF-IDF: {len(description_vectorizer.get_feature_names_out())} features")

# --- Color features (merged on product_id, not index) ---
color_feature_cols = []
for color_name in STANDARD_COLORS.keys():
    for i in range(1, N_COLORS + 1):
        color_feature_cols.append(f'feat_color{i}_{color_name}')

# Initialize to 0
for dataset in all_datasets:
    for col in color_feature_cols:
        dataset[col] = 0.0

if len(color_df) > 0 and COL_PRODUCT_ID in color_df.columns:
    # Build a lookup: product_id → color features dict
    color_lookup = {}
    for _, row in color_df.iterrows():
        pid = row[COL_PRODUCT_ID]
        feats = {}
        for i in range(1, N_COLORS + 1):
            cname = row.get(f'color{i}_name')
            cweight = row.get(f'color{i}_weight')
            if pd.notna(cname) and cname in STANDARD_COLORS and pd.notna(cweight):
                feats[f'feat_color{i}_{cname}'] = cweight
        if feats:
            color_lookup[pid] = feats

    print(f"  Color lookup built: {len(color_lookup):,} products")

    # Apply to each dataset via product_id
    for dataset in all_datasets:
        matched = 0
        for idx, row in dataset.iterrows():
            pid = row[COL_PRODUCT_ID]
            if pid in color_lookup:
                for col, val in color_lookup[pid].items():
                    dataset.at[idx, col] = val
                matched += 1
        # Only print for train/test, not full df (too noisy)
        if len(dataset) < 10000:
            print(f"  Color features filled for {matched}/{len(dataset)} rows")

    train_has_color = sum(1 for _, r in train_data.iterrows() if r[COL_PRODUCT_ID] in color_lookup)
else:
    color_lookup = {}
    train_has_color = 0

print(f"✓ Color features: {len(color_feature_cols)} "
      f"(available for {train_has_color}/{len(train_data)} train samples)")

# ============================================================================
# STEP 7: BUILD FEATURE MATRICES
# ============================================================================

print("\n" + "="*70)
print("STEP 7: BUILD FEATURE MATRICES")
print("="*70)

def build_feature_matrix(data, bc_vec, desc_vec, color_cols, price_cols,
                         n_stores, include_colors=True):
    feature_names = []
    blocks = []

    # Price
    X_price = data[price_cols].fillna(0).values
    blocks.append(csr_matrix(X_price))
    feature_names.extend(price_cols)

    # Store (one-hot)
    store_enc = data['store_encoded'].values
    X_store = np.zeros((len(data), n_stores))
    for i, s in enumerate(store_enc):
        if s >= 0:
            X_store[i, s] = 1
        else:
            X_store[i, -1] = 1
    blocks.append(csr_matrix(X_store))
    feature_names.extend([f'store_{i}' for i in range(n_stores)])

    # Breadcrumb TF-IDF
    X_bc = bc_vec.transform(data['breadcrumb_clean'])
    blocks.append(X_bc)
    feature_names.extend([f'bc_{f}' for f in bc_vec.get_feature_names_out()])

    # Description TF-IDF
    X_desc = desc_vec.transform(data['description_clean'])
    blocks.append(X_desc)
    feature_names.extend([f'desc_{f}' for f in desc_vec.get_feature_names_out()])

    # Color (optional)
    if include_colors:
        X_color = data[color_cols].values
        blocks.append(csr_matrix(X_color))
        feature_names.extend(color_cols)

    X = hstack(blocks)
    return X, feature_names

X_train, feature_names = build_feature_matrix(
    train_data, breadcrumb_vectorizer, description_vectorizer,
    color_feature_cols, price_features, n_stores, include_colors=True
)
X_test, _ = build_feature_matrix(
    test_data, breadcrumb_vectorizer, description_vectorizer,
    color_feature_cols, price_features, n_stores, include_colors=True
)

y_train = train_data['target'].values
y_test = test_data['target'].values

print(f"X_train: {X_train.shape}")
print(f"X_test: {X_test.shape}")
print(f"Features: {len(feature_names)}")
assert X_train.shape[1] == len(feature_names)
print(f"✓ Feature alignment validated")

# ============================================================================
# STEP 8: TRAIN MODELS
# ============================================================================

print("\n" + "="*70)
print("STEP 8: TRAIN MODELS")
print("="*70)

results = []

# --- L1 Logistic Regression ---
print("\n--- Logistic Regression (L1) ---")
model_l1 = LogisticRegressionCV(
    cv=CV_FOLDS, penalty='l1', solver='saga', max_iter=2000,
    multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE
)
model_l1.fit(X_train, y_train)
y_pred_l1 = model_l1.predict(X_test)
acc_l1 = accuracy_score(y_test, y_pred_l1)
f1_l1 = f1_score(y_test, y_pred_l1, average='weighted')
print(f"Accuracy: {acc_l1:.4f}, F1: {f1_l1:.4f}")
results.append({'Model': 'L1 (LASSO)', 'Accuracy': acc_l1, 'F1_weighted': f1_l1})

# --- L2 Logistic Regression ---
print("\n--- Logistic Regression (L2) ---")
model_l2 = LogisticRegressionCV(
    cv=CV_FOLDS, penalty='l2', solver='lbfgs', max_iter=2000,
    multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE
)
model_l2.fit(X_train, y_train)
y_pred_l2 = model_l2.predict(X_test)
acc_l2 = accuracy_score(y_test, y_pred_l2)
f1_l2 = f1_score(y_test, y_pred_l2, average='weighted')
print(f"Accuracy: {acc_l2:.4f}, F1: {f1_l2:.4f}")
results.append({'Model': 'L2 (Ridge)', 'Accuracy': acc_l2, 'F1_weighted': f1_l2})

# --- Random Forest ---
print("\n--- Random Forest ---")
model_rf = RandomForestClassifier(
    n_estimators=200, max_depth=15, min_samples_split=10, min_samples_leaf=5,
    class_weight='balanced', random_state=RANDOM_STATE, n_jobs=-1
)
model_rf.fit(X_train, y_train)
y_pred_rf = model_rf.predict(X_test)
acc_rf = accuracy_score(y_test, y_pred_rf)
f1_rf = f1_score(y_test, y_pred_rf, average='weighted')
print(f"Accuracy: {acc_rf:.4f}, F1: {f1_rf:.4f}")
results.append({'Model': 'Random Forest', 'Accuracy': acc_rf, 'F1_weighted': f1_rf})

# --- Histogram Gradient Boosting ---
print("\n--- Histogram Gradient Boosting ---")
MAX_HGB_SAMPLES = 50000
if X_train.shape[0] > MAX_HGB_SAMPLES:
    sample_idx = np.random.choice(X_train.shape[0], MAX_HGB_SAMPLES, replace=False)
    X_train_hgb = X_train[sample_idx].toarray()
    y_train_hgb = y_train[sample_idx]
else:
    X_train_hgb = X_train.toarray()
    y_train_hgb = y_train

model_hgb = HistGradientBoostingClassifier(
    max_iter=200, max_depth=10, learning_rate=0.1, random_state=RANDOM_STATE
)
model_hgb.fit(X_train_hgb, y_train_hgb)
y_pred_hgb = model_hgb.predict(X_test.toarray())
acc_hgb = accuracy_score(y_test, y_pred_hgb)
f1_hgb = f1_score(y_test, y_pred_hgb, average='weighted')
print(f"Accuracy: {acc_hgb:.4f}, F1: {f1_hgb:.4f}")
results.append({'Model': 'Hist Gradient Boosting', 'Accuracy': acc_hgb, 'F1_weighted': f1_hgb})

# --- SVM ---
print("\n--- SVM (RBF) ---")
scaler = StandardScaler(with_mean=False)
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

model_svm = SVC(
    kernel='rbf', C=1.0, gamma='scale', class_weight='balanced',
    probability=True, random_state=RANDOM_STATE
)
model_svm.fit(X_train_scaled, y_train)
y_pred_svm = model_svm.predict(X_test_scaled)
acc_svm = accuracy_score(y_test, y_pred_svm)
f1_svm = f1_score(y_test, y_pred_svm, average='weighted')
print(f"Accuracy: {acc_svm:.4f}, F1: {f1_svm:.4f}")
results.append({'Model': 'SVM', 'Accuracy': acc_svm, 'F1_weighted': f1_svm})

# Results table
results_df = pd.DataFrame(results).sort_values('F1_weighted', ascending=False)
print("\n" + "="*70)
print("MODEL COMPARISON")
print("="*70)
print(results_df.to_string(index=False))
results_df.to_csv(OUTPUT_DIR / 'model_comparison.csv', index=False)

# ============================================================================
# STEP 9: BEST MODEL ANALYSIS
# ============================================================================

print("\n" + "="*70)
print("BEST MODEL ANALYSIS")
print("="*70)

best_name = results_df.iloc[0]['Model']
print(f"Best model: {best_name}")

model_map = {
    'L1': (model_l1, y_pred_l1),
    'L2': (model_l2, y_pred_l2),
    'Random': (model_rf, y_pred_rf),
    'Hist': (model_hgb, y_pred_hgb),
    'SVM': (model_svm, y_pred_svm),
}
for key, (model, preds) in model_map.items():
    if key in best_name:
        best_model = model
        y_pred_best = preds
        break

if len(np.unique(y_test)) == 3 and len(np.unique(y_pred_best)) > 0:
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred_best, target_names=['female', 'male', 'none']))

print("\nConfusion Matrix:")
cm = confusion_matrix(y_test, y_pred_best, labels=[0, 1, 2])
print(f"            Predicted")
print(f"            female  male  none")
for i, label in enumerate(['female', 'male', 'none']):
    row = cm[i] if i < len(cm) else [0, 0, 0]
    print(f"Actual {label:6s}  {row[0]:4d}  {row[1]:4d}  {row[2]:4d}")

# ============================================================================
# STEP 10: FEATURE IMPORTANCE
# ============================================================================

print("\n" + "="*70)
print("FEATURE IMPORTANCE (L1 Model)")
print("="*70)

importance_df = pd.DataFrame({
    'feature': feature_names,
    'coef_female': model_l1.coef_[0],
    'coef_male': model_l1.coef_[1],
    'coef_none': model_l1.coef_[2]
})
importance_df['max_abs'] = importance_df[['coef_female', 'coef_male', 'coef_none']].abs().max(axis=1)
importance_df = importance_df.sort_values('max_abs', ascending=False)

for label, col in [('FEMALE', 'coef_female'), ('MALE', 'coef_male'), ('NONE', 'coef_none')]:
    top = importance_df[importance_df[col] > 0].nlargest(15, col)
    print(f"\nTop 15 {label} features:")
    for _, row in top.iterrows():
        print(f"  {row['feature']:40s}: {row[col]:+.4f}")

importance_df.to_csv(OUTPUT_DIR / 'feature_importance.csv', index=False)

# ============================================================================
# STEP 11: PREDICT ON ALL PRODUCTS
# ============================================================================

print("\n" + "="*70)
print("STEP 11: PREDICT ON ALL PRODUCTS")
print("="*70)

X_all, _ = build_feature_matrix(
    df, breadcrumb_vectorizer, description_vectorizer,
    color_feature_cols, price_features, n_stores, include_colors=True
)
print(f"Full dataset: {X_all.shape}")

df['ml_prob_female'] = model_l1.predict_proba(X_all)[:, 0]
df['ml_prob_male'] = model_l1.predict_proba(X_all)[:, 1]
df['ml_prob_none'] = model_l1.predict_proba(X_all)[:, 2]
df['ml_pred'] = model_l1.predict(X_all)
df['ml_pred_label'] = df['ml_pred'].map({0: 'female', 1: 'male', 2: 'none'})
df['ml_confidence'] = df[['ml_prob_female', 'ml_prob_male', 'ml_prob_none']].max(axis=1)

print(f"\nPrediction distribution:")
print(df['ml_pred_label'].value_counts())

# ============================================================================
# STEP 12: COLOR FEATURE IMPACT
# ============================================================================

print("\n" + "="*70)
print("STEP 12: COLOR FEATURE IMPACT ANALYSIS")
print("="*70)

X_train_no_color, features_no_color = build_feature_matrix(
    train_data, breadcrumb_vectorizer, description_vectorizer,
    color_feature_cols, price_features, n_stores, include_colors=False
)
X_test_no_color, _ = build_feature_matrix(
    test_data, breadcrumb_vectorizer, description_vectorizer,
    color_feature_cols, price_features, n_stores, include_colors=False
)

model_no_color = LogisticRegressionCV(
    cv=CV_FOLDS, penalty='l1', solver='saga', max_iter=2000,
    multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE
)
model_no_color.fit(X_train_no_color, y_train)
y_pred_no_color = model_no_color.predict(X_test_no_color)

acc_no_color = accuracy_score(y_test, y_pred_no_color)
f1_no_color = f1_score(y_test, y_pred_no_color, average='weighted')

print(f"WITH colors:    Accuracy={acc_l1:.4f}, F1={f1_l1:.4f}")
print(f"WITHOUT colors: Accuracy={acc_no_color:.4f}, F1={f1_no_color:.4f}")
print(f"Color impact:   Accuracy {'+' if acc_l1 > acc_no_color else ''}"
      f"{(acc_l1-acc_no_color)*100:.2f}pp, "
      f"F1 {'+' if f1_l1 > f1_no_color else ''}{(f1_l1-f1_no_color)*100:.2f}pp")

# Subset analysis: only products that actually have color data
if len(color_lookup) > 0:
    test_with_color = test_data[test_data[COL_PRODUCT_ID].isin(color_lookup)]

    if len(test_with_color) >= MIN_TEST_SAMPLES:
        print(f"\nOn color-available subset ({len(test_with_color)} test samples):")

        X_sub_with, _ = build_feature_matrix(
            test_with_color, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=True
        )
        X_sub_without, _ = build_feature_matrix(
            test_with_color, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=False
        )
        y_sub = test_with_color['target'].values

        acc_sub_with = accuracy_score(y_sub, model_l1.predict(X_sub_with))
        acc_sub_without = accuracy_score(y_sub, model_no_color.predict(X_sub_without))
        print(f"  WITH colors:    Accuracy={acc_sub_with:.4f}")
        print(f"  WITHOUT colors: Accuracy={acc_sub_without:.4f}")
    else:
        print(f"\n⚠ Too few test samples with color data ({len(test_with_color)})")

# ============================================================================
# STEP 13: VALIDATION VS HUMAN LABELS
# ============================================================================

print("\n" + "="*70)
print("STEP 13: VALIDATION VS HUMAN LABELS")
print("="*70)

human_labeled = df[df['label_human'].notna()].copy()
print(f"Products with human labels: {len(human_labeled)}")

if len(human_labeled) > 0:
    human_labeled['human_encoded'] = human_labeled['label_human'].map({
        'female': 0, 'male': 1, 'none': 2
    })
    valid = human_labeled[human_labeled['human_encoded'].notna()]

    if len(valid) >= MIN_TEST_SAMPLES:
        acc_human = accuracy_score(valid['human_encoded'], valid['ml_pred'])
        print(f"Accuracy vs human: {acc_human:.4f}")

        if len(valid['human_encoded'].unique()) >= 2:
            print("\nConfusion (ML vs Human):")
            cm_h = confusion_matrix(valid['human_encoded'], valid['ml_pred'], labels=[0, 1, 2])
            print(f"            ML Predicted")
            print(f"            female  male  none")
            for i, label in enumerate(['female', 'male', 'none']):
                print(f"Human {label:6s}  {cm_h[i,0]:4d}  {cm_h[i,1]:4d}  {cm_h[i,2]:4d}")
    else:
        print(f"⚠ Too few validated samples ({len(valid)})")
else:
    print("⚠ No human-labeled data available")

# ============================================================================
# STEP 14: IMPLICIT GENDERING
# ============================================================================

print("\n" + "="*70)
print("STEP 14: IMPLICIT GENDERING")
print("="*70)

implicit_female = df[
    (df['label_extracted'] == 'none') &
    (df['ml_pred_label'] == 'female') &
    (df['ml_confidence'] > 0.5)
]
implicit_male = df[
    (df['label_extracted'] == 'none') &
    (df['ml_pred_label'] == 'male') &
    (df['ml_confidence'] > 0.5)
]
predicted_none = df[df['ml_pred_label'] == 'none']

print(f"Implicit female (>50% conf): {len(implicit_female):,}")
print(f"Implicit male (>50% conf): {len(implicit_male):,}")
print(f"Predicted none: {len(predicted_none):,}")

if len(implicit_female) > 0:
    print(f"\nSample implicit FEMALE:")
    for _, row in implicit_female.head(5).iterrows():
        print(f"  {row[COL_NAME][:60]}... (conf: {row['ml_confidence']:.2f})")

if len(implicit_male) > 0:
    print(f"\nSample implicit MALE:")
    for _, row in implicit_male.head(5).iterrows():
        print(f"  {row[COL_NAME][:60]}... (conf: {row['ml_confidence']:.2f})")

# ============================================================================
# STEP 15: EXPORT VALIDATION SAMPLE
# ============================================================================

print("\n" + "="*70)
print("STEP 15: EXPORT VALIDATION SAMPLE")
print("="*70)

already_labeled = set()
if COL_PRODUCT_ID in human_coded.columns:
    already_labeled.update(human_coded[COL_PRODUCT_ID].values)
if len(your_labeled) > 0 and COL_PRODUCT_ID in your_labeled.columns:
    already_labeled.update(your_labeled[COL_PRODUCT_ID].values)

available = df[
    (~df[COL_PRODUCT_ID].isin(already_labeled)) &
    (df[COL_IMAGE].notna())
].copy()

print(f"Available for validation: {len(available):,}")

if len(available) == 0:
    print("⚠ No products available for validation")
    validation_export = pd.DataFrame()
else:
    N_PER = 85
    samples = []
    for pred, label in [(0, 'female'), (1, 'male'), (2, 'none')]:
        pool = available[available['ml_pred'] == pred]
        n = min(N_PER, len(pool))
        if n > 0:
            samples.append(pool.sample(n=n, random_state=RANDOM_STATE))
            print(f"  Sampled {n} {label}")

    if samples:
        validation = pd.concat(samples).sample(frac=1, random_state=RANDOM_STATE)

        export_cols = [
            COL_PRODUCT_ID, COL_NAME, COL_DESC, COL_BREADCRUMB, COL_IMAGE,
            COL_URL, COL_PRICE, 'label_extracted', 'ml_pred_label',
            'ml_prob_female', 'ml_prob_male', 'ml_prob_none', 'ml_confidence'
        ]
        export_cols = [c for c in export_cols if c in validation.columns]

        validation_export = validation[export_cols].copy()
        validation_export['manual_gender'] = ''
        validation_export['manual_confidence'] = ''
        validation_export['manual_notes'] = ''

        output_file = OUTPUT_DIR / 'validation_sample.csv'
        validation_export.to_csv(output_file, index=True)
        print(f"\n✓ Saved: {output_file} ({len(validation_export)} products)")

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("PIPELINE SUMMARY (v4)")
print("="*70)

print(f"""
DATA:
  Original: {original_count:,}
  After filtering: {len(df):,}
  Excluded categories: {excluded_count:,}

TRAINING:
  Total samples: {len(ml_data):,} (balanced 3-class)
  Train: {len(train_data):,}
  Test: {len(test_data):,}
  Products with colors: {len(color_df):,} (Morrisons only — Tesco/ASDA CDN links expired)

BEST MODEL: {best_name} (F1: {results_df.iloc[0]['F1_weighted']:.4f})

COLOR IMPACT:
  With colors:    F1={f1_l1:.4f}
  Without colors: F1={f1_no_color:.4f}

PREDICTIONS:
  Female: {(df['ml_pred_label'] == 'female').sum():,}
  Male: {(df['ml_pred_label'] == 'male').sum():,}
  None: {(df['ml_pred_label'] == 'none').sum():,}

IMPLICIT GENDERING:
  Implicit female: {len(implicit_female):,}
  Implicit male: {len(implicit_male):,}

OUTPUT:
  {OUTPUT_DIR}/
  ├── model_comparison.csv
  ├── feature_importance.csv
  └── validation_sample.csv

NOTE: Color features are limited to Morrisons products (~5,600 of 12,800).
      Tesco and ASDA image CDN links are expired. Color-based findings
      should be interpreted with this coverage limitation in mind.
""")

summary = {
    'version': '4.0',
    'data': {
        'original': original_count,
        'filtered': int(len(df)),
        'excluded': int(excluded_count),
        'training_samples': int(len(ml_data)),
        'color_samples': int(len(color_df)),
        'color_coverage_note': 'Morrisons only; Tesco/ASDA CDN links expired'
    },
    'models': results,
    'color_impact': {
        'with_colors_f1': float(f1_l1),
        'without_colors_f1': float(f1_no_color)
    },
    'predictions': {
        'female': int((df['ml_pred_label'] == 'female').sum()),
        'male': int((df['ml_pred_label'] == 'male').sum()),
        'none': int((df['ml_pred _label'] == 'none').sum())
    }
}

with open(OUTPUT_DIR / 'summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"✓ Pipeline complete")

✓ Imports complete
✓ Output directory: /Users/leoss/Desktop/Portfolio/Website-/UK pink tax/Outputs/charts/ml_pipeline_v4

LOADING DATA
✓ Main dataset: 21,436 products
✓ Human-coded: 259 products
✓ Your labeled: 44 products

STEP 1: FILTER NON-GENDERED CATEGORIES
Original: 21,436
Excluded: 8,604
Remaining: 12,832

STEP 2: EXTRACT GENDER LABELS
Label distribution:
{'none': 10913, 'female': 1075, 'male': 844}
Human labels merged: 200

STEP 3: LOAD COLOR DATA
✓ Loaded color cache: 5,617 products
  Gender split: {'none': 4746, 'female': 460, 'male': 411}
  Matched to filtered data: 5,617 / 5,617

STEP 4: PREPARE TRAINING DATA
Explicitly female: 1075
Explicitly male: 844

None class sources:
  Human-coded none: 48
  Extracted none (unlabeled): 10849
  Final none class: 844

Balancing to: 844 per class

✓ Training data: 2532 products
  Class distribution: {0: 850, 1: 845, 2: 837}
  (0=female, 1=male, 2=none)

STEP 5: TRAIN-TEST SPLIT
Train: 1905
Test: 639
Train distribution: {0: 643, 1: 634, 

In [9]:
# ============================================================================
# PINK TAX REGRESSION ANALYSIS (v2)
# ============================================================================
#
# Improvements over v1:
#   - Fixed breadcrumb parsing (was printing dicts)
#   - Added quantile regression (10th, 25th, 50th, 75th, 90th percentiles)
#   - Added female × store interaction
#   - Added bootstrap CIs for within-category gaps
#   - Category matching at multiple granularity levels
#   - Trimmed outlier categories from within-category averages
#   - 10+ visualisations
# ============================================================================

from pathlib import Path
import pandas as pd
import numpy as np
import re
import json
import warnings

import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn.feature_extraction.text import TfidfVectorizer

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import FancyBboxPatch
import seaborn as sns

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

BASE_DIR = Path('/Users/leoss')
DATA_DIR = BASE_DIR / 'Downloads'
OUTPUT_BASE = BASE_DIR / 'Desktop/Portfolio/Website-/UK pink tax/Outputs'

PATH_MAIN_DATA = DATA_DIR / 'items_fin.csv'
OUTPUT_DIR = OUTPUT_BASE / 'charts/pink_tax_regression_v2'
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
N_BOOTSTRAP = 1000

EXCLUDE_CATEGORIES = [
    'food', 'grocery', 'groceries', 'snacks', 'drinks', 'beverages',
    'pet food', 'pet supplies', 'cleaning', 'household', 'kitchen',
    'office', 'stationery', 'electronics', 'tech', 'garden', 'automotive'
]

FEMALE_KEYWORDS = ['women', 'woman', 'female', 'ladies', 'lady', 'girls',
                   'womens', "women's", 'femme', 'her', 'feminine', 'fem']
MALE_KEYWORDS = ['men', 'man', 'male', 'gentleman', 'gentlemen', 'boys',
                 'mens', "men's", 'homme', 'his', 'masculine']
ALL_GENDER_KEYWORDS = set(FEMALE_KEYWORDS + MALE_KEYWORDS)

COL_NAME = 'product_title_x'
COL_DESC = 'description'
COL_BREADCRUMB = 'standardized_breadcrumbs'
COL_PRICE = 'price'
COL_UNIT_PRICE = 'unit_price'
COL_STORE = 'store_id'
COL_PRODUCT_ID = 'product_id'

# Chart style
PALETTE = {'female': '#c44e52', 'male': '#4c72b0', 'none': '#8c8c8c'}
plt.rcParams.update({
    'figure.facecolor': 'white',
    'axes.facecolor': '#fafafa',
    'axes.grid': True,
    'grid.alpha': 0.3,
    'grid.linestyle': '--',
    'font.size': 10,
    'axes.titlesize': 12,
    'axes.labelsize': 10,
})

# ============================================================================
# LOAD AND PREPARE DATA
# ============================================================================

print("=" * 70)
print("PINK TAX REGRESSION ANALYSIS (v2)")
print("=" * 70)

df = pd.read_csv(PATH_MAIN_DATA, encoding='latin-1')
df.columns = df.columns.str.lower().str.strip().str.replace(' ', '_')
if 'unnamed:_0' in df.columns:
    df = df.drop(columns=['unnamed:_0'])
print(f"Loaded {len(df):,} products")

# Filter categories
def contains_excluded_category(text):
    if pd.isna(text):
        return False
    return any(cat in str(text).lower() for cat in EXCLUDE_CATEGORIES)

df = df[~df[COL_BREADCRUMB].apply(contains_excluded_category)].copy().reset_index(drop=True)
print(f"After category filter: {len(df):,}")

# Gender labels
def extract_gender(text):
    if pd.isna(text) or str(text).strip() == '':
        return 'none'
    text_lower = str(text).lower()
    has_f = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in FEMALE_KEYWORDS)
    has_m = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in MALE_KEYWORDS)
    if has_f and not has_m:
        return 'female'
    elif has_m and not has_f:
        return 'male'
    elif has_f and has_m:
        return 'both'
    return 'none'

for col_label, col_source in [('label_bc', COL_BREADCRUMB),
                               ('label_name', COL_NAME),
                               ('label_desc', COL_DESC)]:
    df[col_label] = df[col_source].apply(extract_gender)

def combine_labels(row):
    for col in ['label_bc', 'label_name', 'label_desc']:
        if row[col] in ['female', 'male']:
            return row[col]
    return 'none'

df['gender'] = df.apply(combine_labels, axis=1)

# Price
df['price_num'] = pd.to_numeric(df[COL_PRICE], errors='coerce')
df = df[df['price_num'].notna() & (df['price_num'] > 0)].copy()
df['log_price'] = np.log(df['price_num'])

# Unit price
if COL_UNIT_PRICE in df.columns:
    df['unit_price_num'] = df[COL_UNIT_PRICE].astype(str).str.extract(r'([\d.]+)')[0].astype(float)

# Store
df['store'] = df[COL_STORE].fillna('unknown').astype(str)

# Store name mapping (from store IDs)
store_names = {}
for sid in df['store'].unique():
    sub = df[df['store'] == sid]
    # Infer from breadcrumb or URL patterns
    bc_sample = sub[COL_BREADCRUMB].dropna().head(10).str.lower()
    if bc_sample.str.contains('morrisons').any():
        store_names[sid] = 'Morrisons'
    elif bc_sample.str.contains('tesco').any():
        store_names[sid] = 'Tesco'
    elif bc_sample.str.contains('asda').any():
        store_names[sid] = 'ASDA'
    else:
        store_names[sid] = f'Store {sid}'

df['store_name'] = df['store'].map(store_names)

# ---- Breadcrumb parsing (FIXED) ----
def parse_breadcrumb(text):
    """Extract clean category levels from breadcrumb string."""
    if pd.isna(text):
        return 'unknown', 'unknown', 'unknown'
    # Handle both " > " and " / " separators
    text = str(text).strip()
    if ' > ' in text:
        parts = [p.strip().lower() for p in text.split(' > ') if p.strip()]
    elif ' / ' in text:
        parts = [p.strip().lower() for p in text.split(' / ') if p.strip()]
    else:
        parts = [text.strip().lower()]

    # Some breadcrumbs start with store name — skip if it matches known stores
    known_stores = {'morrisons', 'tesco', 'asda', 'groceries', 'marketplace'}
    while parts and parts[0] in known_stores:
        parts = parts[1:]

    level1 = parts[0] if len(parts) > 0 else 'unknown'
    level2 = parts[1] if len(parts) > 1 else 'unknown'
    level3 = parts[2] if len(parts) > 2 else 'unknown'
    return level1, level2, level3

df[['cat1', 'cat2', 'cat3']] = df[COL_BREADCRUMB].apply(
    lambda x: pd.Series(parse_breadcrumb(x))
)

# Category labels at different granularities
df['cat_broad'] = df['cat1']
df['cat_mid'] = df['cat1'] + ' > ' + df['cat2']
df['cat_fine'] = df['cat1'] + ' > ' + df['cat2'] + ' > ' + df['cat3']

print(f"With valid prices: {len(df):,}")
print(f"\nGender distribution:")
for g in ['female', 'male', 'none', 'both']:
    sub = df[df['gender'] == g]
    if len(sub) > 0:
        print(f"  {g}: {len(sub):,}  (mean £{sub['price_num'].mean():.2f}, "
              f"median £{sub['price_num'].median():.2f})")

# ============================================================================
# ANALYSIS SAMPLE
# ============================================================================

print("\n" + "=" * 70)
print("ANALYSIS SAMPLE: FEMALE vs MALE")
print("=" * 70)

gendered = df[df['gender'].isin(['female', 'male'])].copy()
gendered['is_female'] = (gendered['gender'] == 'female').astype(int)

print(f"N = {len(gendered):,}  (F: {gendered['is_female'].sum()}, "
      f"M: {(1 - gendered['is_female']).sum()})")
print(f"Mean price  — F: £{gendered.loc[gendered['is_female']==1, 'price_num'].mean():.2f}, "
      f"M: £{gendered.loc[gendered['is_female']==0, 'price_num'].mean():.2f}")
print(f"Median price — F: £{gendered.loc[gendered['is_female']==1, 'price_num'].median():.2f}, "
      f"M: £{gendered.loc[gendered['is_female']==0, 'price_num'].median():.2f}")

# Categories with both genders
for level_name, level_col in [('broad', 'cat_broad'), ('mid', 'cat_mid'), ('fine', 'cat_fine')]:
    cats_with_both = 0
    for cat in gendered[level_col].unique():
        sub = gendered[gendered[level_col] == cat]
        if sub['is_female'].sum() >= 3 and (1 - sub['is_female']).sum() >= 3:
            cats_with_both += 1
    print(f"  Categories ({level_name}) with ≥3F + ≥3M: {cats_with_both}")

# ============================================================================
# REGRESSIONS
# ============================================================================

results_table = []

def run_and_record(name, formula_or_result, data=None, controls='', coef_name='is_female',
                   prefit=None):
    """Run OLS and record the female coefficient."""
    if prefit is not None:
        model = prefit
    else:
        model = smf.ols(formula_or_result, data=data).fit(cov_type='HC1')

    coef = model.params[coef_name]
    se = model.bse[coef_name]
    pval = model.pvalues[coef_name]
    pct = (np.exp(coef) - 1) * 100
    ci_lo = coef - 1.96 * se
    ci_hi = coef + 1.96 * se

    results_table.append({
        'spec': name,
        'coef': coef, 'se': se, 'p': pval,
        'pct': pct,
        'ci_lo': ci_lo, 'ci_hi': ci_hi,
        'pct_lo': (np.exp(ci_lo) - 1) * 100,
        'pct_hi': (np.exp(ci_hi) - 1) * 100,
        'r2': model.rsquared,
        'n': int(model.nobs),
        'controls': controls,
    })

    sig = '***' if pval < 0.01 else ('**' if pval < 0.05 else ('*' if pval < 0.1 else ''))
    print(f"  {name}: coef={coef:+.4f} ({pct:+.1f}%), SE={se:.4f}, "
          f"p={pval:.4f}{sig}, R²={model.rsquared:.3f}, N={int(model.nobs)}")
    return model

# --- Spec 1: Raw ---
print("\nSpec 1: Raw gap")
spec1 = run_and_record('(1) Raw gap',
                        'log_price ~ is_female', gendered, 'None')

# --- Spec 2: + Store FE ---
print("\nSpec 2: + Store FE")
spec2 = run_and_record('(2) + Store FE',
                        'log_price ~ is_female + C(store)', gendered, 'Store')

# --- Spec 3: + Broad category FE ---
# Filter to categories with ≥5 obs
cat_counts_broad = gendered['cat_broad'].value_counts()
valid_broad = cat_counts_broad[cat_counts_broad >= 5].index
gen_broad = gendered[gendered['cat_broad'].isin(valid_broad)].copy()

print(f"\nSpec 3: + Broad category FE (N cats: {len(valid_broad)})")
spec3 = run_and_record('(3) + Broad cat FE',
                        'log_price ~ is_female + C(store) + C(cat_broad)',
                        gen_broad, 'Store + Broad cat')

# --- Spec 4: + Mid category FE ---
cat_counts_mid = gendered['cat_mid'].value_counts()
valid_mid = cat_counts_mid[cat_counts_mid >= 5].index
gen_mid = gendered[gendered['cat_mid'].isin(valid_mid)].copy()

print(f"\nSpec 4: + Mid category FE (N cats: {len(valid_mid)})")
spec4 = run_and_record('(4) + Mid cat FE',
                        'log_price ~ is_female + C(store) + C(cat_mid)',
                        gen_mid, 'Store + Mid cat')

# --- Spec 5: + Fine category FE ---
cat_counts_fine = gendered['cat_fine'].value_counts()
valid_fine = cat_counts_fine[cat_counts_fine >= 5].index
gen_fine = gendered[gendered['cat_fine'].isin(valid_fine)].copy()

print(f"\nSpec 5: + Fine category FE (N cats: {len(valid_fine)})")
if len(valid_fine) > 0 and len(gen_fine) > 50:
    spec5 = run_and_record('(5) + Fine cat FE',
                            'log_price ~ is_female + C(store) + C(cat_fine)',
                            gen_fine, 'Store + Fine cat')
else:
    print("  Skipped (too few categories)")

# --- Spec 6: + Description TF-IDF ---
def clean_text_no_gender(text):
    if pd.isna(text):
        return ''
    text = str(text).lower()
    for w in ALL_GENDER_KEYWORDS:
        text = re.sub(r'\b' + w + r'\b', '', text)
    return re.sub(r'\s+', ' ', re.sub(r'[^a-z\s]', ' ', text)).strip()

gen_mid['desc_clean'] = gen_mid[COL_DESC].apply(clean_text_no_gender)

desc_vec = TfidfVectorizer(max_features=100, min_df=5, max_df=0.9,
                            ngram_range=(1, 2), stop_words='english')
X_desc = desc_vec.fit_transform(gen_mid['desc_clean'])
desc_names = [f'desc_{f}' for f in desc_vec.get_feature_names_out()]

y6 = gen_mid['log_price'].values
X6_parts = [
    gen_mid[['is_female']].values,
    pd.get_dummies(gen_mid['store'], prefix='store', drop_first=True).values,
    pd.get_dummies(gen_mid['cat_mid'], prefix='cat', drop_first=True).values,
    X_desc.toarray(),
]
X6 = sm.add_constant(np.hstack(X6_parts))

spec6_model = sm.OLS(y6, X6).fit(cov_type='HC1')
# is_female is column index 1 (after constant)
coef6 = spec6_model.params[1]
se6 = spec6_model.bse[1]
pval6 = spec6_model.pvalues[1]
results_table.append({
    'spec': '(6) + Description',
    'coef': coef6, 'se': se6, 'p': pval6,
    'pct': (np.exp(coef6) - 1) * 100,
    'ci_lo': coef6 - 1.96 * se6, 'ci_hi': coef6 + 1.96 * se6,
    'pct_lo': (np.exp(coef6 - 1.96 * se6) - 1) * 100,
    'pct_hi': (np.exp(coef6 + 1.96 * se6) - 1) * 100,
    'r2': spec6_model.rsquared, 'n': int(spec6_model.nobs),
    'controls': 'Store + Mid cat + Description',
})
sig6 = '***' if pval6 < 0.01 else ('**' if pval6 < 0.05 else ('*' if pval6 < 0.1 else ''))
print(f"\nSpec 6: + Description TF-IDF")
print(f"  (6) + Description: coef={coef6:+.4f} ({(np.exp(coef6)-1)*100:+.1f}%), "
      f"SE={se6:.4f}, p={pval6:.4f}{sig6}, R²={spec6_model.rsquared:.3f}, N={int(spec6_model.nobs)}")

# --- Spec 7: Female × Store interaction ---
print(f"\nSpec 7: Female × Store interaction")
spec7 = smf.ols('log_price ~ is_female * C(store) + C(cat_mid)',
                 data=gen_mid).fit(cov_type='HC1')

# Extract female coefficient and interaction terms
print(f"  Main effect (is_female): {spec7.params['is_female']:+.4f} "
      f"(p={spec7.pvalues['is_female']:.4f})")
for param in spec7.params.index:
    if 'is_female:' in param:
        store_id = param.split('[T.')[1].rstrip(']')
        sname = store_names.get(store_id, store_id)
        total_effect = spec7.params['is_female'] + spec7.params[param]
        pct_effect = (np.exp(total_effect) - 1) * 100
        print(f"  {sname}: total female effect = {total_effect:+.4f} ({pct_effect:+.1f}%), "
              f"interaction p={spec7.pvalues[param]:.4f}")

# --- Spec 8: Unit price ---
print(f"\nSpec 8: Unit price regression")
if 'unit_price_num' in gen_mid.columns:
    gen_unit = gen_mid[gen_mid['unit_price_num'].notna() & (gen_mid['unit_price_num'] > 0)].copy()
    gen_unit['log_unit_price'] = np.log(gen_unit['unit_price_num'])
    if len(gen_unit) >= 50:
        spec8 = run_and_record('(8) Unit price',
                                'log_unit_price ~ is_female + C(store) + C(cat_mid)',
                                gen_unit, 'Store + Mid cat (unit price)',
                                coef_name='is_female')

# --- Spec 9: Three-way (female, male, none) ---
print(f"\nSpec 9: Three-way comparison (none = reference)")
df_valid = df.copy()
df_valid['is_female'] = (df_valid['gender'] == 'female').astype(int)
df_valid['is_male'] = (df_valid['gender'] == 'male').astype(int)

spec9 = smf.ols('log_price ~ is_female + is_male + C(store)',
                 data=df_valid).fit(cov_type='HC1')
print(f"  is_female: {spec9.params['is_female']:+.4f} (p={spec9.pvalues['is_female']:.4f})")
print(f"  is_male:   {spec9.params['is_male']:+.4f} (p={spec9.pvalues['is_male']:.4f})")
f_vs_m = spec9.params['is_female'] - spec9.params['is_male']
print(f"  F vs M:    {(np.exp(f_vs_m)-1)*100:+.1f}%")

# ============================================================================
# QUANTILE REGRESSION
# ============================================================================

print("\n" + "=" * 70)
print("QUANTILE REGRESSION")
print("=" * 70)

quantiles = [0.10, 0.25, 0.50, 0.75, 0.90]
qreg_results = []

for q in quantiles:
    qmodel = smf.quantreg('log_price ~ is_female + C(store)', data=gendered).fit(q=q)
    coef_q = qmodel.params['is_female']
    se_q = qmodel.bse['is_female']
    pval_q = qmodel.pvalues['is_female']
    pct_q = (np.exp(coef_q) - 1) * 100

    qreg_results.append({
        'quantile': q,
        'coef': coef_q,
        'se': se_q,
        'p': pval_q,
        'pct': pct_q,
        'ci_lo': coef_q - 1.96 * se_q,
        'ci_hi': coef_q + 1.96 * se_q,
    })

    sig = '***' if pval_q < 0.01 else ('**' if pval_q < 0.05 else ('*' if pval_q < 0.1 else ''))
    print(f"  Q{q:.2f}: coef={coef_q:+.4f} ({pct_q:+.1f}%), p={pval_q:.4f}{sig}")

qreg_df = pd.DataFrame(qreg_results)

# With category controls
print("\n  With mid-category controls:")
qreg_cat_results = []
for q in quantiles:
    try:
        qmodel = smf.quantreg('log_price ~ is_female + C(store) + C(cat_mid)',
                                data=gen_mid).fit(q=q, max_iter=5000)
        coef_q = qmodel.params['is_female']
        se_q = qmodel.bse['is_female']
        pval_q = qmodel.pvalues['is_female']
        pct_q = (np.exp(coef_q) - 1) * 100
        qreg_cat_results.append({
            'quantile': q, 'coef': coef_q, 'se': se_q,
            'p': pval_q, 'pct': pct_q,
            'ci_lo': coef_q - 1.96 * se_q,
            'ci_hi': coef_q + 1.96 * se_q,
        })
        sig = '***' if pval_q < 0.01 else ('**' if pval_q < 0.05 else ('*' if pval_q < 0.1 else ''))
        print(f"  Q{q:.2f}: coef={coef_q:+.4f} ({pct_q:+.1f}%), p={pval_q:.4f}{sig}")
    except Exception as e:
        print(f"  Q{q:.2f}: failed ({e})")

qreg_cat_df = pd.DataFrame(qreg_cat_results) if qreg_cat_results else pd.DataFrame()

# ============================================================================
# WITHIN-CATEGORY ANALYSIS (with bootstrap CIs)
# ============================================================================

print("\n" + "=" * 70)
print("WITHIN-CATEGORY ANALYSIS (bootstrap CIs)")
print("=" * 70)

rng = np.random.default_rng(RANDOM_STATE)
category_gaps = []

for cat in gendered['cat_mid'].unique():
    sub = gendered[gendered['cat_mid'] == cat]
    fem = sub[sub['is_female'] == 1]['price_num']
    mal = sub[sub['is_female'] == 0]['price_num']

    if len(fem) >= 3 and len(mal) >= 3:
        gap_pct = (fem.mean() / mal.mean() - 1) * 100
        log_gap = np.log(fem).mean() - np.log(mal).mean()

        # Bootstrap CI
        boot_gaps = []
        for _ in range(N_BOOTSTRAP):
            f_boot = rng.choice(fem.values, size=len(fem), replace=True)
            m_boot = rng.choice(mal.values, size=len(mal), replace=True)
            if m_boot.mean() > 0:
                boot_gaps.append((f_boot.mean() / m_boot.mean() - 1) * 100)

        ci_lo = np.percentile(boot_gaps, 2.5) if boot_gaps else np.nan
        ci_hi = np.percentile(boot_gaps, 97.5) if boot_gaps else np.nan

        category_gaps.append({
            'category': cat,
            'n_female': len(fem),
            'n_male': len(mal),
            'n_total': len(fem) + len(mal),
            'mean_female': fem.mean(),
            'mean_male': mal.mean(),
            'gap_pct': gap_pct,
            'log_gap': log_gap,
            'ci_lo': ci_lo,
            'ci_hi': ci_hi,
            'significant': (ci_lo > 0 and ci_hi > 0) or (ci_lo < 0 and ci_hi < 0),
        })

gaps_df = pd.DataFrame(category_gaps).sort_values('gap_pct', ascending=False)

if len(gaps_df) > 0:
    print(f"Categories with both genders (≥3 each): {len(gaps_df)}")

    # Trimmed averages (drop top/bottom outlier)
    gaps_trimmed = gaps_df.copy()
    if len(gaps_trimmed) > 4:
        gaps_trimmed = gaps_trimmed.iloc[1:-1]  # drop most extreme

    weighted_gap = np.average(gaps_df['gap_pct'], weights=gaps_df['n_total'])
    trimmed_mean = gaps_trimmed['gap_pct'].mean()
    median_gap = gaps_df['gap_pct'].median()

    print(f"\nWithin-category price gap (F vs M):")
    print(f"  Weighted mean: {weighted_gap:+.1f}%")
    print(f"  Trimmed mean:  {trimmed_mean:+.1f}%")
    print(f"  Median:        {median_gap:+.1f}%")

    print(f"\n  {'category':<45s} {'F':>3s} {'M':>3s} {'gap':>8s} {'95% CI':>16s} {'sig':>4s}")
    print(f"  {'-'*80}")
    for _, row in gaps_df.iterrows():
        sig = '*' if row['significant'] else ''
        print(f"  {row['category'][:45]:<45s} {row['n_female']:>3.0f} {row['n_male']:>3.0f} "
              f"{row['gap_pct']:>+7.1f}% [{row['ci_lo']:>+6.1f}, {row['ci_hi']:>+6.1f}] {sig:>3s}")

    gaps_df.to_csv(OUTPUT_DIR / 'within_category_gaps.csv', index=False)

# ============================================================================
# BY-STORE ANALYSIS
# ============================================================================

print("\n" + "=" * 70)
print("PINK TAX BY STORE")
print("=" * 70)

store_results = []
for store_id in gendered['store'].unique():
    sub = gendered[gendered['store'] == store_id]
    fem = sub[sub['is_female'] == 1]
    mal = sub[sub['is_female'] == 0]
    sname = store_names.get(store_id, store_id)

    if len(fem) >= 10 and len(mal) >= 10:
        model = smf.ols('log_price ~ is_female', data=sub).fit(cov_type='HC1')
        coef = model.params['is_female']
        pval = model.pvalues['is_female']
        store_results.append({
            'store': sname,
            'store_id': store_id,
            'n_female': len(fem), 'n_male': len(mal),
            'mean_f': fem['price_num'].mean(),
            'mean_m': mal['price_num'].mean(),
            'coef': coef,
            'pct_gap': (np.exp(coef) - 1) * 100,
            'p_value': pval,
            'significant': pval < 0.05,
        })

store_df = pd.DataFrame(store_results).sort_values('pct_gap', ascending=False)

if len(store_df) > 0:
    print(f"{'store':<15s} {'F':>5s} {'M':>5s} {'mean F':>8s} {'mean M':>8s} {'gap':>8s} {'p':>8s}")
    print(f"{'-'*65}")
    for _, row in store_df.iterrows():
        sig = '***' if row['p_value'] < 0.01 else ('**' if row['p_value'] < 0.05 else
              ('*' if row['p_value'] < 0.1 else ''))
        print(f"{row['store']:<15s} {row['n_female']:>5.0f} {row['n_male']:>5.0f} "
              f"£{row['mean_f']:>6.2f} £{row['mean_m']:>6.2f} "
              f"{row['pct_gap']:>+7.1f}% {row['p_value']:>7.4f}{sig}")

    store_df.to_csv(OUTPUT_DIR / 'pink_tax_by_store.csv', index=False)

# ============================================================================
# RESULTS SUMMARY TABLE
# ============================================================================

print("\n" + "=" * 70)
print("REGRESSION SUMMARY")
print("=" * 70)

summary_df = pd.DataFrame(results_table)
print()
print(f"{'Spec':<25s} {'Coef':>8s} {'%gap':>8s} {'95% CI':>18s} {'p':>8s} {'R²':>6s} {'N':>6s}")
print(f"{'-'*80}")
for _, row in summary_df.iterrows():
    sig = '***' if row['p'] < 0.01 else ('**' if row['p'] < 0.05 else ('*' if row['p'] < 0.1 else ''))
    print(f"{row['spec']:<25s} {row['coef']:>+7.4f} {row['pct']:>+7.1f}% "
          f"[{row['pct_lo']:>+6.1f}, {row['pct_hi']:>+6.1f}] "
          f"{row['p']:>7.4f}{sig:<3s} {row['r2']:>5.3f} {row['n']:>6.0f}")

summary_df.to_csv(OUTPUT_DIR / 'regression_summary.csv', index=False)

# ============================================================================
# VISUALISATIONS
# ============================================================================

print("\n" + "=" * 70)
print("GENERATING CHARTS")
print("=" * 70)

# ---- 1. Coefficient plot across specifications ----
fig, ax = plt.subplots(figsize=(9, 5))

specs = summary_df['spec'].values
coefs = summary_df['coef'].values
ci_los = summary_df['ci_lo'].values
ci_his = summary_df['ci_hi'].values

y_pos = np.arange(len(specs))
xerr_lo = coefs - ci_los
xerr_hi = ci_his - coefs

colors = [PALETTE['female'] if c > 0 else PALETTE['male'] for c in coefs]

ax.barh(y_pos, coefs, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5, height=0.6)
ax.errorbar(coefs, y_pos, xerr=[xerr_lo, xerr_hi], fmt='none', color='black',
            capsize=4, linewidth=1.2)
ax.axvline(x=0, color='black', linewidth=1, linestyle='-')
ax.set_yticks(y_pos)
ax.set_yticklabels(specs)
ax.set_xlabel('Coefficient on female indicator (log price)')
ax.set_title('Female price premium: sensitivity to controls')
ax.invert_yaxis()

for i, (c, p) in enumerate(zip(coefs, summary_df['pct'].values)):
    offset = max(abs(ci_his[i]), abs(ci_los[i])) + 0.01
    ax.text(ci_his[i] + 0.01, i, f'{p:+.1f}%', va='center', ha='left', fontsize=9,
            fontweight='bold' if summary_df.iloc[i]['p'] < 0.05 else 'normal')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '01_coefficient_plot.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 01_coefficient_plot.png")

# ---- 2. Quantile regression ----
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Without category controls
ax = axes[0]
qs = qreg_df['quantile']
ax.fill_between(qs, qreg_df['ci_lo'], qreg_df['ci_hi'], alpha=0.2, color=PALETTE['female'])
ax.plot(qs, qreg_df['coef'], 'o-', color=PALETTE['female'], linewidth=2, markersize=6)
ax.axhline(y=0, color='black', linewidth=0.8, linestyle='--')
ax.axhline(y=spec1.params['is_female'], color='gray', linewidth=0.8, linestyle=':',
           label=f'OLS mean ({(np.exp(spec1.params["is_female"])-1)*100:+.1f}%)')
ax.set_xlabel('Quantile')
ax.set_ylabel('Coefficient on female (log price)')
ax.set_title('Quantile regression: store controls only')
ax.set_xticks(quantiles)
ax.legend(fontsize=9)

# With category controls
if len(qreg_cat_df) > 0:
    ax = axes[1]
    qs_c = qreg_cat_df['quantile']
    ax.fill_between(qs_c, qreg_cat_df['ci_lo'], qreg_cat_df['ci_hi'],
                    alpha=0.2, color=PALETTE['female'])
    ax.plot(qs_c, qreg_cat_df['coef'], 'o-', color=PALETTE['female'], linewidth=2, markersize=6)
    ax.axhline(y=0, color='black', linewidth=0.8, linestyle='--')
    ax.set_xlabel('Quantile')
    ax.set_ylabel('Coefficient on female (log price)')
    ax.set_title('Quantile regression: store + category controls')
    ax.set_xticks(quantiles)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '02_quantile_regression.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 02_quantile_regression.png")

# ---- 3. Within-category gaps with CIs ----
if len(gaps_df) > 0:
    fig, ax = plt.subplots(figsize=(10, max(5, len(gaps_df) * 0.5)))

    gaps_sorted = gaps_df.sort_values('gap_pct')
    y_pos = np.arange(len(gaps_sorted))

    colors = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in gaps_sorted['gap_pct']]
    edge = ['black' if s else 'gray' for s in gaps_sorted['significant']]

    bars = ax.barh(y_pos, gaps_sorted['gap_pct'], color=colors, alpha=0.6, height=0.7)
    for bar, ec in zip(bars, edge):
        bar.set_edgecolor(ec)
        bar.set_linewidth(1 if ec == 'black' else 0.3)

    # CI whiskers
    ax.errorbar(gaps_sorted['gap_pct'].values, y_pos,
                xerr=[gaps_sorted['gap_pct'].values - gaps_sorted['ci_lo'].values,
                      gaps_sorted['ci_hi'].values - gaps_sorted['gap_pct'].values],
                fmt='none', color='black', capsize=3, linewidth=0.8)

    ax.axvline(x=0, color='black', linewidth=1)
    ax.set_yticks(y_pos)
    labels = [f"{r['category'][:42]} (F:{r['n_female']:.0f}, M:{r['n_male']:.0f})"
              for _, r in gaps_sorted.iterrows()]
    ax.set_yticklabels(labels, fontsize=8)
    ax.set_xlabel('Female price premium (%)')
    ax.set_title('Within-category price gaps with 95% bootstrap CIs')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '03_within_category_gaps.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✓ 03_within_category_gaps.png")

# ---- 4. Price distributions ----
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 4a: Histogram
ax = axes[0, 0]
for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
    sub = gendered[gendered['gender'] == gender]
    ax.hist(sub['price_num'], bins=50, alpha=0.5, color=color,
            label=f'{gender.title()} (n={len(sub)})', density=True)
ax.set_xlabel('Price (£)')
ax.set_ylabel('Density')
ax.set_title('Price distributions')
ax.legend()
ax.set_xlim(0, gendered['price_num'].quantile(0.95))

# 4b: Log price
ax = axes[0, 1]
for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
    sub = gendered[gendered['gender'] == gender]
    ax.hist(sub['log_price'], bins=50, alpha=0.5, color=color,
            label=f'{gender.title()} (n={len(sub)})', density=True)
ax.set_xlabel('Log price')
ax.set_ylabel('Density')
ax.set_title('Log price distributions')
ax.legend()

# 4c: Box plots
ax = axes[1, 0]
data_box = [gendered[gendered['gender'] == 'female']['price_num'],
            gendered[gendered['gender'] == 'male']['price_num']]
bp = ax.boxplot(data_box, labels=['Female', 'Male'], patch_artist=True,
                showfliers=False, widths=0.5)
bp['boxes'][0].set_facecolor(PALETTE['female'])
bp['boxes'][1].set_facecolor(PALETTE['male'])
for box in bp['boxes']:
    box.set_alpha(0.6)
ax.set_ylabel('Price (£)')
ax.set_title('Price box plots (outliers trimmed)')

# 4d: CDF
ax = axes[1, 1]
for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
    sub = gendered[gendered['gender'] == gender]['price_num'].sort_values()
    cdf = np.arange(1, len(sub) + 1) / len(sub)
    ax.plot(sub, cdf, color=color, label=gender.title(), linewidth=1.5)
ax.set_xlabel('Price (£)')
ax.set_ylabel('Cumulative probability')
ax.set_title('Cumulative distribution functions')
ax.legend()
ax.set_xlim(0, gendered['price_num'].quantile(0.95))

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '04_price_distributions.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 04_price_distributions.png")

# ---- 5. By store ----
if len(store_df) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 5a: Coefficient bar chart
    ax = axes[0]
    store_sorted = store_df.sort_values('pct_gap')
    y_pos = np.arange(len(store_sorted))
    colors = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in store_sorted['pct_gap']]
    edge_w = [2 if s else 0.5 for s in store_sorted['significant']]

    bars = ax.barh(y_pos, store_sorted['pct_gap'], color=colors, alpha=0.7, height=0.5)
    for bar, lw in zip(bars, edge_w):
        bar.set_edgecolor('black')
        bar.set_linewidth(lw)

    ax.axvline(x=0, color='black', linewidth=1)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(store_sorted['store'].values)
    ax.set_xlabel('Female price premium (%)')
    ax.set_title('Pink tax by store (thick border = p<0.05)')

    # 5b: Mean prices by store and gender
    ax = axes[1]
    x = np.arange(len(store_df))
    width = 0.35
    ax.bar(x - width/2, store_df['mean_f'], width, color=PALETTE['female'],
           alpha=0.7, label='Female', edgecolor='black', linewidth=0.3)
    ax.bar(x + width/2, store_df['mean_m'], width, color=PALETTE['male'],
           alpha=0.7, label='Male', edgecolor='black', linewidth=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels(store_df['store'].values)
    ax.set_ylabel('Mean price (£)')
    ax.set_title('Mean prices by store and gender')
    ax.legend()

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '05_by_store.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✓ 05_by_store.png")

# ---- 6. Female vs male scatter by category ----
if len(gaps_df) > 0:
    fig, ax = plt.subplots(figsize=(8, 8))

    ax.scatter(gaps_df['mean_male'], gaps_df['mean_female'],
               s=gaps_df['n_total'] * 5, alpha=0.6,
               c=[PALETTE['female'] if g > 0 else PALETTE['male'] for g in gaps_df['gap_pct']],
               edgecolors='black', linewidth=0.5)

    # 45-degree line
    lim_max = max(gaps_df['mean_male'].max(), gaps_df['mean_female'].max()) * 1.1
    ax.plot([0, lim_max], [0, lim_max], 'k--', linewidth=0.8, alpha=0.5, label='Equal price')
    ax.set_xlabel('Mean male price (£)')
    ax.set_ylabel('Mean female price (£)')
    ax.set_title('Female vs male mean prices by category\n(size = total products)')
    ax.legend()
    ax.set_aspect('equal')

    # Label the outliers
    for _, row in gaps_df.nlargest(3, 'gap_pct').iterrows():
        ax.annotate(row['category'][:30], (row['mean_male'], row['mean_female']),
                    fontsize=7, alpha=0.8,
                    xytext=(5, 5), textcoords='offset points')

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '06_scatter_by_category.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✓ 06_scatter_by_category.png")

# ---- 7. Three-way comparison (female, male, none) ----
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# 7a: Density by gender group
ax = axes[0]
for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male']),
                       ('none', PALETTE['none'])]:
    sub = df[df['gender'] == gender]
    ax.hist(sub['log_price'], bins=60, alpha=0.4, color=color, density=True,
            label=f'{gender.title()} (n={len(sub):,})')
ax.set_xlabel('Log price')
ax.set_ylabel('Density')
ax.set_title('Price distributions: all gender groups')
ax.legend()

# 7b: Box plots
ax = axes[1]
groups = ['female', 'male', 'none']
data_3way = [df[df['gender'] == g]['price_num'] for g in groups]
bp = ax.boxplot(data_3way, labels=[g.title() for g in groups],
                patch_artist=True, showfliers=False, widths=0.5)
for i, box in enumerate(bp['boxes']):
    box.set_facecolor(PALETTE[groups[i]])
    box.set_alpha(0.6)
ax.set_ylabel('Price (£)')
ax.set_title('Price comparison: gendered vs ungendered')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '07_three_way_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 07_three_way_comparison.png")

# ---- 8. Category composition ----
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# 8a: Top categories by gender
for ax_idx, (gender, title) in enumerate([('female', 'Female'), ('male', 'Male')]):
    ax = axes[ax_idx]
    sub = gendered[gendered['gender'] == gender]
    top_cats = sub['cat_mid'].value_counts().head(12)
    y_pos = np.arange(len(top_cats))
    ax.barh(y_pos, top_cats.values, color=PALETTE[gender], alpha=0.7,
            edgecolor='black', linewidth=0.3)
    ax.set_yticks(y_pos)
    ax.set_yticklabels([c[:40] for c in top_cats.index], fontsize=8)
    ax.set_xlabel('Number of products')
    ax.set_title(f'Top categories: {title} products')
    ax.invert_yaxis()

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '08_category_composition.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 08_category_composition.png")

# ---- 9. Price ratio distribution within matched categories ----
if len(gaps_df) > 0:
    fig, ax = plt.subplots(figsize=(8, 5))

    gap_values = gaps_df['gap_pct'].values
    colors = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in gap_values]

    ax.hist(gap_values, bins=max(5, len(gap_values) // 2), alpha=0.6,
            color=PALETTE['female'], edgecolor='black', linewidth=0.5)
    ax.axvline(x=0, color='black', linewidth=1.5)
    ax.axvline(x=gaps_df['gap_pct'].median(), color=PALETTE['female'],
               linewidth=1.5, linestyle='--',
               label=f'Median: {gaps_df["gap_pct"].median():+.1f}%')
    ax.set_xlabel('Female price premium (%)')
    ax.set_ylabel('Number of categories')
    ax.set_title('Distribution of within-category price gaps')
    ax.legend()

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / '09_gap_distribution.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✓ 09_gap_distribution.png")

# ---- 10. R² progression ----
fig, ax = plt.subplots(figsize=(8, 5))

r2_data = summary_df[['spec', 'r2']].copy()
y_pos = np.arange(len(r2_data))

ax.barh(y_pos, r2_data['r2'], color='#555555', alpha=0.7,
        edgecolor='black', linewidth=0.5, height=0.6)
ax.set_yticks(y_pos)
ax.set_yticklabels(r2_data['spec'])
ax.set_xlabel('R²')
ax.set_title('Model fit: how much variance do controls explain?')
ax.invert_yaxis()

for i, r2 in enumerate(r2_data['r2']):
    ax.text(r2 + 0.01, i, f'{r2:.3f}', va='center', fontsize=9)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / '10_r2_progression.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 10_r2_progression.png")

# ---- 11. Heatmap: category × store price gap ----
# Only for categories present in ≥2 stores with both genders
heatmap_data = []
for cat in gendered['cat_mid'].unique():
    for store_id in gendered['store'].unique():
        sub = gendered[(gendered['cat_mid'] == cat) & (gendered['store'] == store_id)]
        fem = sub[sub['is_female'] == 1]
        mal = sub[sub['is_female'] == 0]
        if len(fem) >= 2 and len(mal) >= 2:
            gap = (fem['price_num'].mean() / mal['price_num'].mean() - 1) * 100
            heatmap_data.append({
                'category': cat[:35],
                'store': store_names.get(store_id, store_id),
                'gap': gap,
                'n': len(fem) + len(mal),
            })

if heatmap_data:
    heat_df = pd.DataFrame(heatmap_data)
    pivot = heat_df.pivot_table(values='gap', index='category', columns='store', aggfunc='mean')

    # Only keep categories that appear in at least 2 stores
    pivot = pivot.dropna(thresh=2)

    if len(pivot) > 0:
        fig, ax = plt.subplots(figsize=(8, max(5, len(pivot) * 0.45)))
        sns.heatmap(pivot, cmap='RdBu_r', center=0, annot=True, fmt='.0f',
                    linewidths=0.5, ax=ax, cbar_kws={'label': 'F vs M gap (%)'})
        ax.set_title('Price gap (%) by category and store')
        ax.set_ylabel('')
        plt.tight_layout()
        plt.savefig(OUTPUT_DIR / '11_heatmap_category_store.png', dpi=150, bbox_inches='tight')
        plt.close()
        print("✓ 11_heatmap_category_store.png")

# ---- 12. Summary dashboard ----
fig = plt.figure(figsize=(14, 8))
gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)

# 12a: Key finding
ax = fig.add_subplot(gs[0, 0])
ax.axis('off')
ax.text(0.5, 0.85, 'KEY FINDING', ha='center', fontsize=14, fontweight='bold')
ax.text(0.5, 0.60, f'Raw gap: {(np.exp(spec1.params["is_female"])-1)*100:+.1f}%',
        ha='center', fontsize=18, color=PALETTE['male'], fontweight='bold')
ax.text(0.5, 0.40, '(female products cheaper)', ha='center', fontsize=10, color='gray')
ax.text(0.5, 0.15, f'After controls: {summary_df.iloc[-1]["pct"]:+.1f}%\n(not significant)',
        ha='center', fontsize=12, color='gray')

# 12b: Coefficient mini-plot
ax = fig.add_subplot(gs[0, 1])
y_pos = np.arange(len(summary_df))
colors = [PALETTE['female'] if c > 0 else PALETTE['male'] for c in summary_df['coef']]
ax.barh(y_pos, summary_df['pct'], color=colors, alpha=0.7, height=0.6)
ax.axvline(x=0, color='black', linewidth=1)
ax.set_yticks(y_pos)
ax.set_yticklabels([s[:18] for s in summary_df['spec']], fontsize=8)
ax.set_xlabel('% gap')
ax.set_title('Premium by spec')
ax.invert_yaxis()

# 12c: R² bar
ax = fig.add_subplot(gs[0, 2])
ax.barh(y_pos, summary_df['r2'], color='#666', alpha=0.7, height=0.6)
ax.set_yticks(y_pos)
ax.set_yticklabels([s[:18] for s in summary_df['spec']], fontsize=8)
ax.set_xlabel('R²')
ax.set_title('Model fit')
ax.invert_yaxis()

# 12d: Price distributions
ax = fig.add_subplot(gs[1, 0])
for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
    sub = gendered[gendered['gender'] == gender]
    ax.hist(sub['log_price'], bins=40, alpha=0.5, color=color, density=True, label=gender.title())
ax.legend(fontsize=8)
ax.set_xlabel('Log price')
ax.set_title('Price distributions')

# 12e: Quantile regression
ax = fig.add_subplot(gs[1, 1])
ax.fill_between(qreg_df['quantile'], qreg_df['ci_lo'], qreg_df['ci_hi'],
                alpha=0.2, color=PALETTE['female'])
ax.plot(qreg_df['quantile'], qreg_df['coef'], 'o-', color=PALETTE['female'], linewidth=2)
ax.axhline(y=0, color='black', linewidth=0.8, linestyle='--')
ax.set_xlabel('Quantile')
ax.set_ylabel('Coefficient')
ax.set_title('Quantile regression')

# 12f: By store
if len(store_df) > 0:
    ax = fig.add_subplot(gs[1, 2])
    y_pos = np.arange(len(store_df))
    colors = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in store_df['pct_gap']]
    ax.barh(y_pos, store_df['pct_gap'], color=colors, alpha=0.7, height=0.5)
    ax.axvline(x=0, color='black', linewidth=1)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(store_df['store'].values, fontsize=9)
    ax.set_xlabel('% gap')
    ax.set_title('Gap by store')

plt.savefig(OUTPUT_DIR / '12_summary_dashboard.png', dpi=150, bbox_inches='tight')
plt.close()
print("✓ 12_summary_dashboard.png")

# ============================================================================
# FINAL OUTPUT
# ============================================================================

print("\n" + "=" * 70)
print("OUTPUT FILES")
print("=" * 70)

all_files = sorted(OUTPUT_DIR.glob('*'))
for f in all_files:
    print(f"  {f.name}")

# Save full summary
full_summary = {
    'raw_gap_pct': float((np.exp(spec1.params['is_female']) - 1) * 100),
    'raw_gap_p': float(spec1.pvalues['is_female']),
    'controlled_gap_pct': float(summary_df.iloc[-1]['pct']),
    'controlled_gap_p': float(summary_df.iloc[-1]['p']),
    'within_category_median_gap': float(median_gap) if len(gaps_df) > 0 else None,
    'n_gendered_products': int(len(gendered)),
    'n_female': int(gendered['is_female'].sum()),
    'n_male': int((1 - gendered['is_female']).sum()),
    'quantile_results': qreg_df.to_dict('records'),
    'store_results': store_df.to_dict('records') if len(store_df) > 0 else [],
}

with open(OUTPUT_DIR / 'full_summary.json', 'w') as f:
    json.dump(full_summary, f, indent=2, default=str)

print(f"\n✓ All outputs saved to {OUTPUT_DIR}")

PINK TAX REGRESSION ANALYSIS (v2)
Loaded 21,436 products
After category filter: 12,832
With valid prices: 12,832

Gender distribution:
  female: 1,075  (mean £7.03, median £4.00)
  male: 844  (mean £8.15, median £4.50)
  none: 10,913  (mean £10.66, median £5.00)

ANALYSIS SAMPLE: FEMALE vs MALE
N = 1,919  (F: 1075, M: 844)
Mean price  — F: £7.03, M: £8.15
Median price — F: £4.00, M: £4.50
  Categories (broad) with ≥3F + ≥3M: 11
  Categories (mid) with ≥3F + ≥3M: 11
  Categories (fine) with ≥3F + ≥3M: 11

Spec 1: Raw gap
  (1) Raw gap: coef=-0.0975 (-9.3%), SE=0.0411, p=0.0178**, R²=0.003, N=1919

Spec 2: + Store FE
  (2) + Store FE: coef=-0.1036 (-9.8%), SE=0.0412, p=0.0118**, R²=0.006, N=1919

Spec 3: + Broad category FE (N cats: 96)
  (3) + Broad cat FE: coef=+0.0803 (+8.4%), SE=0.0851, p=0.3454, R²=0.672, N=1374

Spec 4: + Mid category FE (N cats: 96)
  (4) + Mid cat FE: coef=+0.0803 (+8.4%), SE=0.0851, p=0.3454, R²=0.672, N=1374

Spec 5: + Fine category FE (N cats: 96)
  (5) + Fine