In [3]:
# ============================================================================
# PINK TAX ANALYSIS - UNIFIED PIPELINE
# ============================================================================
#
# Consolidates four previously separate scripts into one modular codebase:
#   Stage 1: Color extraction from product images (adaptive domain handling)
#   Stage 2: ML-based gender prediction (L1/L2/RF/HGB/SVM)
#   Stage 3: Regression analysis (OLS, quantile, within-category, by-store)
#   Stage 4: Color visualisations for portfolio
#
# Usage:
#   python pink_tax_pipeline.py --stage [1|2|3|4|all]
#
# ============================================================================

import json
import re
import time
import warnings
from collections import Counter
from io import BytesIO
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

warnings.filterwarnings('ignore')

# ============================================================================
# CONFIGURATION
# ============================================================================

BASE_DIR = Path('/Users/leoss/Desktop/Portfolio/Website-/projects')
DATA_DIR = BASE_DIR / 'pink-tax/data'
OUTPUT_BASE = BASE_DIR / 'pink-tax/outputs'

# Input files
PATH_MAIN_DATA = DATA_DIR / 'items_fin.csv'
PATH_HUMAN_CODED = DATA_DIR / 'items_prices_description_gender_humancode_sample.csv'
PATH_YOUR_LABELED = DATA_DIR / 'available_validation.xlsx'

# Output directories
# Colour scraping outputs
COLOUR_DIR = OUTPUT_BASE / 'colour'
COLOR_CACHE_PATH = COLOUR_DIR / 'color_features_cache_v3_filtered.csv'
FAILED_URLS_PATH = COLOUR_DIR / 'failed_urls.csv'

# Chart outputs (interactive HTML)
CHART_REG_DIR = OUTPUT_BASE / 'charts/regression'
CHART_ML_DIR = OUTPUT_BASE / 'charts/ml'
CHART_VAL_DIR = OUTPUT_BASE / 'charts/validation'

# Tabular and summary outputs
TEXT_DIR = OUTPUT_BASE / 'text'

RANDOM_STATE = 42

# ---- Column names ----
COL_PRODUCT_ID = 'product_id'
COL_IMAGE = 'image_url'
COL_BREADCRUMB = 'standardized_breadcrumbs'
COL_NAME = 'product_title_x'
COL_DESC = 'description'
COL_PRICE = 'price'
COL_UNIT_PRICE = 'unit_price'
COL_STORE = 'store_id'
COL_URL = 'product_url_x'

# ---- Gender keywords ----
FEMALE_KEYWORDS = [
    'women', 'woman', 'female', 'ladies', 'lady', 'girls',
    'womens', "women's", 'femme', 'her', 'feminine', 'fem',
]
MALE_KEYWORDS = [
    'men', 'man', 'male', 'gentleman', 'gentlemen', 'boys',
    'mens', "men's", 'homme', 'his', 'masculine',
]
ALL_GENDER_KEYWORDS = set(FEMALE_KEYWORDS + MALE_KEYWORDS)

# ---- Category exclusions ----
EXCLUDE_CATEGORIES = [
    'food', 'grocery', 'groceries', 'snacks', 'drinks', 'beverages',
    'pet food', 'pet supplies', 'cleaning', 'household', 'kitchen',
    'office', 'stationery', 'electronics', 'tech', 'garden', 'automotive',
]

# ---- Color definitions ----
STANDARD_COLORS = {
    'dark_red': (139, 0, 0), 'red': (255, 0, 0), 'coral': (255, 127, 80),
    'salmon': (250, 128, 114), 'crimson': (220, 20, 60), 'brown': (139, 69, 19),
    'tan': (210, 180, 140), 'orange': (255, 165, 0), 'gold': (255, 215, 0),
    'yellow': (255, 255, 0), 'khaki': (240, 230, 140), 'dark_green': (0, 100, 0),
    'green': (0, 128, 0), 'lime': (50, 205, 50), 'olive': (128, 128, 0),
    'teal': (0, 128, 128), 'navy': (0, 0, 128), 'blue': (0, 0, 255),
    'royal_blue': (65, 105, 225), 'sky_blue': (135, 206, 235), 'cyan': (0, 255, 255),
    'purple': (128, 0, 128), 'magenta': (255, 0, 255), 'violet': (238, 130, 238),
    'lavender': (230, 230, 250), 'pink': (255, 192, 203), 'hot_pink': (255, 105, 180),
    'gray': (128, 128, 128), 'silver': (192, 192, 192),
    'black': (0, 0, 0), 'white': (255, 255, 255),
}

# ---- Stage 1 settings ----
N_COLORS = 3
TIMEOUT = 15
MAX_SAMPLES = 20000
PRIORITIZE_GENDERED = False
MAX_RETRIES = 1
RETRY_DELAY = 1
SAVE_EVERY = 200
MIN_DOMAIN_SAMPLES = 30
HEAD_CHECK_THRESHOLD = 0.40
SKIP_THRESHOLD = 0.05

# ---- Stage 2 settings ----
TEST_SIZE = 0.25
CV_FOLDS = 5
MIN_CLASS_SIZE = 50
MIN_TEST_SAMPLES = 10

# ---- Stage 3 settings ----
N_BOOTSTRAP = 1000

# ---- Chart style ----
PALETTE = {'female': '#c44e52', 'male': '#4c72b0', 'none': '#8c8c8c'}

# ---- Unified portfolio style (Plotly) ----
STYLE = {
    'font_family': 'IBM Plex Sans, -apple-system, BlinkMacSystemFont, sans-serif',
    'tick_size': 11,
    'axis_title_size': 13,
    'legend_size': 11,
    'annotation_size': 10,
    'title_color': '#1a2744',
    'navy': '#1a2744',
    'slate': '#3d4f5f',
    'steel': '#4a6fa5',
    'grey_300': '#c9cfd6',
    'grey_200': '#dde1e7',
    'grey_100': '#f0f2f5',
    'text_secondary': '#5a6675',
    'pos_color': '#2e7d4a',
    'neg_color': '#c23a3a',
    'zero_line_color': '#c9cfd6',
    'template': 'plotly_white',
    'plot_bg': 'rgba(0,0,0,0)',
    'paper_bg': 'white',
    'chart_height': 550,
    'chart_height_small': 420,
    'chart_height_tall': 700,
    'margin_default': dict(l=60, r=40, t=20, b=50),
    'margin_bar': dict(l=200, r=60, t=20, b=50),
    'margin_map': dict(l=10, r=10, t=10, b=60),
    'grid_color': '#e5e7eb',
    'grid_width': 0.5,
    'choropleth_line_color': '#c9cfd6',
    'choropleth_line_width': 0.5,
    'colorbar': dict(len=0.7, thickness=15),
    'marker_line': dict(width=0.5, color='white'),
}

WRITE_CONFIG = {'displayModeBar': False}


def base_layout(**overrides):
    layout = dict(
        template=STYLE['template'],
        font=dict(family=STYLE['font_family'], size=STYLE['tick_size'],
                  color='#4b5563'),
        paper_bgcolor=STYLE['paper_bg'],
        plot_bgcolor=STYLE['plot_bg'],
        height=STYLE['chart_height'],
        margin=STYLE['margin_default'],
        title='',
        hoverlabel=dict(
            bgcolor='white',
            bordercolor='#c9cfd6',
            font=dict(family=STYLE['font_family'], size=13, color='#1a2744'),
        ),
    )
    for k, v in overrides.items():
        layout[k] = v
    return layout


def styled_axis(title_text, **kw):
    d = dict(
        title=dict(text=title_text,
                   font=dict(size=STYLE['axis_title_size'],
                             family=STYLE['font_family'])),
        tickfont=dict(size=STYLE['tick_size'],
                      family=STYLE['font_family']),
        gridcolor=STYLE['grid_color'],
        gridwidth=STYLE['grid_width'],
        zeroline=False,
    )
    d.update(kw)
    return d


def save_html(fig, filepath):
    fig.write_html(
        str(filepath),
        config=WRITE_CONFIG,
        include_plotlyjs='cdn',
    )
    import os
    print(f"   \u2713 {os.path.basename(str(filepath))}")


# ============================================================================
# SHARED UTILITIES
# ============================================================================

def load_main_data(path=PATH_MAIN_DATA):
    """Load and normalise the main product dataset."""
    df = pd.read_csv(path, encoding='latin-1')
    df.columns = df.columns.str.lower().str.strip().str.replace(' ', '_')
    if 'unnamed:_0' in df.columns:
        df = df.drop(columns=['unnamed:_0'])
    return df


def contains_excluded_category(text):
    """Check whether a breadcrumb string matches any excluded category."""
    if pd.isna(text):
        return False
    text_lower = str(text).lower()
    return any(cat in text_lower for cat in EXCLUDE_CATEGORIES)


def filter_excluded_categories(df):
    """Remove products in excluded categories. Returns filtered copy."""
    mask = df[COL_BREADCRUMB].apply(contains_excluded_category)
    return df[~mask].copy().reset_index(drop=True)


def extract_gender_from_text(text):
    """Return 'female', 'male', 'both', or 'none' from a single text field."""
    if pd.isna(text) or str(text).strip() == '':
        return 'none'
    text_lower = str(text).lower()
    has_female = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in FEMALE_KEYWORDS)
    has_male = any(re.search(r'\b' + kw + r'\b', text_lower) for kw in MALE_KEYWORDS)
    if has_female and not has_male:
        return 'female'
    if has_male and not has_female:
        return 'male'
    if has_female and has_male:
        return 'both'
    return 'none'


def extract_gender_label(row):
    """Combine gender signals from breadcrumb, title, and description."""
    for col in [COL_BREADCRUMB, COL_NAME, COL_DESC]:
        if col in row.index:
            gender = extract_gender_from_text(row[col])
            if gender in ('female', 'male'):
                return gender
    return 'none'


def add_gender_labels(df):
    """Add per-field and combined gender labels to the dataframe in place."""
    df['label_bc'] = df[COL_BREADCRUMB].apply(extract_gender_from_text)
    df['label_name'] = df[COL_NAME].apply(extract_gender_from_text)
    df['label_desc'] = df[COL_DESC].apply(extract_gender_from_text)
    df['label_extracted'] = df.apply(extract_gender_label, axis=1)
    return df


def clean_text_remove_gender(text, remove_words=ALL_GENDER_KEYWORDS):
    """Lowercase, strip gender keywords and punctuation."""
    if pd.isna(text):
        return ''
    text = str(text).lower()
    for word in remove_words:
        text = re.sub(r'\b' + word + r'\b', '', text)
    text = re.sub(r'[^a-z\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def get_domain(url):
    try:
        return str(url).split('/')[2]
    except (IndexError, AttributeError):
        return 'unknown'


def parse_breadcrumb(text):
    """Extract clean category levels from a breadcrumb string."""
    if pd.isna(text):
        return 'unknown', 'unknown', 'unknown'
    text = str(text).strip()
    if ' > ' in text:
        parts = [p.strip().lower() for p in text.split(' > ') if p.strip()]
    elif ' / ' in text:
        parts = [p.strip().lower() for p in text.split(' / ') if p.strip()]
    else:
        parts = [text.strip().lower()]

    known_stores = {'morrisons', 'tesco', 'asda', 'groceries', 'marketplace'}
    while parts and parts[0] in known_stores:
        parts = parts[1:]

    level1 = parts[0] if len(parts) > 0 else 'unknown'
    level2 = parts[1] if len(parts) > 1 else 'unknown'
    level3 = parts[2] if len(parts) > 2 else 'unknown'
    return level1, level2, level3


# ---- Color helpers ----

def color_distance(c1, c2):
    return np.sqrt(sum((a - b) ** 2 for a, b in zip(c1, c2)))


def closest_standard_color(rgb):
    min_dist = float('inf')
    closest = 'gray'
    for name, std_rgb in STANDARD_COLORS.items():
        dist = color_distance(rgb, std_rgb)
        if dist < min_dist:
            min_dist = dist
            closest = name
    return closest


def is_background_color(rgb):
    r, g, b = rgb
    if r > 240 and g > 240 and b > 240:
        return True
    if r < 15 and g < 15 and b < 15:
        return True
    max_diff = max(abs(r - g), abs(g - b), abs(r - b))
    avg = (r + g + b) / 3
    if max_diff < 20 and 100 < avg < 160:
        return True
    return False


def rgb_norm(name):
    """Normalised RGB tuple (kept for compatibility)."""
    r, g, b = STANDARD_COLORS.get(name, (128, 128, 128))
    return (r / 255, g / 255, b / 255)


def rgb_hex(name):
    """Hex color string for Plotly."""
    r, g, b = STANDARD_COLORS.get(name, (128, 128, 128))
    return f'#{r:02x}{g:02x}{b:02x}'


def text_color_for_bg(name):
    """Black or white text depending on background luminance."""
    r, g, b = STANDARD_COLORS.get(name, (128, 128, 128))
    lum = 0.299 * r + 0.587 * g + 0.114 * b
    return 'white' if lum < 140 else '#1a1a1a'


def edge_color_for(name):
    """Light colors get a visible border."""
    r, g, b = STANDARD_COLORS.get(name, (128, 128, 128))
    lum = 0.299 * r + 0.587 * g + 0.114 * b
    return '#aaaaaa' if lum > 200 else 'none'


def save_incremental(new_results, cache_path):
    """Append new results to the cache CSV, deduplicating on product_id."""
    if not new_results:
        return 0
    new_df = pd.DataFrame(new_results)
    if cache_path.exists():
        existing = pd.read_csv(cache_path)
        combined = pd.concat([existing, new_df]).drop_duplicates(subset=[COL_PRODUCT_ID])
    else:
        combined = new_df
    combined.to_csv(cache_path, index=False)
    return len(combined)


# ============================================================================
# STAGE 1: COLOR EXTRACTION
# ============================================================================

class DomainTracker:
    """
    Tracks per-domain success/failure rates during extraction.
    After MIN_DOMAIN_SAMPLES attempts, adjusts strategy:
      - success rate < HEAD_CHECK_THRESHOLD: HEAD pre-check before GET
      - success rate < SKIP_THRESHOLD: skip entirely
    """

    def __init__(self, min_samples, head_threshold, skip_threshold):
        self.min_samples = min_samples
        self.head_threshold = head_threshold
        self.skip_threshold = skip_threshold
        self.attempts = Counter()
        self.successes = Counter()
        self._notified_head = set()
        self._notified_skip = set()

    def record(self, domain, success):
        self.attempts[domain] += 1
        if success:
            self.successes[domain] += 1

    def success_rate(self, domain):
        total = self.attempts[domain]
        return self.successes[domain] / total if total else 1.0

    def should_skip(self, domain):
        if self.attempts[domain] < self.min_samples:
            return False
        skip = self.success_rate(domain) < self.skip_threshold
        if skip and domain not in self._notified_skip:
            rate = self.success_rate(domain)
            print(f"  [domain tracker] Skipping {domain} "
                  f"(success rate {rate:.0%} after {self.attempts[domain]} attempts)")
            self._notified_skip.add(domain)
        return skip

    def should_head_check(self, domain):
        if self.attempts[domain] < self.min_samples:
            return False
        rate = self.success_rate(domain)
        head_check = rate < self.head_threshold and rate >= self.skip_threshold
        if head_check and domain not in self._notified_head:
            print(f"  [domain tracker] HEAD pre-checking {domain} "
                  f"(success rate {rate:.0%} after {self.attempts[domain]} attempts)")
            self._notified_head.add(domain)
        return head_check

    def summary(self):
        out = {}
        for domain in sorted(self.attempts, key=lambda d: self.attempts[d], reverse=True):
            total = self.attempts[domain]
            ok = self.successes[domain]
            out[domain] = (ok, total, ok / total if total else 0)
        return out


def _build_session():
    import requests
    session = requests.Session()
    session.headers.update({
        'User-Agent': (
            'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
            'AppleWebKit/537.36 (KHTML, like Gecko) '
            'Chrome/91.0.4472.124 Safari/537.36'
        ),
        'Accept': 'image/avif,image/webp,image/apng,image/*,*/*;q=0.8',
        'Accept-Language': 'en-US,en;q=0.9',
    })
    return session


def _head_check_alive(session, url, timeout=5):
    import requests
    try:
        resp = session.head(url, timeout=timeout, allow_redirects=True)
        return resp.status_code == 200
    except requests.RequestException:
        return False


def _is_transient_error(status_code):
    if status_code is None:
        return True
    return status_code >= 500 or status_code == 429


def extract_colors_from_url(session, domain_tracker, image_url,
                            n_colors=3, timeout=15, max_retries=1):
    """
    Download an image and extract dominant colors via KMeans.
    Returns (colors_list | None, error_reason | None).
    """
    import requests
    from PIL import Image
    from sklearn.cluster import KMeans

    url = str(image_url).strip()
    domain = get_domain(url)

    if domain_tracker.should_skip(domain):
        return None, 'adaptive_skip'
    if domain_tracker.should_head_check(domain):
        if not _head_check_alive(session, url, timeout=5):
            return None, 'head_check_dead'

    last_error = None
    for attempt in range(max_retries + 1):
        try:
            response = session.get(url, timeout=timeout, allow_redirects=True)
            if response.status_code != 200:
                last_error = f'http_{response.status_code}'
                if _is_transient_error(response.status_code) and attempt < max_retries:
                    time.sleep(RETRY_DELAY)
                    continue
                return None, last_error

            img = Image.open(BytesIO(response.content)).convert('RGB')
            img = img.resize((100, 100))
            pixels = np.array(img).reshape(-1, 3)

            non_bg = np.array([p for p in pixels if not is_background_color(tuple(p))])
            if len(non_bg) < 50:
                non_bg = pixels

            n_clusters = min(n_colors + 2, len(non_bg))
            kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
            kmeans.fit(non_bg)

            counts = Counter(kmeans.labels_)
            total = len(kmeans.labels_)
            colors = []
            for cluster_id, count in sorted(counts.items(), key=lambda x: x[1], reverse=True):
                rgb = tuple(int(c) for c in kmeans.cluster_centers_[cluster_id])
                if not is_background_color(rgb):
                    colors.append({'name': closest_standard_color(rgb), 'weight': count / total})
                if len(colors) >= n_colors:
                    break

            return (colors, None) if colors else (None, 'no_non_bg_colors')

        except requests.exceptions.Timeout:
            last_error = 'timeout'
        except requests.exceptions.ConnectionError:
            last_error = 'connection_error'
        except requests.exceptions.RequestException:
            last_error = 'request_error'
        except Exception:
            last_error = 'processing_error'

        if attempt < max_retries:
            time.sleep(RETRY_DELAY)

    return None, last_error


def run_color_extraction():
    """Stage 1: extract dominant colors from product images."""
    print("=" * 70)
    print("STAGE 1: COLOR EXTRACTION (adaptive domain handling)")
    print("=" * 70)
    print(f"Timeout: {TIMEOUT}s | Retries: {MAX_RETRIES} (transient only)")
    print(f"Adaptive thresholds: HEAD-check < {HEAD_CHECK_THRESHOLD:.0%} "
          f"success, skip < {SKIP_THRESHOLD:.0%} success "
          f"(after {MIN_DOMAIN_SAMPLES} samples)")
    print()

    OUTPUT_BASE.mkdir(parents=True, exist_ok=True)
    COLOUR_DIR.mkdir(parents=True, exist_ok=True)

    session = _build_session()
    domain_tracker = DomainTracker(MIN_DOMAIN_SAMPLES, HEAD_CHECK_THRESHOLD, SKIP_THRESHOLD)

    # Resume support
    already_done = set()
    if COLOR_CACHE_PATH.exists():
        existing = pd.read_csv(COLOR_CACHE_PATH)
        if COL_PRODUCT_ID in existing.columns:
            already_done = set(existing[COL_PRODUCT_ID].astype(str))
            print(f"Existing cache: {len(already_done):,} products -- will resume.\n")

    # Load and filter
    print("Loading data...")
    df = load_main_data()
    print(f"  Loaded {len(df):,} products")

    print("Filtering categories...")
    df = filter_excluded_categories(df)
    print(f"  Remaining: {len(df):,} products")

    print("Extracting gender labels...")
    df = add_gender_labels(df)
    for label, count in df['label_extracted'].value_counts().items():
        print(f"  {label}: {count:,}")

    print("Selecting products...")
    df_with_images = df[df[COL_IMAGE].notna()].copy()

    df_with_images['_domain'] = df_with_images[COL_IMAGE].apply(get_domain)
    print("  Domain breakdown:")
    for domain, count in df_with_images['_domain'].value_counts().items():
        print(f"    {domain}: {count:,}")

    to_extract = df_with_images.copy()
    if len(to_extract) > MAX_SAMPLES:
        to_extract = to_extract.sample(n=MAX_SAMPLES, random_state=RANDOM_STATE)

    to_extract = to_extract[~to_extract[COL_PRODUCT_ID].astype(str).isin(already_done)]
    print(f"  To process (after resume filter): {len(to_extract):,}")

    # Extract
    print(f"\nExtracting colors (top {N_COLORS} per image)...\n")

    color_results = []
    failed_records = []
    error_counter = Counter()
    start_time = time.time()

    for idx, (row_idx, row) in enumerate(to_extract.iterrows()):
        if (idx + 1) % 100 == 0 or idx == 0:
            elapsed = time.time() - start_time
            rate = (idx + 1) / elapsed if elapsed > 0 else 0
            remaining = (len(to_extract) - idx - 1) / rate if rate > 0 else 0
            ok = len(color_results)
            total_so_far = idx + 1
            pct = 100 * ok / total_so_far if total_so_far else 0
            print(f"  {idx+1:,}/{len(to_extract):,} "
                  f"| OK: {ok} ({pct:.0f}%) "
                  f"| {rate:.1f} img/s "
                  f"| ETA: {remaining/60:.0f} min")

        url = row[COL_IMAGE]
        domain = get_domain(str(url))

        colors, error = extract_colors_from_url(
            session, domain_tracker, url,
            n_colors=N_COLORS, timeout=TIMEOUT, max_retries=MAX_RETRIES,
        )

        success = colors is not None
        domain_tracker.record(domain, success)

        if success:
            entry = {
                'original_index': row_idx,
                COL_PRODUCT_ID: row[COL_PRODUCT_ID],
                'label_extracted': row['label_extracted'],
            }
            for i, c in enumerate(colors):
                entry[f'color{i+1}_name'] = c['name']
                entry[f'color{i+1}_weight'] = c['weight']
            color_results.append(entry)
        else:
            error_counter[error] += 1
            failed_records.append({
                'product_id': row[COL_PRODUCT_ID],
                'url': url,
                'label': row['label_extracted'],
                'error': error,
            })

        if (idx + 1) % SAVE_EVERY == 0 and color_results:
            total_in_cache = save_incremental(color_results, COLOR_CACHE_PATH)
            print(f"  [checkpoint] {total_in_cache:,} products in cache")

    # Final save
    print(f"\nSaving...")
    if color_results:
        total_in_cache = save_incremental(color_results, COLOR_CACHE_PATH)
        print(f"  Cache: {total_in_cache:,} products -> {COLOR_CACHE_PATH}")

    if failed_records:
        pd.DataFrame(failed_records).to_csv(FAILED_URLS_PATH, index=False)
        print(f"  Failed URLs log -> {FAILED_URLS_PATH}")

    # Summary
    total = len(to_extract)
    success_count = len(color_results)
    success_rate = 100 * success_count / total if total else 0

    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    print(f"Processed: {total:,}")
    print(f"Extracted: {success_count:,} ({success_rate:.1f}%)")
    print(f"Failed:    {total - success_count:,} ({100 - success_rate:.1f}%)")

    print("\nFailure breakdown:")
    for error, count in error_counter.most_common(10):
        print(f"  {error}: {count} ({100*count/total:.1f}%)")

    print("\nPer-domain results:")
    print(f"  {'domain':<40s} {'ok':>5s} / {'total':>5s}  {'rate':>6s}")
    print(f"  {'-'*60}")
    for domain, (ok, tot, rate) in domain_tracker.summary().items():
        print(f"  {domain:<40s} {ok:>5d} / {tot:>5d}  {rate:>5.0%}")


# ============================================================================
# STAGE 2: ML GENDER PREDICTION
# ============================================================================

def run_ml_pipeline():
    """Stage 2: train classifiers, predict gender, export validation sample."""
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler, LabelEncoder
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.linear_model import LogisticRegressionCV
    from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
    from scipy.sparse import hstack, csr_matrix

    CHART_ML_DIR.mkdir(parents=True, exist_ok=True)
    TEXT_DIR.mkdir(parents=True, exist_ok=True)

    print("=" * 70)
    print("STAGE 2: ML GENDER PREDICTION PIPELINE")
    print("=" * 70)

    # ---- Load data ----
    df = load_main_data()
    print(f"Main dataset: {len(df):,} products")

    human_coded = pd.read_csv(PATH_HUMAN_CODED, encoding='latin-1')
    human_coded.columns = human_coded.columns.str.lower().str.strip().str.replace(' ', '_')
    if 'unnamed:_0' in human_coded.columns:
        human_coded = human_coded.drop(columns=['unnamed:_0'])
    print(f"Human-coded: {len(human_coded)} products")

    try:
        your_labeled = pd.read_excel(PATH_YOUR_LABELED)
        your_labeled.columns = your_labeled.columns.str.lower().str.strip().str.replace(' ', '_')
        print(f"Your labeled: {len(your_labeled)} products")
    except FileNotFoundError:
        your_labeled = pd.DataFrame()
        print("Your labeled file not found (optional)")

    # ---- Filter and label ----
    original_count = len(df)
    df = filter_excluded_categories(df)
    excluded_count = original_count - len(df)
    print(f"Filtered: {original_count:,} -> {len(df):,} (excluded {excluded_count:,})")

    df = add_gender_labels(df)
    print(f"Label distribution: {df['label_extracted'].value_counts().to_dict()}")

    # Merge human labels
    if 'human_gender_label' in human_coded.columns and COL_PRODUCT_ID in human_coded.columns:
        human_labels = human_coded[[COL_PRODUCT_ID, 'human_gender_label']].drop_duplicates()
        human_labels.columns = [COL_PRODUCT_ID, 'label_human']
        human_labels['label_human'] = human_labels['label_human'].str.lower().str.strip()
        df = df.merge(human_labels, on=COL_PRODUCT_ID, how='left')
    else:
        df['label_human'] = None
    print(f"Human labels merged: {df['label_human'].notna().sum()}")

    # ---- Load color cache ----
    if COLOR_CACHE_PATH.exists():
        color_df = pd.read_csv(COLOR_CACHE_PATH)
        print(f"Color cache: {len(color_df):,} products")
        matched = color_df[COL_PRODUCT_ID].isin(df[COL_PRODUCT_ID]).sum()
        print(f"  Matched to filtered data: {matched:,} / {len(color_df):,}")
    else:
        print("No color cache found -- proceeding without color features")
        color_df = pd.DataFrame()

    # ---- Training data ----
    female_all = df[df['label_extracted'] == 'female'].copy()
    male_all = df[df['label_extracted'] == 'male'].copy()
    print(f"\nExplicitly female: {len(female_all)}, male: {len(male_all)}")

    if len(female_all) < MIN_CLASS_SIZE or len(male_all) < MIN_CLASS_SIZE:
        raise ValueError(f"Insufficient gendered samples (need >= {MIN_CLASS_SIZE} per class).")

    human_none = df[df['label_human'] == 'none'].copy()
    extracted_none = df[(df['label_extracted'] == 'none') & (df['label_human'].isna())].copy()
    min_gendered = min(len(female_all), len(male_all))
    target_none = min_gendered

    if len(human_none) >= target_none:
        none_all = human_none.sample(n=target_none, random_state=RANDOM_STATE)
    else:
        remaining = target_none - len(human_none)
        sampled_none = extracted_none.sample(
            n=min(remaining, len(extracted_none)), random_state=RANDOM_STATE)
        none_all = pd.concat([human_none, sampled_none])

    min_class = min(len(female_all), len(male_all), len(none_all))
    print(f"Balancing to: {min_class} per class")

    female_balanced = female_all.sample(n=min_class, random_state=RANDOM_STATE)
    male_balanced = male_all.sample(n=min_class, random_state=RANDOM_STATE)
    none_balanced = none_all.sample(n=min_class, random_state=RANDOM_STATE)

    ml_data = pd.concat([female_balanced, male_balanced, none_balanced]).copy()
    ml_data['target'] = ml_data['label_extracted'].map({'female': 0, 'male': 1})
    ml_data.loc[ml_data['target'].isna(), 'target'] = 2
    ml_data['target'] = ml_data['target'].astype(int)
    print(f"Training data: {len(ml_data)} (classes: {ml_data['target'].value_counts().sort_index().to_dict()})")

    # ---- Train/test split ----
    train_idx, test_idx = train_test_split(
        ml_data.index, test_size=TEST_SIZE,
        random_state=RANDOM_STATE, stratify=ml_data['target'],
    )
    train_data = ml_data.loc[train_idx].copy()
    test_data = ml_data.loc[test_idx].copy()
    print(f"Train: {len(train_data)}, Test: {len(test_data)}")

    # ---- Feature engineering ----
    all_datasets = [train_data, test_data, df]
    for dataset in all_datasets:
        dataset['breadcrumb_clean'] = dataset[COL_BREADCRUMB].apply(clean_text_remove_gender)
        dataset['description_clean'] = dataset[COL_DESC].apply(clean_text_remove_gender)

    # Price features
    price_features = ['feat_price_log', 'feat_unit_price']
    for dataset in all_datasets:
        dataset['feat_price'] = pd.to_numeric(dataset[COL_PRICE], errors='coerce')
        dataset['feat_price_log'] = np.log1p(dataset['feat_price'])
        if COL_UNIT_PRICE in dataset.columns:
            dataset['feat_unit_price'] = (
                dataset[COL_UNIT_PRICE].astype(str).str.extract(r'([\d.]+)')[0].astype(float))
        else:
            dataset['feat_unit_price'] = 0

    # Store encoding
    store_encoder = LabelEncoder()
    all_stores = pd.concat([d[COL_STORE] for d in all_datasets]).fillna('unknown')
    store_encoder.fit(all_stores.unique())

    def encode_stores(data, encoder):
        stores = data[COL_STORE].fillna('unknown')
        encoded = []
        for s in stores:
            if s in encoder.classes_:
                encoded.append(encoder.transform([s])[0])
            else:
                encoded.append(-1)
        return np.array(encoded)

    for dataset in all_datasets:
        dataset['store_encoded'] = encode_stores(dataset, store_encoder)
    n_stores = len(store_encoder.classes_) + 1

    # TF-IDF (fit on train only)
    breadcrumb_vectorizer = TfidfVectorizer(
        max_features=80, min_df=8, max_df=0.8, ngram_range=(1, 1), stop_words='english')
    breadcrumb_vectorizer.fit(train_data['breadcrumb_clean'])

    description_vectorizer = TfidfVectorizer(
        max_features=150, min_df=8, max_df=0.8, ngram_range=(1, 1), stop_words='english')
    description_vectorizer.fit(train_data['description_clean'])

    print(f"Breadcrumb TF-IDF: {len(breadcrumb_vectorizer.get_feature_names_out())} features")
    print(f"Description TF-IDF: {len(description_vectorizer.get_feature_names_out())} features")

    # Color features
    color_feature_cols = []
    for color_name in STANDARD_COLORS.keys():
        for i in range(1, N_COLORS + 1):
            color_feature_cols.append(f'feat_color{i}_{color_name}')

    for dataset in all_datasets:
        for col in color_feature_cols:
            dataset[col] = 0.0

    color_lookup = {}
    if len(color_df) > 0 and COL_PRODUCT_ID in color_df.columns:
        for _, row in color_df.iterrows():
            pid = row[COL_PRODUCT_ID]
            feats = {}
            for i in range(1, N_COLORS + 1):
                cname = row.get(f'color{i}_name')
                cweight = row.get(f'color{i}_weight')
                if pd.notna(cname) and cname in STANDARD_COLORS and pd.notna(cweight):
                    feats[f'feat_color{i}_{cname}'] = cweight
            if feats:
                color_lookup[pid] = feats

        print(f"Color lookup built: {len(color_lookup):,} products")
        for dataset in all_datasets:
            matched = 0
            for idx_row, row in dataset.iterrows():
                pid = row[COL_PRODUCT_ID]
                if pid in color_lookup:
                    for col, val in color_lookup[pid].items():
                        dataset.at[idx_row, col] = val
                    matched += 1
            if len(dataset) < 10000:
                print(f"  Color features filled for {matched}/{len(dataset)} rows")

    train_has_color = sum(1 for _, r in train_data.iterrows() if r[COL_PRODUCT_ID] in color_lookup)
    print(f"Color features: {len(color_feature_cols)} "
          f"(available for {train_has_color}/{len(train_data)} train samples)")

    # ---- Build feature matrices ----
    def build_feature_matrix(data, bc_vec, desc_vec, color_cols, p_features,
                             ns, include_colors=True):
        feature_names = []
        blocks = []

        X_price = data[p_features].fillna(0).values
        blocks.append(csr_matrix(X_price))
        feature_names.extend(p_features)

        store_enc = data['store_encoded'].values
        X_store = np.zeros((len(data), ns))
        for i, s in enumerate(store_enc):
            if s >= 0:
                X_store[i, s] = 1
            else:
                X_store[i, -1] = 1
        blocks.append(csr_matrix(X_store))
        feature_names.extend([f'store_{i}' for i in range(ns)])

        X_bc = bc_vec.transform(data['breadcrumb_clean'])
        blocks.append(X_bc)
        feature_names.extend([f'bc_{f}' for f in bc_vec.get_feature_names_out()])

        X_desc = desc_vec.transform(data['description_clean'])
        blocks.append(X_desc)
        feature_names.extend([f'desc_{f}' for f in desc_vec.get_feature_names_out()])

        if include_colors:
            X_color = data[color_cols].values
            blocks.append(csr_matrix(X_color))
            feature_names.extend(color_cols)

        return hstack(blocks), feature_names

    X_train, feature_names = build_feature_matrix(
        train_data, breadcrumb_vectorizer, description_vectorizer,
        color_feature_cols, price_features, n_stores)
    X_test, _ = build_feature_matrix(
        test_data, breadcrumb_vectorizer, description_vectorizer,
        color_feature_cols, price_features, n_stores)

    y_train = train_data['target'].values
    y_test = test_data['target'].values
    print(f"X_train: {X_train.shape}, X_test: {X_test.shape}")

    # ---- Train models ----
    results = []

    print("\n--- Logistic Regression (L1) ---")
    model_l1 = LogisticRegressionCV(
        cv=CV_FOLDS, penalty='l1', solver='saga', max_iter=2000,
        multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE)
    model_l1.fit(X_train, y_train)
    y_pred_l1 = model_l1.predict(X_test)
    acc_l1 = accuracy_score(y_test, y_pred_l1)
    f1_l1 = f1_score(y_test, y_pred_l1, average='weighted')
    print(f"Accuracy: {acc_l1:.4f}, F1: {f1_l1:.4f}")
    results.append({'Model': 'L1 (LASSO)', 'Accuracy': acc_l1, 'F1_weighted': f1_l1})

    print("\n--- Logistic Regression (L2) ---")
    model_l2 = LogisticRegressionCV(
        cv=CV_FOLDS, penalty='l2', solver='lbfgs', max_iter=2000,
        multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE)
    model_l2.fit(X_train, y_train)
    y_pred_l2 = model_l2.predict(X_test)
    acc_l2 = accuracy_score(y_test, y_pred_l2)
    f1_l2 = f1_score(y_test, y_pred_l2, average='weighted')
    print(f"Accuracy: {acc_l2:.4f}, F1: {f1_l2:.4f}")
    results.append({'Model': 'L2 (Ridge)', 'Accuracy': acc_l2, 'F1_weighted': f1_l2})

    print("\n--- Random Forest ---")
    model_rf = RandomForestClassifier(
        n_estimators=200, max_depth=15, min_samples_split=10, min_samples_leaf=5,
        class_weight='balanced', random_state=RANDOM_STATE, n_jobs=-1)
    model_rf.fit(X_train, y_train)
    y_pred_rf = model_rf.predict(X_test)
    acc_rf = accuracy_score(y_test, y_pred_rf)
    f1_rf = f1_score(y_test, y_pred_rf, average='weighted')
    print(f"Accuracy: {acc_rf:.4f}, F1: {f1_rf:.4f}")
    results.append({'Model': 'Random Forest', 'Accuracy': acc_rf, 'F1_weighted': f1_rf})

    print("\n--- Histogram Gradient Boosting ---")
    MAX_HGB_SAMPLES = 50000
    if X_train.shape[0] > MAX_HGB_SAMPLES:
        sample_idx = np.random.choice(X_train.shape[0], MAX_HGB_SAMPLES, replace=False)
        X_train_hgb = X_train[sample_idx].toarray()
        y_train_hgb = y_train[sample_idx]
    else:
        X_train_hgb = X_train.toarray()
        y_train_hgb = y_train

    model_hgb = HistGradientBoostingClassifier(
        max_iter=200, max_depth=10, learning_rate=0.1, random_state=RANDOM_STATE)
    model_hgb.fit(X_train_hgb, y_train_hgb)
    y_pred_hgb = model_hgb.predict(X_test.toarray())
    acc_hgb = accuracy_score(y_test, y_pred_hgb)
    f1_hgb = f1_score(y_test, y_pred_hgb, average='weighted')
    print(f"Accuracy: {acc_hgb:.4f}, F1: {f1_hgb:.4f}")
    results.append({'Model': 'Hist Gradient Boosting', 'Accuracy': acc_hgb, 'F1_weighted': f1_hgb})

    print("\n--- SVM (RBF) ---")
    scaler = StandardScaler(with_mean=False)
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    model_svm = SVC(
        kernel='rbf', C=1.0, gamma='scale', class_weight='balanced',
        probability=True, random_state=RANDOM_STATE)
    model_svm.fit(X_train_scaled, y_train)
    y_pred_svm = model_svm.predict(X_test_scaled)
    acc_svm = accuracy_score(y_test, y_pred_svm)
    f1_svm = f1_score(y_test, y_pred_svm, average='weighted')
    print(f"Accuracy: {acc_svm:.4f}, F1: {f1_svm:.4f}")
    results.append({'Model': 'SVM', 'Accuracy': acc_svm, 'F1_weighted': f1_svm})

    results_df = pd.DataFrame(results).sort_values('F1_weighted', ascending=False)
    print(f"\n{'='*70}\nMODEL COMPARISON\n{'='*70}")
    print(results_df.to_string(index=False))
    results_df.to_csv(TEXT_DIR / 'model_comparison.csv', index=False)

    # ---- Best model analysis ----
    best_name = results_df.iloc[0]['Model']
    print(f"\nBest model: {best_name}")

    model_map = {
        'L1': (model_l1, y_pred_l1), 'L2': (model_l2, y_pred_l2),
        'Random': (model_rf, y_pred_rf), 'Hist': (model_hgb, y_pred_hgb),
        'SVM': (model_svm, y_pred_svm),
    }
    best_model, y_pred_best = model_l1, y_pred_l1  # default
    for key, (model, preds) in model_map.items():
        if key in best_name:
            best_model, y_pred_best = model, preds
            break

    if len(np.unique(y_test)) == 3:
        print("\nClassification Report:")
        print(classification_report(y_test, y_pred_best, target_names=['female', 'male', 'none']))

    cm = confusion_matrix(y_test, y_pred_best, labels=[0, 1, 2])
    print("Confusion Matrix:")
    print(f"            Predicted")
    print(f"            female  male  none")
    for i, label in enumerate(['female', 'male', 'none']):
        row = cm[i] if i < len(cm) else [0, 0, 0]
        print(f"Actual {label:6s}  {row[0]:4d}  {row[1]:4d}  {row[2]:4d}")

    # ---- Feature importance ----
    importance_df = pd.DataFrame({
        'feature': feature_names,
        'coef_female': model_l1.coef_[0],
        'coef_male': model_l1.coef_[1],
        'coef_none': model_l1.coef_[2],
    })
    importance_df['max_abs'] = importance_df[['coef_female', 'coef_male', 'coef_none']].abs().max(axis=1)
    importance_df = importance_df.sort_values('max_abs', ascending=False)

    for label, col in [('FEMALE', 'coef_female'), ('MALE', 'coef_male'), ('NONE', 'coef_none')]:
        top = importance_df[importance_df[col] > 0].nlargest(15, col)
        print(f"\nTop 15 {label} features:")
        for _, row in top.iterrows():
            print(f"  {row['feature']:40s}: {row[col]:+.4f}")

    importance_df.to_csv(TEXT_DIR / 'feature_importance.csv', index=False)

    # ---- Predict on all products ----
    X_all, _ = build_feature_matrix(
        df, breadcrumb_vectorizer, description_vectorizer,
        color_feature_cols, price_features, n_stores)

    df['ml_prob_female'] = model_l1.predict_proba(X_all)[:, 0]
    df['ml_prob_male'] = model_l1.predict_proba(X_all)[:, 1]
    df['ml_prob_none'] = model_l1.predict_proba(X_all)[:, 2]
    df['ml_pred'] = model_l1.predict(X_all)
    df['ml_pred_label'] = df['ml_pred'].map({0: 'female', 1: 'male', 2: 'none'})
    df['ml_confidence'] = df[['ml_prob_female', 'ml_prob_male', 'ml_prob_none']].max(axis=1)

    print(f"\nPrediction distribution:\n{df['ml_pred_label'].value_counts()}")

    # ---- Color feature impact ----
    X_train_no_color, _ = build_feature_matrix(
        train_data, breadcrumb_vectorizer, description_vectorizer,
        color_feature_cols, price_features, n_stores, include_colors=False)
    X_test_no_color, _ = build_feature_matrix(
        test_data, breadcrumb_vectorizer, description_vectorizer,
        color_feature_cols, price_features, n_stores, include_colors=False)

    model_no_color = LogisticRegressionCV(
        cv=CV_FOLDS, penalty='l1', solver='saga', max_iter=2000,
        multi_class='multinomial', class_weight='balanced', random_state=RANDOM_STATE)
    model_no_color.fit(X_train_no_color, y_train)
    y_pred_no_color = model_no_color.predict(X_test_no_color)

    acc_no_color = accuracy_score(y_test, y_pred_no_color)
    f1_no_color = f1_score(y_test, y_pred_no_color, average='weighted')
    print(f"\nColor impact:")
    print(f"  WITH colors:    Accuracy={acc_l1:.4f}, F1={f1_l1:.4f}")
    print(f"  WITHOUT colors: Accuracy={acc_no_color:.4f}, F1={f1_no_color:.4f}")
    print(f"  Delta:          Accuracy {(acc_l1-acc_no_color)*100:+.2f}pp, "
          f"F1 {(f1_l1-f1_no_color)*100:+.2f}pp")

    # ---- Morrisons-only HGB color ablation (fairer test) ----
    # The above comparison underestimates color's contribution because:
    #   - It uses L1 (weakest model) instead of HGB (best model)
    #   - Most training rows have zero-filled color columns (Tesco/ASDA)
    # Here we restrict to products with actual color data and use HGB.
    morrisons_pids = set(color_df[COL_PRODUCT_ID].values) if len(color_df) > 0 else set()
    train_morr = train_data[train_data[COL_PRODUCT_ID].isin(morrisons_pids)]
    test_morr = test_data[test_data[COL_PRODUCT_ID].isin(morrisons_pids)]

    if len(train_morr) >= 30 and len(test_morr) >= 10:
        print(f"\nMorrisons-only color ablation (HGB):  train={len(train_morr)}, test={len(test_morr)}")
        y_train_morr = train_morr['target'].values
        y_test_morr = test_morr['target'].values

        # With color
        X_train_morr_wc, _ = build_feature_matrix(
            train_morr, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=True)
        X_test_morr_wc, _ = build_feature_matrix(
            test_morr, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=True)

        hgb_morr_wc = HistGradientBoostingClassifier(
            max_iter=200, max_depth=10, learning_rate=0.1, random_state=RANDOM_STATE)
        hgb_morr_wc.fit(X_train_morr_wc.toarray(), y_train_morr)
        pred_morr_wc = hgb_morr_wc.predict(X_test_morr_wc.toarray())
        acc_morr_wc = accuracy_score(y_test_morr, pred_morr_wc)
        f1_morr_wc = f1_score(y_test_morr, pred_morr_wc, average='weighted')

        # Without color
        X_train_morr_nc, _ = build_feature_matrix(
            train_morr, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=False)
        X_test_morr_nc, _ = build_feature_matrix(
            test_morr, breadcrumb_vectorizer, description_vectorizer,
            color_feature_cols, price_features, n_stores, include_colors=False)

        hgb_morr_nc = HistGradientBoostingClassifier(
            max_iter=200, max_depth=10, learning_rate=0.1, random_state=RANDOM_STATE)
        hgb_morr_nc.fit(X_train_morr_nc.toarray(), y_train_morr)
        pred_morr_nc = hgb_morr_nc.predict(X_test_morr_nc.toarray())
        acc_morr_nc = accuracy_score(y_test_morr, pred_morr_nc)
        f1_morr_nc = f1_score(y_test_morr, pred_morr_nc, average='weighted')

        morr_delta_acc = (acc_morr_wc - acc_morr_nc) * 100
        morr_delta_f1 = (f1_morr_wc - f1_morr_nc) * 100

        print(f"  HGB WITH colors:    Accuracy={acc_morr_wc:.4f}, F1={f1_morr_wc:.4f}")
        print(f"  HGB WITHOUT colors: Accuracy={acc_morr_nc:.4f}, F1={f1_morr_nc:.4f}")
        print(f"  Delta (Morrisons HGB): Accuracy {morr_delta_acc:+.2f}pp, F1 {morr_delta_f1:+.2f}pp")

        # Save to summary later
        morr_color_ablation = {
            'n_train': int(len(train_morr)),
            'n_test': int(len(test_morr)),
            'hgb_with_colors_acc': float(acc_morr_wc),
            'hgb_with_colors_f1': float(f1_morr_wc),
            'hgb_without_colors_acc': float(acc_morr_nc),
            'hgb_without_colors_f1': float(f1_morr_nc),
            'delta_acc_pp': float(morr_delta_acc),
            'delta_f1_pp': float(morr_delta_f1),
        }
    else:
        print(f"\nMorrisons-only ablation skipped: insufficient data "
              f"(train={len(train_morr)}, test={len(test_morr)})")
        morr_color_ablation = None

    # ---- Validation vs human labels ----
    human_labeled = df[df['label_human'].notna()].copy()
    print(f"\nProducts with human labels: {len(human_labeled)}")
    if len(human_labeled) > 0:
        human_labeled['human_encoded'] = human_labeled['label_human'].map(
            {'female': 0, 'male': 1, 'none': 2})
        valid = human_labeled[human_labeled['human_encoded'].notna()]
        if len(valid) >= MIN_TEST_SAMPLES:
            acc_human = accuracy_score(valid['human_encoded'], valid['ml_pred'])
            print(f"Accuracy vs human: {acc_human:.4f}")

    # ---- Implicit gendering ----
    implicit_female = df[
        (df['label_extracted'] == 'none') &
        (df['ml_pred_label'] == 'female') &
        (df['ml_confidence'] > 0.5)]
    implicit_male = df[
        (df['label_extracted'] == 'none') &
        (df['ml_pred_label'] == 'male') &
        (df['ml_confidence'] > 0.5)]

    print(f"\nImplicit female (>50% conf): {len(implicit_female):,}")
    print(f"Implicit male (>50% conf): {len(implicit_male):,}")

    # ---- Export validation sample ----
    already_labeled = set()
    if COL_PRODUCT_ID in human_coded.columns:
        already_labeled.update(human_coded[COL_PRODUCT_ID].values)
    if len(your_labeled) > 0 and COL_PRODUCT_ID in your_labeled.columns:
        already_labeled.update(your_labeled[COL_PRODUCT_ID].values)

    available = df[
        (~df[COL_PRODUCT_ID].isin(already_labeled)) & (df[COL_IMAGE].notna())].copy()
    print(f"\nAvailable for validation: {len(available):,}")

    if len(available) > 0:
        N_PER = 85
        samples = []
        for pred, label in [(0, 'female'), (1, 'male'), (2, 'none')]:
            pool = available[available['ml_pred'] == pred]
            n = min(N_PER, len(pool))
            if n > 0:
                samples.append(pool.sample(n=n, random_state=RANDOM_STATE))
                print(f"  Sampled {n} {label}")

        if samples:
            validation = pd.concat(samples).sample(frac=1, random_state=RANDOM_STATE)
            export_cols = [
                COL_PRODUCT_ID, COL_NAME, COL_DESC, COL_BREADCRUMB, COL_IMAGE,
                COL_URL, COL_PRICE, 'label_extracted', 'ml_pred_label',
                'ml_prob_female', 'ml_prob_male', 'ml_prob_none', 'ml_confidence',
            ]
            export_cols = [c for c in export_cols if c in validation.columns]
            validation_export = validation[export_cols].copy()
            validation_export['manual_gender'] = ''
            validation_export['manual_confidence'] = ''
            validation_export['manual_notes'] = ''
            validation_export.to_csv(TEXT_DIR / 'validation_sample.csv', index=True)
            print(f"Saved validation sample: {len(validation_export)} products")

    # ---- Summary ----
    summary = {
        'version': '4.0',
        'data': {
            'original': original_count,
            'filtered': int(len(df)),
            'excluded': int(excluded_count),
            'training_samples': int(len(ml_data)),
            'color_samples': int(len(color_df)),
            'color_coverage_note': 'Morrisons only; Tesco/ASDA CDN links expired',
        },
        'models': results,
        'color_impact': {
            'full_sample_l1': {
                'with_colors_f1': float(f1_l1),
                'without_colors_f1': float(f1_no_color),
            },
            'morrisons_only_hgb': morr_color_ablation,
        },
        'predictions': {
            'female': int((df['ml_pred_label'] == 'female').sum()),
            'male': int((df['ml_pred_label'] == 'male').sum()),
            'none': int((df['ml_pred_label'] == 'none').sum()),
        },
    }
    with open(TEXT_DIR / 'summary.json', 'w') as f:
        json.dump(summary, f, indent=2)

    print(f"\nPipeline complete. Charts in {CHART_ML_DIR}, tables in {TEXT_DIR}")


# ============================================================================
# STAGE 3: REGRESSION ANALYSIS
# ============================================================================

def run_regression_analysis():
    """Stage 3: OLS, quantile, within-category, and by-store regressions."""
    import statsmodels.api as sm
    import statsmodels.formula.api as smf
    from sklearn.feature_extraction.text import TfidfVectorizer

    CHART_REG_DIR.mkdir(parents=True, exist_ok=True)

    print("=" * 70)
    print("STAGE 3: REGRESSION ANALYSIS")
    print("=" * 70)

    # ---- Load and prepare ----
    df = load_main_data()
    df = filter_excluded_categories(df)

    for col_label, col_source in [('label_bc', COL_BREADCRUMB),
                                   ('label_name', COL_NAME),
                                   ('label_desc', COL_DESC)]:
        df[col_label] = df[col_source].apply(extract_gender_from_text)

    def combine_labels(row):
        for col in ['label_bc', 'label_name', 'label_desc']:
            if row[col] in ('female', 'male'):
                return row[col]
        return 'none'

    df['gender'] = df.apply(combine_labels, axis=1)

    df['price_num'] = pd.to_numeric(df[COL_PRICE], errors='coerce')
    df = df[df['price_num'].notna() & (df['price_num'] > 0)].copy()
    df['log_price'] = np.log(df['price_num'])

    if COL_UNIT_PRICE in df.columns:
        df['unit_price_num'] = df[COL_UNIT_PRICE].astype(str).str.extract(r'([\d.]+)')[0].astype(float)

    df['store'] = df[COL_STORE].fillna('unknown').astype(str)

    # Infer store names
    store_names = {}
    for sid in df['store'].unique():
        sub = df[df['store'] == sid]
        bc_sample = sub[COL_BREADCRUMB].dropna().head(10).str.lower()
        if bc_sample.str.contains('morrisons').any():
            store_names[sid] = 'Morrisons'
        elif bc_sample.str.contains('tesco').any():
            store_names[sid] = 'Tesco'
        elif bc_sample.str.contains('asda').any():
            store_names[sid] = 'ASDA'
        else:
            store_names[sid] = f'Store {sid}'
    df['store_name'] = df['store'].map(store_names)

    # Breadcrumb parsing
    df[['cat1', 'cat2', 'cat3']] = df[COL_BREADCRUMB].apply(
        lambda x: pd.Series(parse_breadcrumb(x)))
    df['cat_broad'] = df['cat1']
    df['cat_mid'] = df['cat1'] + ' > ' + df['cat2']
    df['cat_fine'] = df['cat1'] + ' > ' + df['cat2'] + ' > ' + df['cat3']

    print(f"Products with valid prices: {len(df):,}")
    print(f"\nGender distribution:")
    for g in ['female', 'male', 'none', 'both']:
        sub = df[df['gender'] == g]
        if len(sub) > 0:
            print(f"  {g}: {len(sub):,}  (mean {sub['price_num'].mean():.2f}, "
                  f"median {sub['price_num'].median():.2f})")

    # ---- Analysis sample ----
    gendered = df[df['gender'].isin(['female', 'male'])].copy()
    gendered['is_female'] = (gendered['gender'] == 'female').astype(int)
    print(f"\nGendered sample: N = {len(gendered):,} "
          f"(F: {gendered['is_female'].sum()}, M: {(1-gendered['is_female']).sum()})")

    # ---- Regressions ----
    results_table = []

    def run_and_record(name, formula, data, controls, coef_name='is_female'):
        model = smf.ols(formula, data=data).fit(cov_type='HC1')
        coef = model.params[coef_name]
        se = model.bse[coef_name]
        pval = model.pvalues[coef_name]
        pct = (np.exp(coef) - 1) * 100
        ci_lo = coef - 1.96 * se
        ci_hi = coef + 1.96 * se
        results_table.append({
            'spec': name, 'coef': coef, 'se': se, 'p': pval, 'pct': pct,
            'ci_lo': ci_lo, 'ci_hi': ci_hi,
            'pct_lo': (np.exp(ci_lo) - 1) * 100,
            'pct_hi': (np.exp(ci_hi) - 1) * 100,
            'r2': model.rsquared, 'n': int(model.nobs), 'controls': controls,
        })
        sig = '***' if pval < 0.01 else ('**' if pval < 0.05 else ('*' if pval < 0.1 else ''))
        print(f"  {name}: coef={coef:+.4f} ({pct:+.1f}%), SE={se:.4f}, "
              f"p={pval:.4f}{sig}, R2={model.rsquared:.3f}, N={int(model.nobs)}")
        return model

    print("\nSpec 1: Raw gap")
    spec1 = run_and_record('(1) Raw gap', 'log_price ~ is_female', gendered, 'None')

    print("\nSpec 2: + Store FE")
    run_and_record('(2) + Store FE', 'log_price ~ is_female + C(store)', gendered, 'Store')

    # Broad category
    cat_counts = gendered['cat_broad'].value_counts()
    valid_broad = cat_counts[cat_counts >= 5].index
    gen_broad = gendered[gendered['cat_broad'].isin(valid_broad)].copy()
    print(f"\nSpec 3: + Broad cat FE (N cats: {len(valid_broad)})")
    run_and_record('(3) + Broad cat FE',
                   'log_price ~ is_female + C(store) + C(cat_broad)',
                   gen_broad, 'Store + Broad cat')

    # Mid category
    cat_counts = gendered['cat_mid'].value_counts()
    valid_mid = cat_counts[cat_counts >= 5].index
    gen_mid = gendered[gendered['cat_mid'].isin(valid_mid)].copy()
    print(f"\nSpec 4: + Mid cat FE (N cats: {len(valid_mid)})")
    run_and_record('(4) + Mid cat FE',
                   'log_price ~ is_female + C(store) + C(cat_mid)',
                   gen_mid, 'Store + Mid cat')

    # Fine category
    cat_counts = gendered['cat_fine'].value_counts()
    valid_fine = cat_counts[cat_counts >= 5].index
    gen_fine = gendered[gendered['cat_fine'].isin(valid_fine)].copy()
    print(f"\nSpec 5: + Fine cat FE (N cats: {len(valid_fine)})")
    if len(valid_fine) > 0 and len(gen_fine) > 50:
        run_and_record('(5) + Fine cat FE',
                       'log_price ~ is_female + C(store) + C(cat_fine)',
                       gen_fine, 'Store + Fine cat')

    # Description TF-IDF
    gen_mid['desc_clean'] = gen_mid[COL_DESC].apply(clean_text_remove_gender)
    desc_vec = TfidfVectorizer(max_features=100, min_df=5, max_df=0.9,
                                ngram_range=(1, 2), stop_words='english')
    X_desc = desc_vec.fit_transform(gen_mid['desc_clean'])

    y6 = gen_mid['log_price'].values
    X6_parts = [
        gen_mid[['is_female']].values,
        pd.get_dummies(gen_mid['store'], prefix='store', drop_first=True).values,
        pd.get_dummies(gen_mid['cat_mid'], prefix='cat', drop_first=True).values,
        X_desc.toarray(),
    ]
    X6 = sm.add_constant(np.hstack(X6_parts))
    spec6_model = sm.OLS(y6, X6).fit(cov_type='HC1')

    coef6 = spec6_model.params[1]
    se6 = spec6_model.bse[1]
    pval6 = spec6_model.pvalues[1]
    results_table.append({
        'spec': '(6) + Description', 'coef': coef6, 'se': se6, 'p': pval6,
        'pct': (np.exp(coef6) - 1) * 100,
        'ci_lo': coef6 - 1.96 * se6, 'ci_hi': coef6 + 1.96 * se6,
        'pct_lo': (np.exp(coef6 - 1.96 * se6) - 1) * 100,
        'pct_hi': (np.exp(coef6 + 1.96 * se6) - 1) * 100,
        'r2': spec6_model.rsquared, 'n': int(spec6_model.nobs),
        'controls': 'Store + Mid cat + Description',
    })
    sig6 = '***' if pval6 < 0.01 else ('**' if pval6 < 0.05 else ('*' if pval6 < 0.1 else ''))
    print(f"\nSpec 6: + Description TF-IDF")
    print(f"  (6) + Description: coef={coef6:+.4f} ({(np.exp(coef6)-1)*100:+.1f}%), "
          f"SE={se6:.4f}, p={pval6:.4f}{sig6}, R2={spec6_model.rsquared:.3f}")

    # Female x Store interaction
    print(f"\nSpec 7: Female x Store interaction")
    spec7 = smf.ols('log_price ~ is_female * C(store) + C(cat_mid)',
                     data=gen_mid).fit(cov_type='HC1')
    print(f"  Main effect (is_female): {spec7.params['is_female']:+.4f} "
          f"(p={spec7.pvalues['is_female']:.4f})")
    for param in spec7.params.index:
        if 'is_female:' in param:
            store_id = param.split('[T.')[1].rstrip(']')
            sname = store_names.get(store_id, store_id)
            total_effect = spec7.params['is_female'] + spec7.params[param]
            pct_effect = (np.exp(total_effect) - 1) * 100
            print(f"  {sname}: total female effect = {total_effect:+.4f} ({pct_effect:+.1f}%), "
                  f"interaction p={spec7.pvalues[param]:.4f}")

    # Unit price
    print(f"\nSpec 8: Unit price regression")
    if 'unit_price_num' in gen_mid.columns:
        gen_unit = gen_mid[gen_mid['unit_price_num'].notna() & (gen_mid['unit_price_num'] > 0)].copy()
        gen_unit['log_unit_price'] = np.log(gen_unit['unit_price_num'])
        if len(gen_unit) >= 50:
            run_and_record('(8) Unit price',
                           'log_unit_price ~ is_female + C(store) + C(cat_mid)',
                           gen_unit, 'Store + Mid cat (unit price)')

    # Three-way comparison
    print(f"\nSpec 9: Three-way comparison (none = reference)")
    df_valid = df.copy()
    df_valid['is_female'] = (df_valid['gender'] == 'female').astype(int)
    df_valid['is_male'] = (df_valid['gender'] == 'male').astype(int)
    spec9 = smf.ols('log_price ~ is_female + is_male + C(store)',
                     data=df_valid).fit(cov_type='HC1')
    print(f"  is_female: {spec9.params['is_female']:+.4f} (p={spec9.pvalues['is_female']:.4f})")
    print(f"  is_male:   {spec9.params['is_male']:+.4f} (p={spec9.pvalues['is_male']:.4f})")

    # ---- Quantile regression ----
    print(f"\n{'='*70}\nQUANTILE REGRESSION\n{'='*70}")
    quantiles = [0.10, 0.25, 0.50, 0.75, 0.90]
    qreg_results = []

    for q in quantiles:
        qmodel = smf.quantreg('log_price ~ is_female + C(store)', data=gendered).fit(q=q)
        coef_q = qmodel.params['is_female']
        se_q = qmodel.bse['is_female']
        pval_q = qmodel.pvalues['is_female']
        pct_q = (np.exp(coef_q) - 1) * 100
        qreg_results.append({
            'quantile': q, 'coef': coef_q, 'se': se_q, 'p': pval_q, 'pct': pct_q,
            'ci_lo': coef_q - 1.96 * se_q, 'ci_hi': coef_q + 1.96 * se_q,
        })
        sig = '***' if pval_q < 0.01 else ('**' if pval_q < 0.05 else ('*' if pval_q < 0.1 else ''))
        print(f"  Q{q:.2f}: coef={coef_q:+.4f} ({pct_q:+.1f}%), p={pval_q:.4f}{sig}")

    qreg_df = pd.DataFrame(qreg_results)

    print("\n  With mid-category controls:")
    qreg_cat_results = []
    for q in quantiles:
        try:
            qmodel = smf.quantreg('log_price ~ is_female + C(store) + C(cat_mid)',
                                    data=gen_mid).fit(q=q, max_iter=5000)
            coef_q = qmodel.params['is_female']
            se_q = qmodel.bse['is_female']
            pval_q = qmodel.pvalues['is_female']
            pct_q = (np.exp(coef_q) - 1) * 100
            qreg_cat_results.append({
                'quantile': q, 'coef': coef_q, 'se': se_q, 'p': pval_q, 'pct': pct_q,
                'ci_lo': coef_q - 1.96 * se_q, 'ci_hi': coef_q + 1.96 * se_q,
            })
            sig = '***' if pval_q < 0.01 else ('**' if pval_q < 0.05 else ('*' if pval_q < 0.1 else ''))
            print(f"  Q{q:.2f}: coef={coef_q:+.4f} ({pct_q:+.1f}%), p={pval_q:.4f}{sig}")
        except Exception as e:
            print(f"  Q{q:.2f}: failed ({e})")
    qreg_cat_df = pd.DataFrame(qreg_cat_results) if qreg_cat_results else pd.DataFrame()

    # ---- Within-category analysis (bootstrap CIs) ----
    print(f"\n{'='*70}\nWITHIN-CATEGORY ANALYSIS (bootstrap CIs)\n{'='*70}")
    rng = np.random.default_rng(RANDOM_STATE)
    category_gaps = []

    for cat in gendered['cat_mid'].unique():
        sub = gendered[gendered['cat_mid'] == cat]
        fem = sub[sub['is_female'] == 1]['price_num']
        mal = sub[sub['is_female'] == 0]['price_num']

        if len(fem) >= 3 and len(mal) >= 3:
            gap_pct = (fem.mean() / mal.mean() - 1) * 100
            log_gap = np.log(fem).mean() - np.log(mal).mean()

            boot_gaps = []
            for _ in range(N_BOOTSTRAP):
                f_boot = rng.choice(fem.values, size=len(fem), replace=True)
                m_boot = rng.choice(mal.values, size=len(mal), replace=True)
                if m_boot.mean() > 0:
                    boot_gaps.append((f_boot.mean() / m_boot.mean() - 1) * 100)

            ci_lo = np.percentile(boot_gaps, 2.5) if boot_gaps else np.nan
            ci_hi = np.percentile(boot_gaps, 97.5) if boot_gaps else np.nan

            category_gaps.append({
                'category': cat, 'n_female': len(fem), 'n_male': len(mal),
                'n_total': len(fem) + len(mal),
                'mean_female': fem.mean(), 'mean_male': mal.mean(),
                'gap_pct': gap_pct, 'log_gap': log_gap,
                'ci_lo': ci_lo, 'ci_hi': ci_hi,
                'significant': (ci_lo > 0 and ci_hi > 0) or (ci_lo < 0 and ci_hi < 0),
            })

    gaps_df = pd.DataFrame(category_gaps).sort_values('gap_pct', ascending=False)

    if len(gaps_df) > 0:
        print(f"Categories with both genders (>=3 each): {len(gaps_df)}")
        weighted_gap = np.average(gaps_df['gap_pct'], weights=gaps_df['n_total'])
        median_gap = gaps_df['gap_pct'].median()
        print(f"  Weighted mean: {weighted_gap:+.1f}%")
        print(f"  Median: {median_gap:+.1f}%")
        gaps_df.to_csv(TEXT_DIR / 'within_category_gaps.csv', index=False)
    else:
        median_gap = 0

    # ---- By-store analysis ----
    print(f"\n{'='*70}\nPINK TAX BY STORE\n{'='*70}")
    store_results = []
    for store_id in gendered['store'].unique():
        sub = gendered[gendered['store'] == store_id]
        fem = sub[sub['is_female'] == 1]
        mal = sub[sub['is_female'] == 0]
        sname = store_names.get(store_id, store_id)
        if len(fem) >= 10 and len(mal) >= 10:
            model = smf.ols('log_price ~ is_female', data=sub).fit(cov_type='HC1')
            coef = model.params['is_female']
            pval = model.pvalues['is_female']
            store_results.append({
                'store': sname, 'store_id': store_id,
                'n_female': len(fem), 'n_male': len(mal),
                'mean_f': fem['price_num'].mean(), 'mean_m': mal['price_num'].mean(),
                'coef': coef, 'pct_gap': (np.exp(coef) - 1) * 100,
                'p_value': pval, 'significant': pval < 0.05,
            })

    store_df = pd.DataFrame(store_results).sort_values('pct_gap', ascending=False)
    if len(store_df) > 0:
        for _, row in store_df.iterrows():
            sig = '***' if row['p_value'] < 0.01 else ('**' if row['p_value'] < 0.05 else '')
            print(f"  {row['store']:<15s} F:{row['n_female']:>4.0f} M:{row['n_male']:>4.0f} "
                  f"gap={row['pct_gap']:>+7.1f}% p={row['p_value']:.4f}{sig}")
        store_df.to_csv(TEXT_DIR / 'pink_tax_by_store.csv', index=False)

    # ---- Summary table ----
    summary_df = pd.DataFrame(results_table)
    print(f"\n{'='*70}\nREGRESSION SUMMARY\n{'='*70}")
    print(f"{'Spec':<25s} {'Coef':>8s} {'%gap':>8s} {'95% CI':>18s} {'p':>8s} {'R2':>6s} {'N':>6s}")
    for _, row in summary_df.iterrows():
        sig = '***' if row['p'] < 0.01 else ('**' if row['p'] < 0.05 else ('*' if row['p'] < 0.1 else ''))
        print(f"{row['spec']:<25s} {row['coef']:>+7.4f} {row['pct']:>+7.1f}% "
              f"[{row['pct_lo']:>+6.1f}, {row['pct_hi']:>+6.1f}] "
              f"{row['p']:>7.4f}{sig:<3s} {row['r2']:>5.3f} {row['n']:>6.0f}")
    summary_df.to_csv(TEXT_DIR / 'regression_summary.csv', index=False)

    # ---- Charts (13 interactive Plotly visualisations) ----
    print(f"\n{'='*70}\nGENERATING CHARTS (Plotly)\n{'='*70}")

    # 1. Coefficient plot
    specs = summary_df['spec'].values
    coefs = summary_df['coef'].values
    ci_los = summary_df['ci_lo'].values
    ci_his = summary_df['ci_hi'].values
    pct_vals = summary_df['pct'].values
    p_vals_arr = summary_df['p'].values
    colors_01 = [PALETTE['female'] if c > 0 else PALETTE['male'] for c in coefs]

    fig = go.Figure()
    fig.add_trace(go.Bar(
        y=specs, x=coefs, orientation='h',
        marker=dict(color=colors_01, opacity=0.75, line=dict(width=0.4, color='#333333')),
        error_x=dict(type='data', symmetric=False,
                     array=ci_his - coefs, arrayminus=coefs - ci_los,
                     color='#333333', thickness=1.0, width=3),
        hovertemplate='%{y}<br>Coef: %{x:.4f}<extra></extra>',
    ))
    for i in range(len(specs)):
        fw = 'bold' if p_vals_arr[i] < 0.05 else 'normal'
        fig.add_annotation(x=ci_his[i] + 0.01, y=specs[i],
                           text=f'{pct_vals[i]:+.1f}%', showarrow=False,
                           xanchor='left', font=dict(size=10, weight=fw))
    fig.add_vline(x=0, line=dict(color='#1a1a2e', width=0.8))
    fig.update_layout(**base_layout(height=STYLE['chart_height'],
                                     margin=STYLE['margin_bar']),
                       xaxis=styled_axis('Coefficient on female indicator (log price)'),
                       yaxis=styled_axis('', autorange='reversed'))
    save_html(fig, CHART_REG_DIR / '01_coefficient_plot.html')

    # 2. Quantile regression
    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.08,
                        subplot_titles=['Baseline', 'With category FE'])
    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['ci_hi'],
        mode='lines', line=dict(width=0), showlegend=False), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['ci_lo'],
        mode='lines', line=dict(width=0), fill='tonexty',
        fillcolor='rgba(196,78,82,0.15)', showlegend=False), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['coef'],
        mode='lines+markers', line=dict(color=PALETTE['female'], width=2),
        marker=dict(size=6), name='Baseline',
        hovertemplate='Q%{x:.2f}: %{y:.4f}<extra></extra>'), row=1, col=1)
    fig.add_hline(y=0, line=dict(color='black', width=0.8, dash='dash'), row=1, col=1)
    if len(qreg_cat_df) > 0:
        fig.add_trace(go.Scatter(
            x=qreg_cat_df['quantile'], y=qreg_cat_df['ci_hi'],
            mode='lines', line=dict(width=0), showlegend=False), row=1, col=2)
        fig.add_trace(go.Scatter(
            x=qreg_cat_df['quantile'], y=qreg_cat_df['ci_lo'],
            mode='lines', line=dict(width=0), fill='tonexty',
            fillcolor='rgba(196,78,82,0.15)', showlegend=False), row=1, col=2)
        fig.add_trace(go.Scatter(
            x=qreg_cat_df['quantile'], y=qreg_cat_df['coef'],
            mode='lines+markers', line=dict(color=PALETTE['female'], width=2),
            marker=dict(size=6), name='With FE',
            hovertemplate='Q%{x:.2f}: %{y:.4f}<extra></extra>'), row=1, col=2)
        fig.add_hline(y=0, line=dict(color='black', width=0.8, dash='dash'), row=1, col=2)
    fig.update_layout(**base_layout(height=STYLE['chart_height'], margin=dict(l=60, r=40, t=40, b=50)))
    fig.update_xaxes(**styled_axis('Quantile'))
    fig.update_yaxes(**styled_axis('Coefficient on female (log price)'))
    save_html(fig, CHART_REG_DIR / '02_quantile_regression.html')

    # 3. Within-category gaps
    if len(gaps_df) > 0:
        gaps_sorted = gaps_df.sort_values('gap_pct')
        labels_03 = [f"{r['category'][:42]} (F:{r['n_female']:.0f}, M:{r['n_male']:.0f})"
                     for _, r in gaps_sorted.iterrows()]
        colors_03 = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in gaps_sorted['gap_pct']]
        border_03 = ['black' if s else '#cccccc' for s in gaps_sorted['significant']]
        fig = go.Figure()
        fig.add_trace(go.Bar(
            y=labels_03, x=gaps_sorted['gap_pct'].values, orientation='h',
            marker=dict(color=colors_03, opacity=0.6,
                        line=dict(width=[1.0 if s else 0.3 for s in gaps_sorted['significant']],
                                  color=border_03)),
            error_x=dict(type='data', symmetric=False,
                         array=gaps_sorted['ci_hi'].values - gaps_sorted['gap_pct'].values,
                         arrayminus=gaps_sorted['gap_pct'].values - gaps_sorted['ci_lo'].values,
                         color='black', thickness=0.8, width=3),
            hovertemplate='%{y}<br>Gap: %{x:+.1f}%<extra></extra>',
        ))
        fig.add_vline(x=0, line=dict(color='black', width=1))
        h = max(STYLE['chart_height_tall'], len(gaps_sorted) * 28)
        fig.update_layout(**base_layout(height=h, margin=dict(l=350, r=60, t=20, b=50)),
                           xaxis=styled_axis('Female price premium (%)'),
                           yaxis=styled_axis(''))
        save_html(fig, CHART_REG_DIR / '03_within_category_gaps.html')

    # 4. Price distributions (4 panels)
    fig = make_subplots(rows=2, cols=2, horizontal_spacing=0.08, vertical_spacing=0.1,
                        subplot_titles=['Price density', 'Log price density',
                                        'Price box plot', 'Price CDF'])
    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
        sub = gendered[gendered['gender'] == gender]
        fig.add_trace(go.Histogram(
            x=sub['price_num'], nbinsx=50, histnorm='probability density',
            marker=dict(color=color, opacity=0.5),
            name=f'{gender.title()} (n={len(sub):,})', legendgroup=gender,
            hovertemplate='Price: %{x:.1f}<br>Density: %{y:.4f}<extra></extra>',
        ), row=1, col=1)
        fig.add_trace(go.Histogram(
            x=sub['log_price'], nbinsx=50, histnorm='probability density',
            marker=dict(color=color, opacity=0.5),
            name=f'{gender.title()}', legendgroup=gender, showlegend=False,
        ), row=1, col=2)
    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
        sub = gendered[gendered['gender'] == gender]
        fig.add_trace(go.Box(
            y=sub['price_num'], name=gender.title(), marker_color=color,
            boxmean=True, showlegend=False,
            hovertemplate='%{y:.1f}<extra></extra>',
        ), row=2, col=1)
    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
        sub = gendered[gendered['gender'] == gender]['price_num'].sort_values()
        cdf = np.arange(1, len(sub) + 1) / len(sub)
        fig.add_trace(go.Scatter(
            x=sub, y=cdf, mode='lines', line=dict(color=color, width=1.5),
            name=gender.title(), showlegend=False,
            hovertemplate='Price: %{x:.1f}<br>CDF: %{y:.2f}<extra></extra>',
        ), row=2, col=2)
    q95 = gendered['price_num'].quantile(0.95)
    fig.update_xaxes(range=[0, q95], row=1, col=1)
    fig.update_xaxes(range=[0, q95], row=2, col=2)
    fig.update_layout(**base_layout(height=700, margin=dict(l=60, r=40, t=40, b=50)),
                       barmode='overlay')
    save_html(fig, CHART_REG_DIR / '04_price_distributions.html')

    # 5. By store
    if len(store_df) > 0:
        fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.08,
                            subplot_titles=['Price premium by store', 'Mean price by store'])
        store_sorted = store_df.sort_values('pct_gap')
        colors_05 = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in store_sorted['pct_gap']]
        fig.add_trace(go.Bar(
            y=store_sorted['store'].values, x=store_sorted['pct_gap'].values, orientation='h',
            marker=dict(color=colors_05, opacity=0.7,
                        line=dict(width=[2 if s else 0.5 for s in store_sorted['significant']],
                                  color='black')),
            hovertemplate='%{y}<br>Gap: %{x:+.1f}%<extra></extra>', showlegend=False,
        ), row=1, col=1)
        fig.add_vline(x=0, line=dict(color='black', width=1), row=1, col=1)
        fig.add_trace(go.Bar(
            x=store_df['store'].values, y=store_df['mean_f'].values,
            marker=dict(color=PALETTE['female'], opacity=0.7), name='Female',
        ), row=1, col=2)
        fig.add_trace(go.Bar(
            x=store_df['store'].values, y=store_df['mean_m'].values,
            marker=dict(color=PALETTE['male'], opacity=0.7), name='Male',
        ), row=1, col=2)
        fig.update_layout(**base_layout(height=STYLE['chart_height'],
                                         margin=dict(l=160, r=40, t=40, b=50)),
                           barmode='group')
        save_html(fig, CHART_REG_DIR / '05_by_store.html')

    # 6. Scatter by category
    if len(gaps_df) > 0:
        colors_06 = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in gaps_df['gap_pct']]
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=gaps_df['mean_male'], y=gaps_df['mean_female'],
            mode='markers',
            marker=dict(size=np.sqrt(gaps_df['n_total'].values) * 2,
                        color=colors_06, opacity=0.6,
                        line=dict(width=0.5, color='black')),
            text=gaps_df['category'].str[:30],
            hovertemplate='%{text}<br>Male: %{x:.2f}<br>Female: %{y:.2f}<extra></extra>',
        ))
        lim_max = max(gaps_df['mean_male'].max(), gaps_df['mean_female'].max()) * 1.1
        fig.add_trace(go.Scatter(
            x=[0, lim_max], y=[0, lim_max], mode='lines',
            line=dict(color='black', width=0.8, dash='dash'),
            name='Equal price', hoverinfo='skip'))
        for _, row in gaps_df.nlargest(3, 'gap_pct').iterrows():
            fig.add_annotation(x=row['mean_male'], y=row['mean_female'],
                               text=row['category'][:30], showarrow=True,
                               arrowhead=0, ax=20, ay=-20,
                               font=dict(size=9, color='#555'))
        fig.update_layout(**base_layout(height=550, margin=dict(l=60, r=40, t=20, b=60)),
                           xaxis=styled_axis('Mean male price'),
                           yaxis=styled_axis('Mean female price'))
        save_html(fig, CHART_REG_DIR / '06_scatter_by_category.html')

    # 7. Three-way comparison
    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.08,
                        subplot_titles=['Log price density', 'Price box plot'])
    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male']),
                           ('none', PALETTE['none'])]:
        sub = df[df['gender'] == gender]
        fig.add_trace(go.Histogram(
            x=sub['log_price'], nbinsx=60, histnorm='probability density',
            marker=dict(color=color, opacity=0.4),
            name=f'{gender.title()} (n={len(sub):,})', legendgroup=gender,
        ), row=1, col=1)
    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male']),
                           ('none', PALETTE['none'])]:
        sub = df[df['gender'] == gender]
        fig.add_trace(go.Box(
            y=sub['price_num'], name=gender.title(), marker_color=color,
            showlegend=False, boxmean=True,
        ), row=1, col=2)
    fig.update_layout(**base_layout(height=STYLE['chart_height'],
                                     margin=dict(l=60, r=40, t=40, b=50)),
                       barmode='overlay')
    save_html(fig, CHART_REG_DIR / '07_three_way_comparison.html')

    # 8. Category composition
    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.08,
                        subplot_titles=['Female', 'Male'])
    for col_idx, gender in enumerate(['female', 'male'], 1):
        sub = gendered[gendered['gender'] == gender]
        top_cats = sub['cat_mid'].value_counts().head(12)
        fig.add_trace(go.Bar(
            y=[c[:40] for c in top_cats.index], x=top_cats.values, orientation='h',
            marker=dict(color=PALETTE[gender], opacity=0.7,
                        line=dict(width=0.3, color='black')),
            hovertemplate='%{y}<br>Count: %{x:,}<extra></extra>',
            showlegend=False,
        ), row=1, col=col_idx)
    fig.update_yaxes(autorange='reversed')
    fig.update_layout(**base_layout(height=STYLE['chart_height'],
                                     margin=dict(l=280, r=40, t=40, b=50)))
    fig.update_xaxes(**styled_axis('Number of products'))
    save_html(fig, CHART_REG_DIR / '08_category_composition.html')

    # 9. Gap distribution
    if len(gaps_df) > 0:
        med_gap = gaps_df['gap_pct'].median()
        fig = go.Figure()
        fig.add_trace(go.Histogram(
            x=gaps_df['gap_pct'], nbinsx=max(5, len(gaps_df) // 2),
            marker=dict(color=PALETTE['female'], opacity=0.6,
                        line=dict(width=0.5, color='black')),
            hovertemplate='Gap: %{x:.1f}%<br>Count: %{y}<extra></extra>',
        ))
        fig.add_vline(x=0, line=dict(color='black', width=1.5))
        fig.add_vline(x=med_gap, line=dict(color=PALETTE['female'], width=1.5, dash='dash'),
                      annotation_text=f'Median: {med_gap:+.1f}%',
                      annotation_position='top right')
        fig.update_layout(**base_layout(height=STYLE['chart_height_small']),
                           xaxis=styled_axis('Female price premium (%)'),
                           yaxis=styled_axis('Number of categories'))
        save_html(fig, CHART_REG_DIR / '09_gap_distribution.html')

    # 10. R2 progression
    fig = go.Figure()
    fig.add_trace(go.Bar(
        y=summary_df['spec'].values, x=summary_df['r2'].values, orientation='h',
        marker=dict(color='#555555', opacity=0.7, line=dict(width=0.5, color='black')),
        hovertemplate='%{y}<br>R\u00b2: %{x:.4f}<extra></extra>',
    ))
    for i, r2 in enumerate(summary_df['r2']):
        fig.add_annotation(x=r2 + 0.01, y=summary_df['spec'].values[i],
                           text=f'{r2:.3f}', showarrow=False,
                           xanchor='left', font=dict(size=10))
    fig.update_layout(**base_layout(height=STYLE['chart_height_small'],
                                     margin=STYLE['margin_bar']),
                       xaxis=styled_axis('R\u00b2'),
                       yaxis=styled_axis('', autorange='reversed'))
    save_html(fig, CHART_REG_DIR / '10_r2_progression.html')

    # 11. Heatmap: category x store
    heatmap_data = []
    for cat in gendered['cat_mid'].unique():
        for store_id in gendered['store'].unique():
            sub = gendered[(gendered['cat_mid'] == cat) & (gendered['store'] == store_id)]
            fem = sub[sub['is_female'] == 1]
            mal = sub[sub['is_female'] == 0]
            if len(fem) >= 2 and len(mal) >= 2:
                gap = (fem['price_num'].mean() / mal['price_num'].mean() - 1) * 100
                heatmap_data.append({
                    'category': cat[:35],
                    'store': store_names.get(store_id, store_id),
                    'gap': gap, 'n': len(fem) + len(mal),
                })

    if heatmap_data:
        heat_df = pd.DataFrame(heatmap_data)
        pivot = heat_df.pivot_table(values='gap', index='category', columns='store', aggfunc='mean')
        pivot = pivot.dropna(thresh=2)
        if len(pivot) > 0:
            finite_vals = pivot.values[np.isfinite(pivot.values)]
            max_abs = max(abs(finite_vals.min()), abs(finite_vals.max()))
            # Build text array handling NaN
            text_arr = []
            for row in pivot.values:
                text_row = []
                for v in row:
                    if np.isfinite(v):
                        text_row.append(f'{v:.0f}')
                    else:
                        text_row.append('')
                text_arr.append(text_row)
            fig = go.Figure(data=go.Heatmap(
                z=pivot.values, x=pivot.columns.tolist(), y=pivot.index.tolist(),
                colorscale='RdBu_r', zmid=0, zmin=-max_abs, zmax=max_abs,
                text=text_arr,
                texttemplate='%{text}', textfont=dict(size=10),
                colorbar=dict(title='F vs M gap (%)', **STYLE['colorbar']),
                hovertemplate='%{y}<br>%{x}<br>Gap: %{z:.1f}%<extra></extra>',
            ))
            h = max(STYLE['chart_height'], len(pivot) * 28)
            fig.update_layout(**base_layout(height=h, margin=dict(l=250, r=80, t=20, b=50)),
                               yaxis=styled_axis(''))
            save_html(fig, CHART_REG_DIR / '11_heatmap_category_store.html')

    # 12. Summary dashboard (6 panels)
    fig = make_subplots(
        rows=2, cols=3, horizontal_spacing=0.08, vertical_spacing=0.12,
        subplot_titles=['Key Finding', '% Gap by Specification', 'R\u00b2 by Specification',
                        'Log Price Density', 'Quantile Regression', 'Store Gaps'],
        specs=[[{'type': 'xy'}, {'type': 'xy'}, {'type': 'xy'}],
               [{'type': 'xy'}, {'type': 'xy'}, {'type': 'xy'}]],
    )
    raw_gap = (np.exp(spec1.params['is_female']) - 1) * 100
    ctrl_gap = summary_df.iloc[-1]['pct']
    fig.add_annotation(x=0.5, y=0.8, xref='x domain', yref='y domain',
                       text=f'Raw gap: {raw_gap:+.1f}%', showarrow=False,
                       font=dict(size=20, color=PALETTE['male'], weight='bold'),
                       row=1, col=1)
    fig.add_annotation(x=0.5, y=0.45, xref='x domain', yref='y domain',
                       text='(female products cheaper)', showarrow=False,
                       font=dict(size=11, color='gray'), row=1, col=1)
    fig.add_annotation(x=0.5, y=0.15, xref='x domain', yref='y domain',
                       text=f'After controls: {ctrl_gap:+.1f}% (n.s.)', showarrow=False,
                       font=dict(size=13, color='gray'), row=1, col=1)
    fig.update_xaxes(visible=False, row=1, col=1)
    fig.update_yaxes(visible=False, row=1, col=1)

    colors_12b = [PALETTE['female'] if c > 0 else PALETTE['male'] for c in summary_df['coef']]
    fig.add_trace(go.Bar(
        y=[s[:18] for s in summary_df['spec']], x=summary_df['pct'], orientation='h',
        marker=dict(color=colors_12b, opacity=0.7), showlegend=False,
    ), row=1, col=2)
    fig.add_vline(x=0, line=dict(color='black', width=1), row=1, col=2)
    fig.update_yaxes(autorange='reversed', row=1, col=2)

    fig.add_trace(go.Bar(
        y=[s[:18] for s in summary_df['spec']], x=summary_df['r2'], orientation='h',
        marker=dict(color='#666', opacity=0.7), showlegend=False,
    ), row=1, col=3)
    fig.update_yaxes(autorange='reversed', row=1, col=3)

    for gender, color in [('female', PALETTE['female']), ('male', PALETTE['male'])]:
        sub = gendered[gendered['gender'] == gender]
        fig.add_trace(go.Histogram(
            x=sub['log_price'], nbinsx=40, histnorm='probability density',
            marker=dict(color=color, opacity=0.5),
            name=gender.title(), showlegend=True, legendgroup=gender,
        ), row=2, col=1)

    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['ci_hi'],
        mode='lines', line=dict(width=0), showlegend=False), row=2, col=2)
    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['ci_lo'],
        mode='lines', line=dict(width=0), fill='tonexty',
        fillcolor='rgba(196,78,82,0.15)', showlegend=False), row=2, col=2)
    fig.add_trace(go.Scatter(
        x=qreg_df['quantile'], y=qreg_df['coef'],
        mode='lines+markers', line=dict(color=PALETTE['female'], width=2),
        marker=dict(size=5), showlegend=False), row=2, col=2)
    fig.add_hline(y=0, line=dict(color='black', width=0.8, dash='dash'), row=2, col=2)

    if len(store_df) > 0:
        colors_12f = [PALETTE['female'] if g > 0 else PALETTE['male'] for g in store_df['pct_gap']]
        fig.add_trace(go.Bar(
            y=store_df['store'].values, x=store_df['pct_gap'].values, orientation='h',
            marker=dict(color=colors_12f, opacity=0.7), showlegend=False,
        ), row=2, col=3)
        fig.add_vline(x=0, line=dict(color='black', width=1), row=2, col=3)
    fig.update_layout(**base_layout(height=750, margin=dict(l=140, r=40, t=40, b=50)),
                       barmode='overlay')
    save_html(fig, CHART_REG_DIR / '12_summary_dashboard.html')

    # 13. Waterfall decomposition chart
    if len(summary_df) >= 4:
        spec_labels = summary_df['spec'].values
        pct_gaps = summary_df['pct'].values
        p_vals_w = summary_df['p'].values
        n = len(spec_labels)
        deltas = np.zeros(n)
        starts = np.zeros(n)
        deltas[0] = pct_gaps[0]
        starts[0] = 0
        for i in range(1, n):
            deltas[i] = pct_gaps[i] - pct_gaps[i - 1]
            starts[i] = pct_gaps[i - 1]

        bar_colors_w = []
        for i in range(n):
            if i == 0:
                bar_colors_w.append(PALETTE['male'] if pct_gaps[0] < 0 else PALETTE['female'])
            elif i == n - 1:
                bar_colors_w.append('#555555')
            else:
                bar_colors_w.append(PALETTE['female'] if deltas[i] > 0 else PALETTE['male'])

        fig = go.Figure()
        fig.add_trace(go.Bar(
            y=spec_labels[:-1], x=deltas[:-1], orientation='h',
            base=starts[:-1],
            marker=dict(color=bar_colors_w[:-1], opacity=0.75,
                        line=dict(width=0.4, color='#333333')),
            hovertemplate='%{y}<br>Change: %{x:+.1f}%<extra></extra>',
            showlegend=False,
        ))
        fig.add_trace(go.Bar(
            y=[spec_labels[-1]], x=[pct_gaps[-1]], orientation='h',
            base=[0],
            marker=dict(color=bar_colors_w[-1], opacity=0.75,
                        line=dict(width=0.4, color='#333333')),
            hovertemplate='%{y}<br>Total: %{x:+.1f}%<extra></extra>',
            showlegend=False,
        ))
        for i in range(n):
            sig_mark = '' if p_vals_w[i] < 0.05 else ' (n.s.)'
            x_pos = pct_gaps[i]
            fig.add_annotation(x=x_pos + (0.8 if x_pos >= 0 else -0.8),
                               y=spec_labels[i],
                               text=f'{pct_gaps[i]:+.1f}%{sig_mark}',
                               showarrow=False,
                               xanchor='left' if x_pos >= 0 else 'right',
                               font=dict(size=10, weight=600, color='#333333'))
        fig.add_vline(x=0, line=dict(color='#1a1a2e', width=0.8))
        fig.update_layout(**base_layout(height=STYLE['chart_height'],
                                         margin=STYLE['margin_bar']),
                           xaxis=styled_axis('Female price gap (%)'),
                           yaxis=styled_axis('', autorange='reversed'))
        save_html(fig, CHART_REG_DIR / '13_waterfall_decomposition.html')

    # Save full summary JSON
    full_summary = {
        'raw_gap_pct': float((np.exp(spec1.params['is_female']) - 1) * 100),
        'raw_gap_p': float(spec1.pvalues['is_female']),
        'controlled_gap_pct': float(summary_df.iloc[-1]['pct']),
        'controlled_gap_p': float(summary_df.iloc[-1]['p']),
        'within_category_median_gap': float(median_gap) if len(gaps_df) > 0 else None,
        'n_gendered_products': int(len(gendered)),
        'n_female': int(gendered['is_female'].sum()),
        'n_male': int((1 - gendered['is_female']).sum()),
        'quantile_results': qreg_df.to_dict('records'),
        'store_results': store_df.to_dict('records') if len(store_df) > 0 else [],
    }
    with open(TEXT_DIR / 'full_summary.json', 'w') as f:
        json.dump(full_summary, f, indent=2, default=str)

    print(f"\nCharts saved to {CHART_REG_DIR}, tables in {TEXT_DIR}")


# ============================================================================
# STAGE 4: COLOR VISUALISATIONS
# ============================================================================

def run_color_charts():
    """Stage 4: color distribution, importance, and comparison charts (Plotly)."""
    CHART_VAL_DIR.mkdir(parents=True, exist_ok=True)
    IMPORTANCE_PATH = TEXT_DIR / 'feature_importance.csv'

    print("=" * 70)
    print("STAGE 4: COLOR VISUALISATIONS (Plotly)")
    print("=" * 70)

    color_df = pd.read_csv(COLOR_CACHE_PATH)
    print(f"Color cache: {len(color_df):,} products")

    has_importance = IMPORTANCE_PATH.exists()
    if has_importance:
        importance_df = pd.read_csv(IMPORTANCE_PATH)
        print(f"Feature importance: {len(importance_df)} features")
    else:
        importance_df = None
        print("No feature_importance.csv found -- will skip LASSO charts")

    # Build frequency table by gender
    color_gender = color_df[['label_extracted', 'color1_name']].copy()
    color_gender = color_gender[color_gender['label_extracted'].isin(['female', 'male', 'none'])]
    color_gender = color_gender.rename(columns={'color1_name': 'color'})

    freq = color_gender.groupby(['label_extracted', 'color']).size().unstack(fill_value=0)
    freq_pct = freq.div(freq.sum(axis=1), axis=0) * 100

    mask = (freq_pct > 1).any(axis=0)
    freq_pct_filtered = freq_pct.loc[:, mask].copy()
    col_order = freq_pct_filtered.sum().sort_values(ascending=False).index
    freq_pct_filtered = freq_pct_filtered[col_order]
    print(f"Colors with >1% share: {len(col_order)}")

    gender_labels = {'female': 'Female Products', 'male': 'Male Products', 'none': 'Neutral Products'}
    gender_accent = {'female': '#c44e52', 'male': '#4c72b0', 'none': '#777777'}

    # ---- Chart 1: Color distribution by gender ----
    print("\nGenerating color distribution chart...")
    fig = make_subplots(rows=1, cols=3, horizontal_spacing=0.06,
                        subplot_titles=[gender_labels.get(g, g) for g in ['female', 'male', 'none']])

    for col_idx, gender in enumerate(['female', 'male', 'none'], 1):
        if gender not in freq_pct_filtered.index:
            continue
        row = freq_pct_filtered.loc[gender].sort_values(ascending=True).tail(15)
        color_names = row.index.tolist()
        values = row.values
        n_total = int(freq.loc[gender].sum()) if gender in freq.index else 0

        fig.add_trace(go.Bar(
            y=[c.replace('_', ' ').title() for c in color_names],
            x=values, orientation='h',
            marker=dict(color=[rgb_hex(c) for c in color_names],
                        line=dict(width=0.8,
                                  color=[edge_color_for(c) if edge_color_for(c) != 'none'
                                         else 'rgba(0,0,0,0)' for c in color_names])),
            text=[f'{v:.1f}%' for v in values],
            textposition='outside', textfont=dict(size=9),
            hovertemplate='%{y}: %{x:.1f}%<extra></extra>',
            showlegend=False,
        ), row=1, col=col_idx)
        fig.add_annotation(x=0.97, y=0.03, xref=f'x{"" if col_idx == 1 else col_idx} domain', yref=f'y{"" if col_idx == 1 else col_idx} domain',
                           text=f'n = {n_total:,}', showarrow=False,
                           font=dict(size=9, color='#999999'), xanchor='right', yanchor='bottom')

    fig.update_layout(**base_layout(height=550, margin=dict(l=120, r=60, t=40, b=50)))
    fig.update_xaxes(**styled_axis('Share of products (%)'))
    save_html(fig, CHART_VAL_DIR / 'color_distribution_by_gender.html')

    # ---- Chart 2: Color importance (LASSO coefficients) ----
    if has_importance:
        print("\nGenerating color importance chart...")
        color_feats = importance_df[importance_df['feature'].str.startswith('feat_color1_')].copy()
        color_feats['color_name'] = color_feats['feature'].str.replace('feat_color1_', '', regex=False)
        color_feats['max_coef'] = color_feats[['coef_female', 'coef_male', 'coef_none']].abs().max(axis=1)
        color_feats = color_feats[color_feats['max_coef'] > 0.01].sort_values('max_coef', ascending=False)

        if len(color_feats) == 0:
            for slot in ['color2_', 'color3_']:
                extra = importance_df[importance_df['feature'].str.startswith(f'feat_{slot}')].copy()
                extra['color_name'] = extra['feature'].str.replace(f'feat_{slot}', '', regex=False)
                extra['max_coef'] = extra[['coef_female', 'coef_male', 'coef_none']].abs().max(axis=1)
                color_feats = pd.concat([color_feats, extra[extra['max_coef'] > 0.01]])

        if len(color_feats) > 0:
            agg = color_feats.groupby('color_name')[['coef_female', 'coef_male', 'coef_none']].sum()
            agg['max_abs'] = agg.abs().max(axis=1)
            agg = agg.sort_values('max_abs', ascending=False).head(18)

            coef_cols = [('coef_female', 'Female Predictors', '#c44e52'),
                         ('coef_male', 'Male Predictors', '#4c72b0'),
                         ('coef_none', 'Neutral Predictors', '#777777')]

            fig = make_subplots(rows=1, cols=3, shared_yaxes=True, horizontal_spacing=0.04,
                                subplot_titles=[lbl for _, lbl, _ in coef_cols])

            for col_idx, (col, label, accent) in enumerate(coef_cols, 1):
                sorted_data = agg[col].sort_values()
                color_names = sorted_data.index.tolist()
                values = sorted_data.values
                fig.add_trace(go.Bar(
                    y=[c.replace('_', ' ').title() for c in color_names],
                    x=values, orientation='h',
                    marker=dict(color=[rgb_hex(c) for c in color_names],
                                line=dict(width=0.8,
                                          color=[edge_color_for(c) if edge_color_for(c) != 'none'
                                                 else 'rgba(0,0,0,0)' for c in color_names])),
                    text=[f'{v:+.2f}' if abs(v) > 0.02 else '' for v in values],
                    textposition='outside', textfont=dict(size=8, color='#444444'),
                    hovertemplate='%{y}: %{x:+.3f}<extra></extra>',
                    showlegend=False,
                ), row=1, col=col_idx)
                fig.add_vline(x=0, line=dict(color='#1a1a1a', width=0.8), row=1, col=col_idx)

            fig.update_layout(**base_layout(height=600, margin=dict(l=130, r=60, t=40, b=50)))
            fig.update_xaxes(**styled_axis('LASSO coefficient'))
            save_html(fig, CHART_VAL_DIR / 'color_importance.html')

    # ---- Chart 3: Color importance heatmap ----
    if has_importance:
        print("\nGenerating color importance heatmap...")
        all_color = []
        for slot in range(1, 4):
            prefix = f'feat_color{slot}_'
            sub = importance_df[importance_df['feature'].str.startswith(prefix)].copy()
            sub['color_name'] = sub['feature'].str.replace(prefix, '', regex=False)
            sub['slot'] = slot
            all_color.append(sub)

        all_color_df = pd.concat(all_color)
        hm_data = all_color_df.groupby('color_name')[
            ['coef_female', 'coef_male', 'coef_none']].sum()
        hm_data.columns = ['Female', 'Male', 'Neutral']
        hm_data['max_abs'] = hm_data.abs().max(axis=1)
        hm_data = hm_data[hm_data['max_abs'] > 0.01].drop(columns='max_abs')
        hm_data = hm_data.sort_values('Female', ascending=True)

        if len(hm_data) > 0:
            max_val = max(abs(hm_data.values.min()), abs(hm_data.values.max()))
            fig = go.Figure(data=go.Heatmap(
                z=hm_data.values,
                x=hm_data.columns.tolist(),
                y=[c.replace('_', ' ').title() for c in hm_data.index.tolist()],
                colorscale='RdBu_r', zmid=0, zmin=-max_val, zmax=max_val,
                text=[[f'{v:+.2f}' if abs(v) > 0.01 else '' for v in row] for row in hm_data.values],
                texttemplate='%{text}', textfont=dict(size=9),
                colorbar=dict(title='L1 coefficient', **STYLE['colorbar']),
                hovertemplate='%{y}<br>%{x}: %{z:+.3f}<extra></extra>',
            ))
            h = max(STYLE['chart_height'], len(hm_data) * 26)
            fig.update_layout(**base_layout(height=h, margin=dict(l=160, r=80, t=20, b=50)),
                               yaxis=styled_axis(''))
            save_html(fig, CHART_VAL_DIR / 'color_importance_heatmap.html')

    # ---- Chart 4: Female vs male butterfly chart ----
    print("\nGenerating female vs male color comparison...")
    if 'female' in freq_pct_filtered.index and 'male' in freq_pct_filtered.index:
        fem = freq_pct_filtered.loc['female']
        mal = freq_pct_filtered.loc['male']
        all_colors = sorted(set(fem.index) | set(mal.index),
                            key=lambda c: fem.get(c, 0) + mal.get(c, 0), reverse=True)[:18]

        fig = go.Figure()
        y_labels = [c.replace('_', ' ').title() for c in all_colors]

        for i, cname in enumerate(all_colors):
            f_val = fem.get(cname, 0)
            m_val = mal.get(cname, 0)
            hex_c = rgb_hex(cname)
            ec = edge_color_for(cname) if edge_color_for(cname) != 'none' else 'rgba(0,0,0,0)'
            # Female (left, negative x)
            fig.add_trace(go.Bar(
                y=[y_labels[i]], x=[-f_val], orientation='h',
                marker=dict(color=hex_c, opacity=0.85, line=dict(width=0.6, color=ec)),
                text=f'{f_val:.1f}%' if f_val > 1 else '',
                textposition='outside', textfont=dict(size=8, color='#555'),
                hovertemplate=f'{cname}<br>Female: {f_val:.1f}%<extra></extra>',
                showlegend=False,
            ))
            # Male (right, positive x)
            fig.add_trace(go.Bar(
                y=[y_labels[i]], x=[m_val], orientation='h',
                marker=dict(color=hex_c, opacity=0.85, line=dict(width=0.6, color=ec)),
                text=f'{m_val:.1f}%' if m_val > 1 else '',
                textposition='outside', textfont=dict(size=8, color='#555'),
                hovertemplate=f'{cname}<br>Male: {m_val:.1f}%<extra></extra>',
                showlegend=False,
            ))

        fig.add_vline(x=0, line=dict(color='#1a1a1a', width=0.8))
        x_max = max(fem.max(), mal.max()) * 1.3
        fig.add_annotation(x=-x_max * 0.3, y=-0.08, yref='paper',
                           text='<b>\u2190 Female</b>', showarrow=False,
                           font=dict(size=11, color='#c44e52'))
        fig.add_annotation(x=x_max * 0.3, y=-0.08, yref='paper',
                           text='<b>Male \u2192</b>', showarrow=False,
                           font=dict(size=11, color='#4c72b0'))
        h = max(STYLE['chart_height'], len(all_colors) * 28)
        fig.update_layout(**base_layout(height=h, margin=dict(l=120, r=60, t=20, b=60)),
                           xaxis=styled_axis('Share of products (%)',
                                             range=[-x_max, x_max],
                                             tickvals=np.arange(-int(x_max), int(x_max)+1, 5).tolist(),
                                             ticktext=[str(abs(v)) for v in
                                                       np.arange(-int(x_max), int(x_max)+1, 5)]),
                           yaxis=styled_axis('', autorange='reversed'),
                           barmode='relative')
        save_html(fig, CHART_VAL_DIR / 'color_comparison_butterfly.html')

    print(f"\nAll charts saved to {CHART_VAL_DIR}")


# ============================================================================
# MAIN ENTRY POINT
# ============================================================================

STAGE_MAP = {
    '1': ('Color extraction', run_color_extraction),
    '2': ('ML gender prediction', run_ml_pipeline),
    '3': ('Regression analysis', run_regression_analysis),
    '4': ('Color visualisations', run_color_charts),
}


def run_pipeline(stages='all'):
    """
    Run one or more pipeline stages.

    Args:
        stages: '1', '2', '3', '4', 'all', or a list like ['2', '3'].
                Can also pass ints: run_pipeline(3) or run_pipeline([2, 3]).
    
    Examples (notebook):
        run_pipeline('all')
        run_pipeline(2)
        run_pipeline([2, 3])
        run_pipeline('3,4')
    """
    if stages == 'all':
        stage_list = ['1', '2', '3', '4']
    elif isinstance(stages, (list, tuple)):
        stage_list = [str(s) for s in stages]
    elif isinstance(stages, int):
        stage_list = [str(stages)]
    else:
        stage_list = [s.strip() for s in str(stages).split(',')]

    for stage in stage_list:
        if stage not in STAGE_MAP:
            print(f"Unknown stage: {stage}. Choose from 1, 2, 3, 4, or 'all'.")
            continue
        name, func = STAGE_MAP[stage]
        print(f"\n{'#' * 70}")
        print(f"# RUNNING STAGE {stage}: {name.upper()}")
        print(f"{'#' * 70}\n")
        try:
            func()
        except KeyboardInterrupt:
            print(f"\nStage {stage} interrupted.")
        except Exception as e:
            print(f"\nStage {stage} failed: {e}")
            import traceback
            traceback.print_exc()


if __name__ == '__main__':
    run_pipeline('2,3,4')
    print("Usage: run_pipeline(stages) where stages = 1, 2, 3, 4, or 'all'")
    print("  e.g. run_pipeline(2)  or  run_pipeline([2, 3])  or  run_pipeline('all')")



######################################################################
# RUNNING STAGE 2: ML GENDER PREDICTION
######################################################################

STAGE 2: ML GENDER PREDICTION PIPELINE
Main dataset: 21,436 products
Human-coded: 259 products
Your labeled: 44 products
Filtered: 21,436 -> 12,832 (excluded 8,604)
Label distribution: {'none': 10913, 'female': 1075, 'male': 844}
Human labels merged: 200
Color cache: 5,525 products
  Matched to filtered data: 5,525 / 5,525

Explicitly female: 1075, male: 844
Balancing to: 844 per class
Training data: 2532 (classes: {0: 850, 1: 845, 2: 837})
Train: 1905, Test: 639
Breadcrumb TF-IDF: 80 features
Description TF-IDF: 150 features
Color lookup built: 5,525 products
  Color features filled for 822/1905 rows
  Color features filled for 293/639 rows
Color features: 93 (available for 822/1905 train samples)
X_train: (1905, 329), X_test: (639, 329)

--- Logistic Regression (L1) ---
Accuracy: 0.6495, F1: 0.6462

---