In [None]:
# -*- coding: utf-8 -*-
"""
OpenAlex: Top-50 topics → journals via primary-topic works,
then filter journals whose 25 representative topics include ≥1 of the Top-50.

"""

import os
import re
import json
import time
from datetime import datetime
from collections import defaultdict
from typing import Dict, List, Set, Iterable, Optional, Any

import requests
import pandas as pd

# -----------------------------
# Configuration
# -----------------------------
BASE = "https://api.openalex.org"

MAILTO = os.getenv("OPENALEX_MAILTO", "YOUR_EMAIL@example.com")
BASE_DIR = os.getenv("OPENALEX_BASE_DIR", os.path.join(".", "data"))
SAVE_DIR = os.path.join(BASE_DIR, "out_primary_Representative_topic25")
os.makedirs(SAVE_DIR, exist_ok=True)

# Retries / throttling
MAX_RETRIES = int(os.getenv("OPENALEX_MAX_RETRIES", "6"))
BACKOFF = float(os.getenv("OPENALEX_BACKOFF", "1.6"))
TIMEOUT_S = int(os.getenv("OPENALEX_TIMEOUT_S", "60"))

# Threshold
MIN_TOPIC_COUNT = int(os.getenv("OPENALEX_MIN_TOPIC_COUNT", "1"))

# Checkpoints
CP_TID2SIDS = os.path.join(SAVE_DIR, "cp_tid_to_sids_primary.json")   # {tid: [sid,...]}
CP_SID_TOP25 = os.path.join(SAVE_DIR, "cp_sid_top25_topics.json")     # {sid: [tid,...]}

# Outputs
OUT_SUMMARY_TPL = os.path.join(SAVE_DIR, "journals_Representative_topic25_summary_{}.csv")
OUT_LONG_TPL   = os.path.join(SAVE_DIR, "journals_Representative_topic25_long_{}.csv")


# -----------------------------
# Input: Top-50 topic IDs
# -----------------------------
TOPIC_ID_CSV = """T10712 , T12289, T12863, T11813, T11657, T10102, T10068, T14330, T12028, T11937, T11891, T10215, T10609, T12377, T10557,
T10162 , T13166, T13516, T10953, T14386, T11499, T11572, T12171, T11744,
T10003, T11197, T11024, T11719, T14246, T12364, T10355, T10286, T11530,
T11986 , T13807, T10181, T13083, T11045, T14109, T10028, T12016, T12478,
T11437 , T11147, T14380, T13607, T12573, T13496, T10154, T12764"""
TOPIC_IDS: List[str] = re.findall(r"T\d+", TOPIC_ID_CSV)
TOPIC_SET: Set[str] = set(TOPIC_IDS)

# -----------------------------
# HTTP session
# -----------------------------
SESSION = requests.Session()
SESSION.headers.update({"User-Agent": "openalex-lis-top50-journals/1.0"})


# -----------------------------
# Utilities
# -----------------------------
def req_json(url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """GET JSON with retries; always appends mailto."""
    params = dict(params or {})
    params["mailto"] = MAILTO

    last_err = None
    for i in range(MAX_RETRIES):
        try:
            r = SESSION.get(url, params=params, timeout=TIMEOUT_S)
            if r.status_code == 200:
                return r.json()
            last_err = f"{r.status_code} {r.text[:200]}"
        except Exception as e:
            last_err = str(e)
        time.sleep(BACKOFF ** (i + 1))
    raise RuntimeError(f"GET failed: {url} :: {last_err}")


def save_json(path: str, obj: Any) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def load_json(path: str) -> Optional[Any]:
    if os.path.exists(path):
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)
    return None


def to_sid(openalex_id_or_sid: str) -> str:
    """Normalize to 'Sxxxx'."""
    s = str(openalex_id_or_sid).strip()
    if s.startswith("https://openalex.org/"):
        s = s.rsplit("/", 1)[-1]
    return s


def to_tid(openalex_id_or_tid: str) -> str:
    """Normalize to 'Txxxx'."""
    t = str(openalex_id_or_tid).strip()
    if t.startswith("https://openalex.org/"):
        t = t.rsplit("/", 1)[-1]
    return t


# -----------------------------
# Step 1: topic → works → journals (primary topic only, journal sources only)
# -----------------------------
def collect_journal_sids_from_primary_topic(tid: str) -> Set[str]:
    sids: Set[str] = set()
    cursor = "*"

    while True:
        params = {
            "filter": f"primary_topic.id:{tid},primary_location.source.type:journal",
            "select": "id,primary_location",
            "per-page": 200,
            "cursor": cursor,
        }
        js = req_json(f"{BASE}/works", params=params)
        results = js.get("results", [])

        for w in results:
            src = (w.get("primary_location") or {}).get("source") or {}
            if src.get("type") == "journal":
                sid = to_sid(src.get("id", ""))
                if sid.startswith("S"):
                    sids.add(sid)

        cursor = js.get("meta", {}).get("next_cursor")
        if not cursor or not results:
            break

    return sids


# -----------------------------
# Step 2: journal → 25 representative topics (batch, consistent with UI)
# -----------------------------
def fetch_sources_top25_topics_batch(sids: Iterable[str]) -> Dict[str, List[str]]:
    """
    Batch call to /sources with filter=openalex:S...|S... (up to 200 per call).
    Returns {sid: [T...]} for sources that returned topics (empty topics omitted).
    """
    out: Dict[str, List[str]] = {}
    ids = [to_sid(s) for s in sids]

    for i in range(0, len(ids), 200):
        chunk = ids[i:i + 200]
        if not chunk:
            continue

        js = req_json(
            f"{BASE}/sources",
            params={
                "filter": "openalex:" + "|".join(chunk),
                "select": "id,topics",
                "per-page": 200,
            },
        )

        for r in js.get("results", []):
            sid = to_sid(r.get("id", ""))
            topics = r.get("topics") or []
            tids = [to_tid(t.get("id", "")) for t in topics if to_tid(t.get("id", "")).startswith("T")]
            if tids:
                out[sid] = tids

        time.sleep(0.1)  # polite

    return out


def fetch_source_top25_topics_single(sid: str) -> List[str]:
    """Fallback: /sources/{sid}?select=topics (omits empty)."""
    js = req_json(f"{BASE}/sources/{sid}", params={"select": "topics"})
    topics = js.get("topics") or []
    tids = [to_tid(t.get("id", "")) for t in topics if to_tid(t.get("id", "")).startswith("T")]
    return tids


def hydrate_top25_cache(all_sids: List[str], cache: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """Fill cache via batch; if still missing, fallback to single calls."""
    missing = [sid for sid in all_sids if sid not in cache]
    if missing:
        cache.update(fetch_sources_top25_topics_batch(missing))

    # If still missing (e.g., transient errors), try single fallback
    still_missing = [sid for sid in all_sids if sid not in cache]
    for sid in still_missing:
        try:
            tids = fetch_source_top25_topics_single(sid)
            if tids:
                cache[sid] = tids
        except Exception as e:
            print(f"[WARN] failed source {sid}: {e}")

    return cache


# -----------------------------
# Meta: sources metadata (for output)
# -----------------------------
def batch_fetch_sources_meta(sids: List[str]) -> Dict[str, Dict[str, Any]]:
    out: Dict[str, Dict[str, Any]] = {}
    ids = [to_sid(s) for s in sids]

    for i in range(0, len(ids), 50):
        chunk = ids[i:i + 50]
        if not chunk:
            continue

        js = req_json(
            f"{BASE}/sources",
            params={
                "filter": "openalex:" + "|".join(chunk),
                "select": "id,display_name,issn_l,issn,host_organization_name,country_code,type",
                "per-page": 200,
            },
        )

        for r in js.get("results", []):
            out[to_sid(r.get("id", ""))] = r

        time.sleep(0.4)

    return out


# -----------------------------
# Main pipeline (steps 1–5)
# -----------------------------
def main(min_topics: int = MIN_TOPIC_COUNT) -> None:
    if not TOPIC_IDS:
        raise RuntimeError("No topic IDs provided.")

    # Step 1: For each Top-50 topic, collect journal SIDs from works with that primary topic
    cp_tid2sids: Dict[str, List[str]] = load_json(CP_TID2SIDS) or {}
    for idx, tid in enumerate(TOPIC_IDS, 1):
        if tid in cp_tid2sids:
            print(f"[{idx}/{len(TOPIC_IDS)}] {tid} cached ({len(cp_tid2sids[tid])} sids)")
            continue

        try:
            sids = sorted(collect_journal_sids_from_primary_topic(tid))
            cp_tid2sids[tid] = sids
            print(f"[{idx}/{len(TOPIC_IDS)}] {tid}: {len(sids)} sids (primary works)")
        except Exception as e:
            print(f"[WARN] topic {tid} failed: {e}")
            cp_tid2sids[tid] = []

        save_json(CP_TID2SIDS, cp_tid2sids)
        if idx % 3 == 0:
            time.sleep(0.5)

    # Step 2 (prep): Build source → matched Top-50 topics from Step 1 results
    source_to_topics: Dict[str, Set[str]] = defaultdict(set)
    for tid, sids in cp_tid2sids.items():
        for sid in sids:
            source_to_topics[to_sid(sid)].add(tid)

    all_sids = list(source_to_topics.keys())

    # Step 3: Fetch each journal's 25 representative topics (UI-consistent) and retain journals
    sid_top25_cache: Dict[str, List[str]] = load_json(CP_SID_TOP25) or {}
    sid_top25_cache = hydrate_top25_cache(all_sids, sid_top25_cache)
    save_json(CP_SID_TOP25, sid_top25_cache)

    filtered: Dict[str, Set[str]] = {}
    for sid, matched_top50 in source_to_topics.items():
        top25 = set(sid_top25_cache.get(sid, []))
        # Keep journal if (Top-50 matched via works) ∩ (25 representative topics) is non-empty
        keep = matched_top50.intersection(top25).intersection(TOPIC_SET)
        if keep:
            filtered[sid] = keep

    # Step 4: Apply threshold (top50_1 .. top50_k idea corresponds to min_topics)
    kept = [(sid, len(tset), sorted(tset)) for sid, tset in filtered.items() if len(tset) >= int(min_topics)]
    kept.sort(key=lambda x: x[1], reverse=True)

    if not kept:
        print(f"[DONE] No journals found with ≥{int(min_topics)} topics.")
        return

    # Step 5: Output
    meta = batch_fetch_sources_meta([sid for sid, _, _ in kept])

    rows = []
    long_rows = []

    for sid, cnt, tids in kept:
        m = meta.get(sid, {})
        rows.append(
            {
                "source_id": f"https://openalex.org/{sid}",
                "display_name": m.get("display_name", ""),
                "issn_l": m.get("issn_l", ""),
                "issn": ";".join(m.get("issn", []) or []),
                "publisher": m.get("host_organization_name", ""),
                "country_code": m.get("country_code", ""),
                "type": m.get("type", ""),
                "topic_match_count": cnt,            # guaranteed ≤ 25
                "matched_topic_ids": ";".join(tids), # Top-25 ∩ Top-50
            }
        )

        for t in tids:
            long_rows.append(
                {
                    "source_id": f"https://openalex.org/{sid}",
                    "display_name": m.get("display_name", ""),
                    "topic_id": t,
                }
            )

    df_summary = pd.DataFrame(rows)
    df_long = pd.DataFrame(long_rows)

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_summary = OUT_SUMMARY_TPL.format(ts)
    out_long = OUT_LONG_TPL.format(ts)

    df_summary.to_csv(out_summary, index=False, encoding="utf-8-sig")
    df_long.to_csv(out_long, index=False, encoding="utf-8-sig")

    print(f"[DONE] Journals with ≥{int(min_topics)} topics: {len(df_summary)}")
    print("Saved CSV (summary):", out_summary)
    print("Saved CSV (pairs):", out_long)


if __name__ == "__main__":
    main(min_topics=MIN_TOPIC_COUNT)