In [None]:
"""
Fetch representative topics for OpenAlex sources (as returned by the API).

Input : a text file with one normalized OpenAlex source ID per line (e.g., S1234567890)
Output: CSV (long format):
        source_id, source_title, topic_rank, topic_id, topic_name, count

Usage:
  python openalex_source_topics.py input_source_ids.txt output_topics_long.csv

"""

import sys
import time
import requests
import pandas as pd
from urllib.parse import urlencode

BASE_URL = "https://api.openalex.org"
USER_AGENT = "openalex-source-topics-fetcher/1.0"
SELECT_FIELDS = "display_name,topics"


def read_source_ids(path: str) -> list[str]:
    """Read already-normalized source IDs (one per line)."""
    ids = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if s:
                ids.append(s)

    # remove duplicates while preserving order
    seen, uniq = set(), []
    for x in ids:
        if x not in seen:
            seen.add(x)
            uniq.append(x)

    print(f"[INFO] loaded: {len(uniq)} unique source IDs")
    return uniq


def req_json(url: str, max_retries: int = 5, backoff: float = 1.6) -> dict:
    """GET JSON with retry for transient errors."""
    last_err = None
    for i in range(max_retries):
        try:
            r = requests.get(url, headers={"User-Agent": USER_AGENT}, timeout=30)
            if r.status_code == 200:
                return r.json()
            if r.status_code in (429, 500, 502, 503, 504):
                time.sleep(backoff ** i)
                continue
            r.raise_for_status()
        except Exception as e:
            last_err = e
            time.sleep(backoff ** i)
    raise RuntimeError(f"Request failed after {max_retries} retries: {last_err}")


def fetch_topics_for_source(sid: str) -> list[dict]:
    """
    Fetch topics exactly as returned by OpenAlex for a given source.
    """
    url = f"{BASE_URL}/sources/{sid}?{urlencode({'select': SELECT_FIELDS})}"
    data = req_json(url)

    source_title = data.get("display_name", "")
    topics = data.get("topics") or []

    rows = []
    for rank, t in enumerate(topics, start=1):
        rows.append({
            "source_id": sid,
            "source_title": source_title,
            "topic_rank": rank,
            "topic_id": t.get("id", ""),
            "topic_name": t.get("display_name", ""),
            "count": t.get("count"),
        })

    return rows


def fetch_all(source_ids: list[str], sleep_sec: float = 0.15) -> pd.DataFrame:
    rows = []
    failed = []

    for i, sid in enumerate(source_ids, start=1):
        try:
            rows.extend(fetch_topics_for_source(sid))
        except Exception as e:
            failed.append((sid, str(e)))

        time.sleep(sleep_sec)

        if i % 25 == 0:
            print(f"[INFO] {i}/{len(source_ids)} processed | rows={len(rows)}")

    if failed:
        print(f"[WARN] failed: {len(failed)} sources (showing up to 3): {failed[:3]}")

    return pd.DataFrame(rows, columns=[
        "source_id", "source_title", "topic_rank", "topic_id", "topic_name", "count"
    ])


def main():
    if len(sys.argv) != 3:
        print("Usage: python openalex_source_topics.py input_source_ids.txt output.csv")
        sys.exit(1)

    in_path = sys.argv[1]
    out_path = sys.argv[2]

    source_ids = read_source_ids(in_path)
    df = fetch_all(source_ids)

    df.to_csv(out_path, index=False, encoding="utf-8-sig")
    print(f"[SAVE] {out_path} (rows={len(df)})")


if __name__ == "__main__":
    main()