In [None]:
# Efficiently sample 10% of a large CSV without loading into memory
# Adjust input/output paths as needed. This streams line-by-line.
! LC_ALL=C awk 'BEGIN{srand();} NR==1{print; next} { if (rand() <= 0.1) print }' data/transactions_train.csv > 'data/transactions_train_sample (10%).csv'
# This is a shell command that uses awk to sample 10% of the data without loading it all into memory.
# Usage: Run this in a Jupyter cell by prefixing with `!` or execute in the terminal, e.g., `!LC_ALL=C awk ...`.

# 01 - Data Preprocessing

This notebook refactors the large `JewelryDataPreprocessor` class into small, focused functions so it's easier to read and present in a Jupyter Book or on GitHub. Use each cell interactively to explain and run parts of the pipeline.

## Notebook structure
- Imports
- Load data helper
- Quick inspection
- Cleaning helpers (columns, missing values, types)
- Feature engineering
- Pipeline wrapper and example usage

In [32]:
# Imports
import pandas as pd  # 導入 pandas，用於資料操作
import numpy as np  # 導入 numpy，用於數值運算
from pathlib import Path  # 用於處理檔案路徑

In [33]:
def load_csv(path, **kwargs):  # 載入 CSV 的輕量包裝函式，回傳 DataFrame
    """Load a CSV into a DataFrame. Returns DataFrame.
    Keeps a thin wrapper so we can set defaults and handle common encodings.
    """
    path = Path(path)  # 將路徑轉為 Path 物件，方便後續操作
    if not path.exists():  # 若檔案不存在，拋出錯誤提醒
        raise FileNotFoundError(f'File not found: {path}')
    return pd.read_csv(path, **kwargs)  # 使用 pandas 讀取 CSV，支援傳入額外參數

In [34]:
def preview(df, n=5):  # 簡單顯示資料形狀與前幾列，方便快速檢查
    """Display basic info and top rows for quick inspection.
    Returns nothing; print summary.
    """
    print('Shape:', df.shape)  # 顯示 (rows, columns)
    print('\nColumn types:')  # 顯示欄位型別的前 20 個
    print(df.dtypes.head(20))
    display(df.head(n))  # 顯示前 n 列作為預覽

### Column cleanup helper
Small helpers to standardize column names and drop obviously useless columns.

In [35]:
def clean_column_names(df, lower=True, strip=True, replace_spaces='_'):  # 標準化欄位名稱
    cols = list(df.columns)  # 原欄位清單
    new_cols = []  # 建立新的欄位名稱清單
    for c in cols:
        nc = c
        if lower:
            nc = nc.lower()  # 轉為小寫
        if strip:
            nc = nc.strip()  # 去除前後空白
        if replace_spaces is not None:
            nc = nc.replace(' ', replace_spaces)  # 將空白改為底線或指定字元
        new_cols.append(nc)
    df.columns = new_cols  # 套用新的欄位名稱
    return df

def drop_columns(df, cols):  # 刪除不需要的欄位（安全檢查只有存在才刪）
    cols = [c for c in cols if c in df.columns]  # 只保留存在於 DataFrame 的欄位
    return df.drop(columns=cols)  # 回傳刪除後的 DataFrame

### Missing values and type conversion
Helpers to fill or drop missing values and convert columns to appropriate types.

In [36]:
def fill_missing(df, strategy=None, value=None, cols=None):  # 處理遺失值的通用函式
    """Fill missing values.
    - If `strategy` is 'median' or 'mean', uses numeric column aggregation per-col.
    - If `value` is provided, uses that for all selected cols.
    """
    if cols is None:
        cols = df.columns.tolist()  # 預設處理所有欄位
    if strategy in ('median','mean'):
        for c in cols:
            if pd.api.types.is_numeric_dtype(df[c]):
                agg = getattr(df[c].median if strategy=='median' else df[c].mean, '__call__')()  # 計算中位數或平均數
                df[c] = df[c].fillna(agg)  # 用聚合值填補遺失值
    elif value is not None:
        df[cols] = df[cols].fillna(value)  # 使用指定常數填補
    else:
        df = df.dropna(subset=cols)  # 若無策略則直接移除缺失值列
    return df

def convert_types(df, conversions):  # 轉換欄位型別，`conversions` 為字典
    """Convert column types using a dict `conversions`, e.g. {'date_col':'datetime64', 'id':'int'}
    """
    for c, t in conversions.items():
        if c not in df.columns:
            continue  # 若欄位不存在則跳過
        try:
            if t == 'datetime':
                df[c] = pd.to_datetime(df[c], errors='coerce')  # 轉為 datetime，失敗會變成 NaT
            else:
                df[c] = df[c].astype(t)  # 嘗試使用 pandas astype 轉換
        except Exception as e:
            print(f'Could not convert {c} to {t}:', e)  # 若失敗則顯示錯誤
    return df

### Feature engineering
Keep this focused and small; add functions as needed for your analysis.

In [37]:
def add_datetime_parts(df, date_col, drop=True):  # 從日期欄位拆出年/月/日
    df = df.copy()  # 複製避免修改原始 DataFrame
    if date_col not in df.columns:
        raise KeyError(date_col)  # 欄位不存在則丟出錯誤
    df[date_col] = pd.to_datetime(df[date_col], errors='coerce')  # 轉為 datetime
    df[date_col + '_year'] = df[date_col].dt.year  # 年
    df[date_col + '_month'] = df[date_col].dt.month  # 月
    df[date_col + '_day'] = df[date_col].dt.day  # 日
    if drop:
        df = df.drop(columns=[date_col])  # 若需要則移除原始日期欄
    return df

def simple_encoding(df, cols):  # 簡單的 one-hot 編碼（適用欄位種類少的類別欄）
    """One-hot encode small-cardinality categorical columns.
    Returns a new DataFrame with encoded columns.
    """
    return pd.get_dummies(df, columns=cols, dummy_na=False)  # 回傳編碼後的新 DataFrame

### Pipeline wrapper
A small pipeline that composes the helpers above for repeatable preprocessing.

In [38]:
def run_preprocessing(input_path, output_path=None, *,  # 將前面 helper 串成一個執行流程
                      drop_cols=None,
                      fill_strategy=None, fill_value=None,
                      type_conversions=None,
                      date_cols=None, encode_cols=None):
    df = load_csv(input_path)  # 讀取資料
    df = clean_column_names(df)  # 標準化欄位名稱
    if drop_cols:
        df = drop_columns(df, drop_cols)  # 刪除指定欄位
    if fill_strategy or fill_value is not None:
        df = fill_missing(df, strategy=fill_strategy, value=fill_value)  # 處理遺失值
    if type_conversions:
        df = convert_types(df, type_conversions)  # 轉換欄位型別
    if date_cols:
        for d in date_cols:
            df = add_datetime_parts(df, d)  # 拆出日期部位
    if encode_cols:
        df = simple_encoding(df, encode_cols)  # 類別欄位編碼
    if output_path is not None:
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)  # 確保輸出資料夾存在
        df.to_csv(output_path, index=False)  # 寫出清理後的 CSV
    return df  # 回傳處理後的 DataFrame

In [39]:
# Example usage (uncomment and set real paths to run)
# input_path = 'data/jewelry_sales_data.csv'  # 範例輸入路徑
# out = run_preprocessing(input_path, output_path='data/jewelry_cleaned_data.csv',
#                     drop_cols=['unnamed: 0'],  # 範例要刪除的欄位
#                     fill_strategy='median',  # 使用中位數填補遺失值
#                     type_conversions={'transaction_date':'datetime'},  # 欄位型別轉換
#                     date_cols=['transaction_date'],  # 要拆解的日期欄位
#                     encode_cols=['category'])  # 要進行編碼的類別欄
# preview(out)  # 顯示處理後的 DataFrame 預覽
print('Example usage cell: edit paths and uncomment to run')  # 提示文字

Example usage cell: edit paths and uncomment to run


**Next steps**:
- Replace placeholders to match your column names from `01_Data_Preprocessing.py`.
- Add any domain-specific cleaning (pricing rules, SKU parsing) as small functions and keep them in their own cells so the Jupyter Book renders them clearly.
- If you prefer, move the clean functions into a small `src/preprocessing.py` module and import them in the notebook for brevity.

In [40]:
# Gather slide-ready summaries and save to files
import json  # 用於輸出 JSON
from pathlib import Path  # 已在上方引入，但在函式區域再確認一次


def gather_slide_data(input_csv, out_json='results/slide_02_03_data.json', out_md='results/slide_02_03.md'):
    """Compute a compact set of metrics for Slides 2-3 and write JSON and Markdown outputs.
    Relies on helper functions defined earlier in this notebook: `load_csv`, `clean_column_names`.
    Adds richer missingness summary and item/customer statistics for jewelry analysis.
    """
    Path('results').mkdir(parents=True, exist_ok=True)  # 確保 results 資料夾存在
    df = load_csv(input_csv, low_memory=True)  # 讀取 CSV（low_memory 以節省記憶體）
    df = clean_column_names(df)  # 標準化欄位名稱

    # If a price column exists and needs scaling the caller may have already scaled;
    # do not multiply here unconditionally. (caller may pass scaled file.)

    metrics = {}  # 儲存所有統計指標的字典
    metrics['rows'] = int(len(df))  # 總列數（總交易量）
    metrics['columns'] = df.columns.tolist()  # 欄位清單

    # customers
    if 'customer_id' in df.columns:
        metrics['unique_customers'] = int(df['customer_id'].nunique())  # 不重複客戶數

    # date range (pick a column containing "date")
    date_cols = [c for c in df.columns if 'date' in c]
    if date_cols:
        dcol = date_cols[0]
        dser = pd.to_datetime(df[dcol], errors='coerce')
        metrics['date_column'] = dcol
        metrics['date_min'] = str(dser.min())
        metrics['date_max'] = str(dser.max())

    # revenue / price
    revenue_candidates = [c for c in df.columns if c.lower() in ('revenue','amount','price','total','sales')]
    rcol = None
    if revenue_candidates:
        rcol = revenue_candidates[0]
        metrics['revenue_column'] = rcol
        # coerce to numeric for safety
        rev = pd.to_numeric(df[rcol], errors='coerce')
        metrics['total_revenue'] = float(rev.sum(skipna=True))
        metrics['median_order_value'] = float(rev.median(skipna=True))
        # richer revenue stats
        metrics['revenue_stats'] = {
            'mean': float(rev.mean(skipna=True)),
            'std': float(rev.std(skipna=True)),
            'min': float(rev.min(skipna=True)) if rev.count() else None,
            'max': float(rev.max(skipna=True)) if rev.count() else None,
        }

    # top categories
    if 'category' in df.columns:
        top = df['category'].value_counts().head(10)
        metrics['top_categories'] = top.to_dict()
        top.to_csv('results/top_categories_slide_02_03.csv')

    # attempt to identify an item/product column for item-level summaries
    item_candidates = [c for c in df.columns if any(x in c for x in ('item','product','article','sku','title','name'))]
    item_col = item_candidates[0] if item_candidates else None
    if item_col:
        top_items = df[item_col].value_counts().head(200)
        metrics['unique_items'] = int(df[item_col].nunique())
        metrics['top_items_by_count'] = top_items.to_dict()
        top_items.head(100).to_csv('results/top_items_by_count.csv')

    # top customers by revenue (if possible)
    if rcol and 'customer_id' in df.columns:
        grp = df.groupby('customer_id')[rcol].apply(lambda s: pd.to_numeric(s, errors='coerce').sum(skipna=True))
        top_cust = grp.sort_values(ascending=False).head(200)
        metrics['top_customers_by_revenue'] = top_cust.to_dict()
        top_cust.head(100).to_csv('results/top_customers_by_revenue.csv', header=[rcol])

    # missingness: fraction, count, non-null
    fraction = (df.isnull().mean()).round(3)
    count = df.isnull().sum()
    non_null = df.notnull().sum()
    metrics['missingness'] = {
        'fraction': fraction.to_dict(),
        'count': count.to_dict(),
        'non_null': non_null.to_dict(),
    }

    # Jewelry-specific statistics: try to load articles and compute jewelry metrics when possible
    jewelry_metrics = {}
    articles_path = Path('data') / 'articles.csv'
    try:
        if articles_path.exists() and 'article_id' in df.columns:
            articles_df = load_csv(articles_path)
            # normalize name column candidates
            cols = [c for c in articles_df.columns]
            name_col = None
            if 'product_type_name' in cols:
                name_col = 'product_type_name'
            elif 'product_group_name' in cols:
                name_col = 'product_group_name'
            # find jewelry articles using name_col if available
            jewelry_keywords = ['jewelry', 'jewellery', 'necklace', 'bracelet', 'earring', 'ring', 'pendant', 'charm', 'brooch', 'cufflink']
            if name_col:
                mask = articles_df[name_col].astype(str).str.lower().str.contains('|'.join(jewelry_keywords), na=False)
                jewelry_articles = articles_df[mask]
            else:
                jewelry_articles = pd.DataFrame(columns=articles_df.columns)

            # transactions that match jewelry articles
            if not jewelry_articles.empty:
                jew_tx = df[df['article_id'].isin(jewelry_articles['article_id'])]
                jewelry_metrics['jewelry_transaction_count'] = int(len(jew_tx))
                if 'customer_id' in jew_tx.columns:
                    jewelry_metrics['jewelry_unique_customers'] = int(jew_tx['customer_id'].nunique())
                if rcol and not jew_tx.empty:
                    jr = pd.to_numeric(jew_tx[rcol], errors='coerce')
                    jewelry_metrics['jewelry_total_revenue'] = float(jr.sum(skipna=True))
                    jewelry_metrics['jewelry_median_order_value'] = float(jr.median(skipna=True))
                # top jewelry items by count and revenue if item_col present
                if item_col and item_col in jew_tx.columns:
                    top_j_items = jew_tx[item_col].value_counts().head(50)
                    jewelry_metrics['top_jewelry_items_by_count'] = top_j_items.to_dict()
                    top_j_items.to_csv('results/top_jewelry_items_by_count.csv')
                if rcol and 'customer_id' in jew_tx.columns:
                    cust_jrev = jew_tx.groupby('customer_id')[rcol].apply(lambda s: pd.to_numeric(s, errors='coerce').sum(skipna=True))
                    jewelry_metrics['top_jewelry_customers_by_revenue'] = cust_jrev.sort_values(ascending=False).head(50).to_dict()
    except Exception as e:
        # if anything fails, record the error for debugging
        jewelry_metrics['error'] = str(e)

    if jewelry_metrics:
        metrics['jewelry'] = jewelry_metrics

    # small sample rows to preview in slides
    sample_rows = df.sample(n=min(50, len(df)), random_state=1) if len(df) > 0 else df
    sample_rows.head(20).to_csv('results/sample_preview_slide_02_03.csv', index=False)

    # write json and markdown
    with open(out_json, 'w') as f:
        json.dump(metrics, f, indent=2, default=str)

    # build markdown summary
    md_lines = [
        '# Data Foundation & Key Metrics (Slides 2-3)\\n',
        f'- Rows: {metrics.get("rows")}',
    ]
    if metrics.get('unique_customers') is not None:
        md_lines.append(f'- Unique customers: {metrics.get("unique_customers")}')
    if metrics.get('date_column'):
        md_lines.append(f'- Date range ({metrics.get("date_column")}): {metrics.get("date_min")} → {metrics.get("date_max")})')
    if metrics.get('total_revenue') is not None:
        md_lines.append(f'- Total revenue ({metrics.get("revenue_column")}): {metrics.get("total_revenue"): .2f}')
        md_lines.append(f'- Median order value: {metrics.get("median_order_value"): .2f}')
    if metrics.get('revenue_stats'):
        rs = metrics['revenue_stats']
        md_lines.append(f'- Revenue mean: {rs.get("mean"): .2f} | std: {rs.get("std"): .2f} | min: {rs.get("min")} | max: {rs.get("max")}')
    if metrics.get('unique_items') is not None:
        md_lines.append(f'- Unique items: {metrics.get("unique_items")}')

    md_lines.append('\n## Top categories (top 10)')
    if 'top_categories' in metrics and metrics['top_categories']:
        for k, v in metrics['top_categories'].items():
            md_lines.append(f'- {k}: {v}')

    md_lines.append('\n## Top items (by count)')
    if 'top_items_by_count' in metrics and metrics['top_items_by_count']:
        for k, v in list(metrics['top_items_by_count'].items())[:10]:
            md_lines.append(f'- {k}: {v}')

    md_lines.append('\n## Top customers (by revenue)')
    if 'top_customers_by_revenue' in metrics and metrics['top_customers_by_revenue']:
        for k, v in list(metrics['top_customers_by_revenue'].items())[:10]:
            md_lines.append(f'- {k}: {v:.2f}')

    # jewelry-specific section
    if 'jewelry' in metrics:
        jm = metrics['jewelry']
        md_lines.append('\n## Jewelry-specific metrics')
        if 'jewelry_transaction_count' in jm:
            md_lines.append(f'- Jewelry transactions: {jm.get("jewelry_transaction_count")}')
        if 'jewelry_unique_customers' in jm:
            md_lines.append(f'- Jewelry unique customers: {jm.get("jewelry_unique_customers")}')
        if 'jewelry_total_revenue' in jm:
            md_lines.append(f'- Jewelry total revenue: {jm.get("jewelry_total_revenue"): .2f}')
        if 'jewelry_median_order_value' in jm:
            md_lines.append(f'- Jewelry median order value: {jm.get("jewelry_median_order_value"): .2f}')
        if 'top_jewelry_items_by_count' in jm:
            md_lines.append('\nTop jewelry items (by count):')
            for k, v in list(jm['top_jewelry_items_by_count'].items())[:10]:
                md_lines.append(f'- {k}: {v}')
        if 'top_jewelry_customers_by_revenue' in jm:
            md_lines.append('\nTop jewelry customers (by revenue):')
            for k, v in list(jm['top_jewelry_customers_by_revenue'].items())[:10]:
                md_lines.append(f'- {k}: {v:.2f}')

    md_lines.append('\n## Missingness (fraction / count / non-null)')
    miss = metrics.get('missingness', {})
    for k, v in list(miss.get('fraction', {}).items())[:10]:
        md_lines.append(f'- {k}: fraction={v} count={miss.get("count", {}).get(k, None)} non_null={miss.get("non_null", {}).get(k, None)}')

    md_lines.append('\n## Notes')
    md_lines.append('- Missingness and column list saved to JSON.')
    md_lines.append('- CSV previews in results/ for top categories, items, and customers.')

    with open(out_md, 'w') as f:
        f.write('\n'.join(md_lines))

    print('Wrote:', out_json, out_md)
    return metrics


# Example call (uncomment to run):  # 範例呼叫：把註解拿掉即可執行
# gather_slide_data('data/transactions_train_sample.csv')



# Adjust the path below if your sample filename differs

# sample_path = 'data/transactions_train_sample.csv'
# sample_path = '/Users/jzou/Library/Mobile Documents/com~apple~CloudDocs/data/transactions_train_sample.csv'

# print('Running gather_slide_data on', sample_path)
# metrics = gather_slide_data(sample_path)
# print('Done. Metrics keys:', list(metrics.keys()))

In [41]:
# Integrate figure creation and price-scaling into the notebook
import matplotlib.pyplot as plt
import os

navy_blue = '#0A2463'
accent_gold = '#D4AF37'
white = '#FFFFFF'


def scale_price_in_transactions(in_path, out_path, factor=590):
    """Multiply numeric price-like columns by `factor` and write a new CSV."""
    df = load_csv(in_path)
    # detect candidate price/revenue columns
    candidates = [c for c in df.columns if c.lower() in ('revenue','amount','price','total','sales')]
    if not candidates:
        # try columns that contain 'price' or 'amount'
        candidates = [c for c in df.columns if 'price' in c.lower() or 'amount' in c.lower()]
    for c in candidates:
        df[c] = pd.to_numeric(df[c], errors='coerce') * factor
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_path, index=False)
    print(f'Scaled {len(candidates)} columns by {factor} and wrote: {out_path}')
    return out_path


def load_three_figures_data(transactions_path, data_dir='data'):
    """Load customers, articles, and transactions (transactions_path can be custom)."""
    customers_df = load_csv(Path(data_dir) / 'customers.csv')
    articles_df = load_csv(Path(data_dir) / 'articles.csv')
    transactions_df = load_csv(transactions_path)

    jewelry_keywords = ['jewelry', 'jewellery', 'necklace', 'bracelet', 'earring', 'ring', 'pendant', 'charm', 'brooch', 'cufflink']

    if 'product_type_name' in articles_df.columns:
        articles_df['is_jewelry'] = articles_df['product_type_name'].astype(str).str.lower().str.contains(
            '|'.join(jewelry_keywords), na=False
        )
    else:
        articles_df['is_jewelry'] = False
    if 'product_type_name' in articles_df.columns:
        print(f'\n--- Top Jewelry Types ---')
        jewelry_types = articles_df[articles_df['is_jewelry']]['product_type_name'].value_counts()
        print(jewelry_types)

    # Create jewelry articles reference
    # 建立珠寶商品參考資料
    jewelry_article_ids = set(articles_df[articles_df['is_jewelry']]['article_id'].unique())
    print(f'\nJewelry article IDs: {len(jewelry_article_ids)} unique items')

    # Identify transactions that are jewelry purchases
    # 識別珠寶購買交易
    transactions_df['is_jewelry_purchase'] = transactions_df['article_id'].isin(jewelry_article_ids)
    jewelry_transactions = transactions_df[transactions_df['is_jewelry_purchase']]
    jewelry_customers = set(jewelry_transactions['customer_id'].unique()) if 'customer_id' in jewelry_transactions.columns else set()

    customers_df['is_jewelry_buyer'] = customers_df['customer_id'].isin(jewelry_customers)
    customers_df.drop_duplicates(subset=['customer_id'], inplace=True)

    jewelry_buyers = customers_df[customers_df['is_jewelry_buyer']]
    all_customers = customers_df

    articles_df.to_csv(os.path.join(data_dir,'articles_with_jewelry_flag.csv'), index=False)
    customers_df.to_csv(os.path.join(data_dir,'customers_with_jewelry_flag.csv'), index=False)
    transactions_df.to_csv(os.path.join(data_dir,'transactions_with_jewelry_flag.csv'), index=False)
    print(f"Shape of customers_df: {customers_df.shape} | jewelry_buyers: {jewelry_buyers.shape} | all_customers: {all_customers.shape}")
    return jewelry_buyers, all_customers


# Chart functions (adapted from src/create_three_figures.py)
def create_market_penetration_chart(jewelry_buyers, all_customers, out_path='results/figure1_market_penetration_by_01ipynb.png'):
    fig, ax = plt.subplots(figsize=(8, 8))
    fig.patch.set_facecolor(white)

    customer_categories = ['Jewelry Buyers', 'Non-Jewelry Buyers']

    total_customers = all_customers['customer_id'].nunique()
    total_jewelry_buyers = jewelry_buyers['customer_id'].nunique()

    customer_counts = [total_jewelry_buyers, total_customers - total_jewelry_buyers]
    print(f"For plotting purpuroses; \n\nCustomer categories: {customer_categories}")
    print(f"Customer counts: {customer_counts}")
    colors_cust = [accent_gold, navy_blue]

    wedges, texts, autotexts = ax.pie(
        customer_counts, labels=customer_categories, colors=colors_cust,
        autopct='%1.1f%%', startangle=90,
        wedgeprops=dict(edgecolor='white', linewidth=2),
        textprops={'fontsize': 11, 'fontweight': 'bold'}
    )
    for autotext in autotexts:
        autotext.set_color('white')
        autotext.set_fontweight('bold')
    ax.set_title('Customer Penetration', fontsize=12, fontweight='bold', color=navy_blue, pad=10)

    # jewelry_pct = (len(jewelry_buyers) / len(all_customers)) * 100 if len(all_customers) else 0
    # non_jewelry_pct = 100 - jewelry_pct
    # sizes = [non_jewelry_pct, jewelry_pct]
    # colors = [navy_blue, accent_gold]
    # labels = ['Non-Jewelry Buyers', 'Jewelry Buyers']
    # wedges, texts, autotexts = ax.pie(sizes, colors=colors, labels=labels, 
    #                                  autopct='%1.1f%%', startangle=90,
    #                                  wedgeprops=dict(width=0.5, edgecolor='white', linewidth=2),
    #                                  textprops={'fontsize': 14, 'fontweight': 'bold', 'color': navy_blue})
    # for autotext in autotexts:
    #     autotext.set_color('white')
    #     autotext.set_fontsize(14)
    #     autotext.set_fontweight('bold')
    # ax.set_title('Market Penetration', fontsize=18, fontweight='bold', 
    #               color=navy_blue, pad=20)
    ax.axis('equal')
    plt.tight_layout()
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=white)
    plt.close()
    print(f"Saved {out_path}")


def create_age_group_comparison_chart(jewelry_buyers, all_customers, out_path='results/figure2_age_group_comparison.png'):
    fig, ax = plt.subplots(figsize=(12, 8))
    fig.patch.set_facecolor(white)

    age_bins = [20, 25, 30, 35, 40, 45, 50, 55, 100]
    age_labels = ['20-24', '25-29', '30-34', '35-39', '40-44', '45-49', '50-54', '55+']
    jewelry_age_dist = pd.cut(jewelry_buyers['age'], bins=age_bins, labels=age_labels, right=False).value_counts().sort_index()
    all_age_dist = pd.cut(all_customers['age'], bins=age_bins, labels=age_labels, right=False).value_counts().sort_index()

    jewelry_pct = (jewelry_age_dist / len(jewelry_buyers)) * 100 if len(jewelry_buyers) else np.zeros(len(age_labels))
    all_pct = (all_age_dist / len(all_customers)) * 100 if len(all_customers) else np.zeros(len(age_labels))

    x = np.arange(len(age_labels))
    width = 0.35
    bars1 = ax.bar(x - width/2, jewelry_pct.values, width, label='Jewelry Buyers', 
                    color=accent_gold, edgecolor=navy_blue, linewidth=2)
    bars2 = ax.bar(x + width/2, all_pct.values, width, label='All Customers', 
                    color=navy_blue, edgecolor=navy_blue, linewidth=2)
    ax.set_xlabel('Age Group', fontsize=14, fontweight='bold', color=navy_blue)
    ax.set_ylabel('Percentage (%)', fontsize=14, fontweight='bold', color=navy_blue)
    ax.set_title('Age Group Distribution Comparison', fontsize=18, fontweight='bold', 
                  color=navy_blue, pad=20)
    ax.set_xticks(x)
    ax.set_xticklabels(age_labels, fontsize=12)
    ax.legend(fontsize=12, loc='upper right')
    ax.tick_params(axis='both', colors=navy_blue, labelsize=12)
    ax.spines['bottom'].set_color(navy_blue)
    ax.spines['left'].set_color(navy_blue)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.set_facecolor(white)

    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.1f}%',
                    ha='center', va='bottom', fontsize=10, fontweight='bold', 
                    color=navy_blue)

    plt.tight_layout()
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=white)
    plt.close()
    print(f"Saved {out_path}")


def create_age_distribution_boxplot(jewelry_buyers, all_customers, out_path='results/figure3_age_distribution_boxplot.png'):
    fig, ax = plt.subplots(figsize=(10, 8))
    fig.patch.set_facecolor(white)

    jewelry_mean = jewelry_buyers['age'].mean() if len(jewelry_buyers) else np.nan
    all_mean = all_customers['age'].mean() if len(all_customers) else np.nan

    box_data = [jewelry_buyers['age'], all_customers['age']]
    bp = ax.boxplot(box_data, labels=['Jewelry Buyers', 'All Customers'], 
                     patch_artist=True, widths=0.6)
    for patch in bp['boxes']:
        patch.set_facecolor(white)
        patch.set_edgecolor(navy_blue)
        patch.set_linewidth(3)
    for element in ['whiskers', 'fliers', 'medians', 'caps']:
        plt.setp(bp[element], color=navy_blue, linewidth=2)
    if 'means' in bp:
        plt.setp(bp['means'], color=accent_gold, linewidth=2, marker='D', markersize=8)

    ax.text(1, jewelry_mean, f'Mean: {jewelry_mean:.1f}' if not np.isnan(jewelry_mean) else 'Mean: N/A', 
             ha='center', va='bottom', fontsize=12, fontweight='bold', color=accent_gold)
    ax.text(2, all_mean, f'Mean: {all_mean:.1f}' if not np.isnan(all_mean) else 'Mean: N/A', 
             ha='center', va='bottom', fontsize=12, fontweight='bold', color=navy_blue)

    ax.set_ylabel('Age', fontsize=14, fontweight='bold', color=navy_blue)
    ax.set_title('Age Distribution Comparison', fontsize=18, fontweight='bold', 
                  color=navy_blue, pad=20)
    ax.tick_params(axis='both', colors=navy_blue, labelsize=12)
    ax.spines['bottom'].set_color(navy_blue)
    ax.spines['left'].set_color(navy_blue)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.set_facecolor(white)

    plt.tight_layout()
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=white)
    plt.close()
    print(f"Saved {out_path}")


# Use original (unscaled) data for consistency across notebooks
# 使用原始(未縮放)資料以確保跨 notebook 的一致性
# orig_sample = 'data/transactions_train_sample.csv'
jiacheng_data_folder = '/Users/jzou/Library/Mobile Documents/com~apple~CloudDocs/data'
serena_data_folder = 'data'

orig_sample = os.path.join(serena_data_folder, 'transactions_train_sample.csv')


# load data for figures using original transactions
jewelry_buyers, all_customers = load_three_figures_data(orig_sample, data_dir=serena_data_folder)

# create figures
create_market_penetration_chart(jewelry_buyers, all_customers)
create_age_group_comparison_chart(jewelry_buyers, all_customers)
create_age_distribution_boxplot(jewelry_buyers, all_customers)

# recompute slide metrics on original (unscaled) file for consistency
# 在原始檔案上重新計算投影片指標以確保一致性
metrics_original = gather_slide_data(orig_sample, out_json='results/slide_02_03_data.json', out_md='results/slide_02_03.md')
print('Done. Metrics keys:', list(metrics_original.keys()))


--- Top Jewelry Types ---
product_type_name
Earring        1159
Necklace        581
Ring            240
Hair string     238
Bracelet        180
Earrings         11
Name: count, dtype: int64

Jewelry article IDs: 2409 unique items
Shape of customers_df: (1371980, 8) | jewelry_buyers: (35592, 8) | all_customers: (1371980, 8)
For plotting purpuroses; 

Customer categories: ['Jewelry Buyers', 'Non-Jewelry Buyers']
Customer counts: [35592, 1336388]
Saved results/figure1_market_penetration_by_01ipynb.png
Saved results/figure2_age_group_comparison.png


  bp = ax.boxplot(box_data, labels=['Jewelry Buyers', 'All Customers'],
  plt.tight_layout()


Saved results/figure3_age_distribution_boxplot.png
Wrote: results/slide_02_03_data.json results/slide_02_03.md
Done. Metrics keys: ['rows', 'columns', 'unique_customers', 'revenue_column', 'total_revenue', 'median_order_value', 'revenue_stats', 'unique_items', 'top_items_by_count', 'top_customers_by_revenue', 'missingness', 'jewelry']
