In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import re

START_YEAR = 1900
END_YEAR = 2025

MIN_NUM_AS_YEAR = 32
N_COLORS = 15
GAMMA = 0.40
MAX_TICKS_X = 10

GROUPS = {
    "sports": {
        "input": Path.home() / "Desktop" / "parsed_years_history_of_sports_tagged.csv",
    },
    "objects": {
        "input": Path.home() / "Desktop" / "parsed_years_historical_objects_tagged.csv",
    },
    "ideologies": {
        "input": Path.home() / "Desktop" / "parsed_years_history_of_ideologies_tagged.csv",
    },
}

def filename_to_language(fn: str) -> str:
    if not isinstance(fn, str):
        return "Unknown"
    s = fn.strip()
    if s.lower().endswith(".txt"):
        s = s[:-4]
    return s.strip()

def make_discrete_cmap(n_colors=15):
    # 反转的 magma 色带 → 黄色=少，紫色=多
    base = plt.cm.magma_r(np.linspace(0.06, 1.0, n_colors))
    return ListedColormap(base, name=f"magma_r_{n_colors}")

def draw_combined_heatmaps(pivots: dict, start: int, end: int, out_path: Path):
    groups = list(pivots.keys())
    n_groups = len(groups)

    fig_w = n_groups * 20
    fig_h = max(10, max(len(p.index) * 0.40 for p in pivots.values()))
    fig, axes = plt.subplots(
        1, n_groups, figsize=(fig_w, fig_h), dpi=160,
        constrained_layout=False, sharey=False   # ❌ 每个子图独立 y 轴
    )

    if n_groups == 1:
        axes = [axes]

    cmap = make_discrete_cmap(N_COLORS)
    boundaries = np.linspace(0.0, 1.0, N_COLORS + 1)
    norm = BoundaryNorm(boundaries, N_COLORS, clip=True)

    for ax, (group, pivot) in zip(axes, pivots.items()):
        if pivot.empty:
            continue

        row_max = pivot.max(axis=1).replace(0, 1)
        mat_norm = pivot.div(row_max, axis=0)
        mat_gamma = np.power(mat_norm.values.astype(float), GAMMA)

        im = ax.imshow(
            mat_gamma,
            aspect="auto",
            interpolation="nearest",
            cmap=cmap,
            norm=norm,
            origin="upper"
        )

        # 横线分隔语言
        for y in range(1, len(mat_norm.index)):
            ax.axhline(y - 0.5, color="white", linewidth=0.5)

        # ✅ 每个子图都显示语言
        ax.set_yticks(np.arange(len(mat_norm.index)))
        ax.set_yticklabels(mat_norm.index.tolist(), fontsize=7)

        # X 轴：年份
        years = mat_norm.columns.values
        num_years = len(years)
        if num_years <= MAX_TICKS_X:
            xticks = np.arange(num_years)
        else:
            step = max(1, num_years // MAX_TICKS_X)
            xticks = np.arange(0, num_years, step)
        ax.set_xticks(xticks)
        ax.set_xticklabels(
            [str(years[i]) for i in xticks], rotation=90, ha="center", fontsize=60   # ⭐ 调大字体
)

        # 每个子图的小标题
        if group == "objects":
            ax.set_title("Historical Objects", pad=10, fontsize=65)
        elif group == "sports":
            ax.set_title("History of Sports", pad=10, fontsize=65)
        elif group == "ideologies":
            ax.set_title("History of Ideologies", pad=10, fontsize=65)

    # ===== 顶部主副标题 =====
    fig.suptitle(
        "Historical Dates in Wikipedia, by Language, Within Three Categories\n"
        "(Sports, Ideology, and Objects)",
        fontsize=100, y=0.95
    )
    fig.text(
        0.5, 0.90,
        "Numerical dates only, ordered by Y axis values; heatmap values are normalized per language\n"
        "Yellow = minimum, Purple = maximum; Source: Wikipedia, September 2025",
        ha="center", fontsize=85
    )

    # 左侧竖直副标题
    fig.text(0.02, 0.5, "Normalized Frequency (per Language)",
             va="center", rotation=90, fontsize=80)

    # 调整横向和上下间距 → 把 heatmaps 往上推
    fig.subplots_adjust(
        left=0.12, right=0.88,
        wspace=0.8,
        top=0.85, bottom=0.10
    )

    # 右侧颜色条
    cbar = fig.colorbar(
        plt.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=axes, orientation="vertical", fraction=0.02, pad=0.04
    )

    plt.savefig(out_path, dpi=160)
    plt.close()
    print(f"[OK] 已保存 {out_path}")

def process_group(group: str, cfg: dict):
    in_path = cfg["input"]

    print(f"[INFO] 读取 {group}: {in_path}")
    df = pd.read_csv(in_path, dtype=str, encoding="utf-8")
    df["language"] = df["filename"].apply(filename_to_language)

    records = []
    for _, row in df.iterrows():
        lang = row["language"]
        years_str = row.get("parsed_years", "")
        if not isinstance(years_str, str) or not years_str.strip():
            continue
        for y in re.findall(r"-?\d{1,4}", years_str):
            y = int(y)
            if MIN_NUM_AS_YEAR <= y <= END_YEAR:
                records.append((lang, y))

    years_df = pd.DataFrame(records, columns=["language", "year"])
    counts = years_df.groupby(["language", "year"]).size().reset_index(name="count")

    mask = (counts["year"] >= START_YEAR) & (counts["year"] <= END_YEAR)
    counts_window = counts[mask]
    if counts_window.empty:
        return pd.DataFrame()

    all_years = np.arange(START_YEAR, END_YEAR + 1, dtype=int)
    pivot = (
        counts_window.pivot(index="language", columns="year", values="count")
        .reindex(columns=all_years)
        .fillna(0)
    )

    weighted_means = pivot.apply(
        lambda row: np.average(pivot.columns, weights=row) if row.sum() > 0 else START_YEAR,
        axis=1
    )

    pivot = pivot.loc[weighted_means.sort_values(ascending=False).index]
    return pivot

def main():
    pivots = {}
    for group, cfg in GROUPS.items():
        pivot = process_group(group, cfg)
        if not pivot.empty:
            pivots[group] = pivot

    if pivots:
        out_path = Path.home() / "Desktop" / f"combined_heatmap_{START_YEAR}-{END_YEAR}.png"
        draw_combined_heatmaps(pivots, START_YEAR, END_YEAR, out_path)

if __name__ == "__main__":
    main()