diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..19ff7d1 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,33 @@ +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bookworm", + "customizations": { + "codespaces": { + "openFiles": [ + "README.md", + "app.py" + ] + }, + "vscode": { + "settings": {}, + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance" + ] + } + }, + "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y argparse.ArgumentParser: + p = argparse.ArgumentParser(prog="content_parser") + sub = p.add_subparsers(dest="command", required=True) + + sub.add_parser("list-sources", help="Show registered source plugins") + + # ----- jobs subcommand ----- + jobs_p = sub.add_parser("jobs", help="Manage scheduled jobs") + jobs_sub = jobs_p.add_subparsers(dest="jobs_command", required=True) + jobs_sub.add_parser("list", help="List all saved jobs") + show_p = jobs_sub.add_parser("show", help="Print a job's YAML") + show_p.add_argument("name") + run_job_p = jobs_sub.add_parser("run", help="Run a job once") + run_job_p.add_argument("name") + jobs_sub.add_parser("install-cron", help="Regenerate the managed crontab block") + jobs_sub.add_parser("remove-cron", help="Remove the managed crontab block") + jobs_sub.add_parser("cron-status", help="Show what's currently in the managed block") + + run_p = sub.add_parser("run", help="Resolve inputs and fetch items for one source") + run_p.add_argument("--source", required=True, help="Plugin name (e.g. youtube, instagram)") + run_p.add_argument("--output", "-o", default=None, help="Output directory") + + # Generic input flags — repeatable. Plugin decides which kinds it understands. + run_p.add_argument( + "--input", "-i", action="append", default=[], + metavar="KIND=VALUE", + help='Input as "kind=value" (e.g. --input video=https://youtu.be/x). Repeatable.', + ) + # Convenience aliases + run_p.add_argument("--query", "-q", action="append", default=[]) + run_p.add_argument("--channel", "-c", action="append", default=[]) + run_p.add_argument("--playlist", "-p", action="append", default=[]) + run_p.add_argument("--video", "-v", action="append", default=[]) + run_p.add_argument("--hashtag", action="append", default=[]) + run_p.add_argument("--account", action="append", default=[]) + run_p.add_argument("--post", action="append", default=[]) + + # Plugin settings as key=value, repeatable + run_p.add_argument( + "--set", action="append", default=[], + metavar="KEY=VALUE", + help='Override a plugin setting (e.g. --set max_comments=100). Repeatable.', + ) + return p + + +def _parse_kv(items: list[str]) -> dict[str, str]: + out: dict[str, str] = {} + for s in items: + if "=" not in s: + raise SystemExit(f"Expected KEY=VALUE, got {s!r}") + k, v = s.split("=", 1) + out[k.strip()] = v.strip() + return out + + +def _coerce(value: str): + low = value.lower() + if low in ("true", "yes", "on"): + return True + if low in ("false", "no", "off"): + return False + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + +def cmd_list_sources() -> int: + for p in all_plugins(): + kinds = ", ".join(s.kind for s in p.input_specs()) + print(f"{p.name:12s} {p.label:20s} inputs=[{kinds}] secrets={p.secret_keys}") + return 0 + + +def cmd_run(args: argparse.Namespace) -> int: + plugin = get_plugin(args.source) + + inputs: dict[str, list[str]] = {s.kind: [] for s in plugin.input_specs()} + + # Aliases → inputs + for alias_attr, kind in [ + ("query", "query"), ("channel", "channel"), ("playlist", "playlist"), + ("video", "video"), ("hashtag", "hashtag"), ("account", "account"), + ("post", "post"), + ]: + for v in getattr(args, alias_attr, []): + inputs.setdefault(kind, []).append(v) + + # Generic --input KIND=VALUE + for raw in args.input: + if "=" not in raw: + raise SystemExit(f"--input expects KIND=VALUE, got {raw!r}") + kind, value = raw.split("=", 1) + inputs.setdefault(kind.strip(), []).append(value.strip()) + + # Drop empty kinds + inputs = {k: v for k, v in inputs.items() if v} + + if not inputs: + accepted = ", ".join(s.kind for s in plugin.input_specs()) + raise SystemExit(f"No inputs given. Plugin {args.source!r} accepts: {accepted}") + + # Settings + settings: dict = {s.key: s.default for s in plugin.settings_specs()} + for k, v in _parse_kv(args.set).items(): + settings[k] = _coerce(v) + + # Secrets + secrets: dict[str, str] = {k: get_secret(k) for k in plugin.secret_keys} + # also pull any well-known optional secrets the plugin might use + for opt in ("WEBSHARE_USERNAME", "WEBSHARE_PASSWORD", "PROXY_HTTP_URL", "PROXY_HTTPS_URL"): + v = get_secret(opt) + if v: + secrets[opt] = v + + out_dir = Path(args.output) if args.output else None + + def log(msg: str) -> None: + print(msg) + + def progress(done: int, total: int, message: str) -> None: + print(f" [{done}/{total}] {message}") + + result = run(plugin, inputs, settings, secrets, output_dir=out_dir, log=log, progress=progress) + print(f"\nDone. {len(result.items)} item(s) saved to {result.out_dir.resolve()}") + return 0 + + +def cmd_jobs(args: argparse.Namespace) -> int: + from .jobs import store as jobs_store # noqa: PLC0415 + from .jobs.runner import run_job # noqa: PLC0415 + from .jobs.schema import dump_job_yaml # noqa: PLC0415 + + if args.jobs_command == "list": + jobs = jobs_store.list_jobs() + if not jobs: + print("No jobs found in", jobs_store.JOBS_DIR) + return 0 + for job in jobs: + schedule = job.schedule or "(manual)" + inputs_summary = ", ".join(f"{k}={len(v)}" for k, v in job.inputs.items()) or "—" + sheet_count = len(job.sheet_inputs) + print( + f"{job.name:30s} source={job.source:10s} schedule={schedule:20s} " + f"inline=[{inputs_summary}] sheet_refs={sheet_count}" + ) + invalid = jobs_store.list_invalid() + if invalid: + print() + print("Invalid job files:") + for name, err in invalid: + print(f" {name}: {err}") + return 0 + + if args.jobs_command == "show": + job = jobs_store.load_job(args.name) + print(dump_job_yaml(job)) + return 0 + + if args.jobs_command == "run": + from .core.errors import AuthError, PluginError # noqa: PLC0415 + try: + result = run_job( + args.name, log=print, + progress=lambda d, t, m: print(f" [{d}/{t}] {m}"), + ) + except (AuthError, PluginError) as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + except KeyError as e: + # get_plugin raises KeyError for unknown source. + print(f"Error: unknown plugin/source — {e}", file=sys.stderr) + return 1 + print(f"\nDone. {len(result.items)} item(s) saved to {result.out_dir.resolve()}") + return 0 + + if args.jobs_command == "install-cron": + from .jobs.cron import install_cron # noqa: PLC0415 + entries = install_cron() + if not entries: + print("No scheduled jobs found. Managed block cleared.") + return 0 + print(f"Installed {len(entries)} entrie(s) in crontab:") + for e in entries: + print(f" {e.schedule} {e.job_name}") + return 0 + + if args.jobs_command == "remove-cron": + from .jobs.cron import remove_cron # noqa: PLC0415 + removed = remove_cron() + print("Removed managed block." if removed else "Managed block not present.") + return 0 + + if args.jobs_command == "cron-status": + from .jobs.cron import read_block # noqa: PLC0415 + entries = read_block() + if not entries: + print("Managed block is empty or absent.") + return 0 + for e in entries: + print(f"{e.schedule} job:{e.job_name}\n → {e.command}") + return 0 + + return 2 + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + if args.command == "list-sources": + return cmd_list_sources() + if args.command == "run": + return cmd_run(args) + if args.command == "jobs": + return cmd_jobs(args) + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/content_parser/clients/__init__.py b/content_parser/clients/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/clients/apify.py b/content_parser/clients/apify.py new file mode 100644 index 0000000..598799f --- /dev/null +++ b/content_parser/clients/apify.py @@ -0,0 +1,62 @@ +"""Minimal Apify HTTP client — runs an actor synchronously and returns dataset items. + +Lives outside the plugins/ tree so multiple plugins (Instagram, Telegram, …) +can share it without depending on each other. The token always travels in +the Authorization: Bearer header so it doesn't leak into nginx access logs. +""" +from __future__ import annotations + +from typing import Any + +import requests + + +APIFY_BASE = "https://api.apify.com/v2" + + +class ApifyError(Exception): + pass + + +class ApifyClient: + def __init__(self, token: str, timeout: int = 600): + if not token: + raise ValueError("Apify token is required") + self.token = token + self.timeout = timeout + + def run_actor( + self, + actor_id: str, + actor_input: dict[str, Any], + ) -> list[dict]: + """Run an actor synchronously and return its dataset items. + + actor_id: e.g. "apify/instagram-scraper" — slashes get replaced by '~'. + """ + slug = actor_id.replace("/", "~") + url = f"{APIFY_BASE}/acts/{slug}/run-sync-get-dataset-items" + headers = {"Authorization": f"Bearer {self.token}"} + params = {"format": "json"} + try: + r = requests.post( + url, headers=headers, params=params, json=actor_input, timeout=self.timeout + ) + except requests.RequestException as e: + raise ApifyError(f"Network error talking to Apify: {e}") from e + + if r.status_code == 401: + raise ApifyError("Apify rejected the token (401). Check APIFY_API_TOKEN.") + if r.status_code == 402: + raise ApifyError("Apify says the account is out of credits (402).") + if not r.ok: + raise ApifyError(f"Apify returned {r.status_code}: {r.text[:300]}") + + try: + data = r.json() + except ValueError as e: + raise ApifyError(f"Apify returned non-JSON: {r.text[:300]}") from e + + if not isinstance(data, list): + raise ApifyError(f"Expected list of items, got {type(data).__name__}: {str(data)[:300]}") + return data diff --git a/content_parser/core/__init__.py b/content_parser/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/core/errors.py b/content_parser/core/errors.py new file mode 100644 index 0000000..2e74ab7 --- /dev/null +++ b/content_parser/core/errors.py @@ -0,0 +1,17 @@ +"""Common errors raised by plugins; UI/CLI map them to user-friendly messages.""" + + +class PluginError(Exception): + """Generic plugin error.""" + + +class AuthError(PluginError): + """Missing or invalid credentials.""" + + +class RateLimitError(PluginError): + """Source rejected the request due to rate limiting / quota.""" + + +class ItemUnavailable(PluginError): + """The requested item is private, removed, or geo-blocked.""" diff --git a/content_parser/core/output.py b/content_parser/core/output.py new file mode 100644 index 0000000..7d17f88 --- /dev/null +++ b/content_parser/core/output.py @@ -0,0 +1,200 @@ +"""Source-agnostic writers: Item → JSON / Markdown / CSV / index. + +Filenames use `__` so multiple sources coexist in one folder. +""" +from __future__ import annotations + +import csv +import hashlib +import json +import re +from dataclasses import asdict +from pathlib import Path + +from .schema import Item + + +_FALLBACK_FILENAME = "item" + + +def _safe_filename(name: str, max_length: int = 80) -> str: + cleaned = re.sub(r"[^\w\s-]", "", name, flags=re.UNICODE).strip() + cleaned = re.sub(r"\s+", "_", cleaned) + return cleaned[:max_length] or _FALLBACK_FILENAME + + +def _format_seconds(seconds: float) -> str: + total = int(seconds) + h, rem = divmod(total, 3600) + m, s = divmod(rem, 60) + return f"{h:02d}:{m:02d}:{s:02d}" if h else f"{m:02d}:{s:02d}" + + +def _file_stem(item: Item) -> str: + """Build a filesystem-safe, collision-resistant stem. + + Every component goes through _safe_filename — defense in depth against an + upstream API returning a malicious id like '../../etc/passwd'. + + If the item_id sanitizes away to the fallback (e.g. all special chars), a + short hash of the raw (source, item_id) is appended so two such items + don't clobber each other on disk. + """ + safe_source = _safe_filename(item.source) + safe_id = _safe_filename(item.item_id) + safe_title = _safe_filename(item.title or "") + + if safe_id == _FALLBACK_FILENAME: + digest = hashlib.sha256( + f"{item.source}\0{item.item_id}".encode("utf-8") + ).hexdigest()[:8] + safe_id = f"{_FALLBACK_FILENAME}-{digest}" + + return f"{safe_source}_{safe_id}_{safe_title}" + + +def write_item_json(item: Item, out_dir: Path) -> Path: + path = out_dir / f"{_file_stem(item)}.json" + path.write_text(json.dumps(asdict(item), ensure_ascii=False, indent=2), encoding="utf-8") + return path + + +def write_item_markdown(item: Item, out_dir: Path) -> Path: + path = out_dir / f"{_file_stem(item)}.md" + lines: list[str] = [] + + title = item.title or item.item_id + lines.append(f"# {title}") + lines.append("") + lines.append(f"- **Source:** {item.source}") + lines.append(f"- **Author:** {item.author or '—'}") + lines.append(f"- **URL:** {item.url}") + lines.append(f"- **Published:** {item.published_at or '—'}") + if item.media: + media_pairs = ", ".join(f"{k}={v}" for k, v in item.media.items() if v is not None) + if media_pairs: + lines.append(f"- **Metrics:** {media_pairs}") + lines.append("") + + if item.text: + lines.append("## Text") + lines.append("") + lines.append(item.text.strip()) + lines.append("") + + lines.append("## Transcript") + lines.append("") + t = item.transcript + if t and t.segments: + kind = "auto" if t.is_generated else "manual" + lines.append(f"_Language: {t.language} ({kind})_") + lines.append("") + for seg in t.segments: + ts = _format_seconds(seg.get("start", 0)) + text = (seg.get("text") or "").replace("\n", " ").strip() + if text: + lines.append(f"- `[{ts}]` {text}") + lines.append("") + elif t and t.error: + lines.append(f"_Transcript error: {t.error}_") + lines.append("") + else: + lines.append("_No transcript available._") + lines.append("") + + lines.append(f"## Comments ({len(item.comments)})") + lines.append("") + if not item.comments: + lines.append("_No comments._") + lines.append("") + else: + by_parent: dict[str | None, list] = {} + for c in item.comments: + by_parent.setdefault(c.parent_id, []).append(c) + + for top in by_parent.get(None, []): + lines.append( + f"### {top.author or '—'} " + f"_({top.published_at or '—'}, ♥ {top.like_count})_" + ) + lines.append("") + lines.append((top.text or "").strip()) + lines.append("") + for reply in by_parent.get(top.comment_id, []): + lines.append( + f"> **{reply.author or '—'}** " + f"_({reply.published_at or '—'}, ♥ {reply.like_count})_" + ) + lines.append("> ") + for ln in (reply.text or "").strip().splitlines(): + lines.append(f"> {ln}") + lines.append("") + + path.write_text("\n".join(lines), encoding="utf-8") + return path + + +_CSV_INJECTION_PREFIXES = ("=", "+", "-", "@", "\t", "\r") + + +def _csv_safe(value): + """Defuse Excel-style formula injection in CSV cells. + + Excel/Sheets/LibreOffice treat any cell starting with =, +, -, @ as a + formula (incl. =cmd|'/c calc'!A1). User-controlled fields like title + and author can carry such payloads from third-party APIs. Prefixing + with a single quote keeps the value visible but neutralizes execution. + """ + if isinstance(value, str) and value and value[0] in _CSV_INJECTION_PREFIXES: + return "'" + value + return value + + +def write_summary_csv(items: list[Item], out_dir: Path) -> Path: + path = out_dir / "summary.csv" + metric_keys: set[str] = set() + for it in items: + metric_keys.update(it.media.keys()) + metric_keys_sorted = sorted(metric_keys) + + fields = [ + "source", "item_id", "title", "author", "url", + "published_at", "comments_fetched", + "transcript_language", "transcript_is_generated", + ] + metric_keys_sorted + + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for it in items: + row: dict = { + "source": _csv_safe(it.source), + "item_id": _csv_safe(it.item_id), + "title": _csv_safe(it.title), + "author": _csv_safe(it.author), + "url": _csv_safe(it.url), + "published_at": _csv_safe(it.published_at), + "comments_fetched": len(it.comments), + "transcript_language": _csv_safe(it.transcript.language if it.transcript else None), + "transcript_is_generated": it.transcript.is_generated if it.transcript else None, + } + for k in metric_keys_sorted: + row[k] = _csv_safe(it.media.get(k)) + writer.writerow(row) + return path + + +def write_index_markdown(items: list[Item], out_dir: Path) -> Path: + path = out_dir / "index.md" + lines = [f"# Results ({len(items)} item(s))", ""] + for it in items: + title = it.title or it.item_id + fname = f"{_file_stem(it)}.md" + comments = len(it.comments) + has_t = bool(it.transcript and it.transcript.segments) + lines.append( + f"- [{title}]({fname}) — `{it.source}`, " + f"{comments} comment(s), transcript: {'yes' if has_t else 'no'}" + ) + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path diff --git a/content_parser/core/plugin.py b/content_parser/core/plugin.py new file mode 100644 index 0000000..85ccf3d --- /dev/null +++ b/content_parser/core/plugin.py @@ -0,0 +1,70 @@ +"""Plugin contract — every source (YouTube, Instagram, Reddit, …) implements this.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Callable, Iterator, Literal + +from .schema import Item + + +WidgetType = Literal["text", "textarea", "password", "number", "checkbox", "select"] + + +@dataclass +class FieldSpec: + key: str + label: str + widget: WidgetType = "text" + default: Any = None + options: list[str] = field(default_factory=list) + help: str = "" + placeholder: str = "" + min_value: float | int | None = None + max_value: float | int | None = None + + +@dataclass +class InputSpec: + """Describes one kind of input the plugin can accept (e.g. 'channel', 'hashtag').""" + kind: str + label: str + placeholder: str = "" + help: str = "" + + +ProgressCb = Callable[[int, int, str], None] # (done, total, message) + + +class SourcePlugin(ABC): + name: str = "" # internal id, e.g. "youtube" + label: str = "" # human label, e.g. "YouTube" + secret_keys: list[str] = [] # required st.secrets / env vars + + @abstractmethod + def input_specs(self) -> list[InputSpec]: ... + + @abstractmethod + def settings_specs(self) -> list[FieldSpec]: ... + + @abstractmethod + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + """Resolve raw inputs into a deduplicated list of item identifiers.""" + + @abstractmethod + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + """Yield fully-populated Item for each id (or a partial Item with an error).""" + + def validate_secrets(self, secrets: dict[str, str]) -> list[str]: + return [k for k in self.secret_keys if not secrets.get(k)] diff --git a/content_parser/core/redact.py b/content_parser/core/redact.py new file mode 100644 index 0000000..c4ee9b5 --- /dev/null +++ b/content_parser/core/redact.py @@ -0,0 +1,29 @@ +"""Trim user-provided strings before they reach logs or exception messages. + +Used by every plugin's fetch() error path. URLs may carry tokens in their +?query or #fragment (OAuth implicit-flow access tokens, gsheets share +tokens, etc.); long strings can flood logs. Single shared implementation +so a future fix lands everywhere. +""" +from __future__ import annotations + + +_TRUNCATE_LIMIT = 80 +_TRUNCATE_MARKER = "…" + + +def redact_spec(spec: str) -> str: + """Drop ?query / #fragment, cap total length to 80 chars. + + >>> redact_spec("post:https://reddit.com/r/x/?token=secret") + 'post:https://reddit.com/r/x/?…' + >>> redact_spec("channel:https://t.me/durov#access_token=xxx") + 'channel:https://t.me/durov#…' + """ + for sep in ("?", "#"): + if sep in spec: + spec = spec.split(sep, 1)[0] + sep + _TRUNCATE_MARKER + break + if len(spec) > _TRUNCATE_LIMIT: + spec = spec[: _TRUNCATE_LIMIT - len(_TRUNCATE_MARKER)] + _TRUNCATE_MARKER + return spec diff --git a/content_parser/core/registry.py b/content_parser/core/registry.py new file mode 100644 index 0000000..f907ca1 --- /dev/null +++ b/content_parser/core/registry.py @@ -0,0 +1,80 @@ +"""Plugin discovery — explicit list, no entry-point magic. + +Plugins that fail to import for *missing optional dependencies* are skipped +quietly so the rest of the registry stays usable. Any other failure (typo, +runtime error in plugin module) is reported to stderr instead of disappearing. +""" +from __future__ import annotations + +import logging +import sys + +from .plugin import SourcePlugin + +logger = logging.getLogger(__name__) + + +def _try_load(loader, label: str) -> SourcePlugin | None: + try: + return loader() + except ImportError as e: + logger.debug("Skipping %s plugin (optional dep missing): %s", label, e) + return None + except Exception as e: # pragma: no cover - defensive + print( + f"[content_parser.registry] WARNING: {label} plugin failed to load: " + f"{type(e).__name__}: {e}", + file=sys.stderr, + ) + return None + + +def all_plugins() -> list[SourcePlugin]: + """Instantiate every registered plugin.""" + plugins: list[SourcePlugin] = [] + + def _load_youtube(): + from ..plugins.youtube.plugin import YouTubePlugin + return YouTubePlugin() + + def _load_instagram(): + from ..plugins.instagram.plugin import InstagramPlugin + return InstagramPlugin() + + def _load_reddit(): + from ..plugins.reddit.plugin import RedditPlugin + return RedditPlugin() + + def _load_vk(): + from ..plugins.vk.plugin import VKPlugin + return VKPlugin() + + def _load_telegram(): + from ..plugins.telegram.plugin import TelegramPlugin + return TelegramPlugin() + + def _load_instagram_graph(): + from ..plugins.instagram_graph.plugin import InstagramGraphPlugin + return InstagramGraphPlugin() + + for loader, label in [ + (_load_youtube, "youtube"), + (_load_instagram, "instagram"), + (_load_reddit, "reddit"), + (_load_vk, "vk"), + (_load_telegram, "telegram"), + (_load_instagram_graph, "instagram_graph"), + ]: + p = _try_load(loader, label) + if p is not None: + plugins.append(p) + + return plugins + + +def get_plugin(name: str) -> SourcePlugin: + for p in all_plugins(): + if p.name == name: + return p + available = [p.name for p in all_plugins()] + raise KeyError(f"No plugin named {name!r}. Available: {available}") diff --git a/content_parser/core/runner.py b/content_parser/core/runner.py new file mode 100644 index 0000000..38f4202 --- /dev/null +++ b/content_parser/core/runner.py @@ -0,0 +1,75 @@ +"""Source-agnostic orchestrator: resolve → fetch → write.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Callable + +from .output import ( + write_index_markdown, + write_item_json, + write_item_markdown, + write_summary_csv, +) +from .plugin import ProgressCb, SourcePlugin +from .schema import Item + + +LogCb = Callable[[str], None] + + +@dataclass +class RunResult: + out_dir: Path + items: list[Item] = field(default_factory=list) + + +def run( + plugin: SourcePlugin, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + output_dir: Path | None = None, + log: LogCb | None = None, + progress: ProgressCb | None = None, +) -> RunResult: + log = log or (lambda _msg: None) + + missing = plugin.validate_secrets(secrets) + if missing: + raise ValueError(f"Missing required secrets for {plugin.name}: {missing}") + + out_dir = output_dir or ( + Path("output") / plugin.name / datetime.now().strftime("%Y%m%d_%H%M%S") + ) + out_dir.mkdir(parents=True, exist_ok=True) + log(f"Output: {out_dir.resolve()}") + + log("Resolving inputs…") + item_ids = plugin.resolve(inputs, settings, secrets) + if not item_ids: + log("No items resolved.") + return RunResult(out_dir=out_dir, items=[]) + log(f"Found {len(item_ids)} item(s).") + + items: list[Item] = [] + fetch_error: BaseException | None = None + try: + for item in plugin.fetch(item_ids, settings, secrets, progress=progress): + items.append(item) + write_item_json(item, out_dir) + write_item_markdown(item, out_dir) + except BaseException as e: + fetch_error = e + finally: + # Always flush summary + index so partial runs are still inspectable. + if items: + write_summary_csv(items, out_dir) + write_index_markdown(items, out_dir) + + if fetch_error is not None: + log(f"Aborted after {len(items)} item(s): {type(fetch_error).__name__}: {fetch_error}") + raise fetch_error + log(f"Done — {len(items)} item(s).") + return RunResult(out_dir=out_dir, items=items) diff --git a/content_parser/core/schema.py b/content_parser/core/schema.py new file mode 100644 index 0000000..9447b4d --- /dev/null +++ b/content_parser/core/schema.py @@ -0,0 +1,45 @@ +"""Source-agnostic data model returned by every plugin.""" +from __future__ import annotations + +from dataclasses import dataclass, field, asdict +from typing import Any + + +@dataclass +class Comment: + comment_id: str + parent_id: str | None = None + author: str | None = None + author_id: str | None = None + text: str | None = None + like_count: int = 0 + published_at: str | None = None + updated_at: str | None = None + + +@dataclass +class Transcript: + language: str | None = None + is_generated: bool | None = None + segments: list[dict] = field(default_factory=list) + text: str = "" + error: str | None = None + + +@dataclass +class Item: + source: str + item_id: str + url: str + title: str | None = None + author: str | None = None + author_id: str | None = None + published_at: str | None = None + text: str | None = None + media: dict[str, Any] = field(default_factory=dict) + transcript: Transcript | None = None + comments: list[Comment] = field(default_factory=list) + extra: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + return asdict(self) diff --git a/content_parser/core/secrets.py b/content_parser/core/secrets.py new file mode 100644 index 0000000..8e7b627 --- /dev/null +++ b/content_parser/core/secrets.py @@ -0,0 +1,147 @@ +"""Single source of truth for reading and persisting credentials. + +Lookup order: st.secrets → env → ~/.content_parser/config.json. +Persistence: write to ~/.content_parser/config.json AND .streamlit/secrets.toml +so the same key is available locally and via st.secrets in the same project. +""" +from __future__ import annotations + +import json +import os +from pathlib import Path + +CONFIG_DIR = Path.home() / ".content_parser" +CONFIG_PATH = CONFIG_DIR / "config.json" +SECRETS_PATH = Path(".streamlit") / "secrets.toml" + + +def _load_local_config() -> dict[str, str]: + if not CONFIG_PATH.exists(): + return {} + try: + data = json.loads(CONFIG_PATH.read_text(encoding="utf-8")) + return {k: str(v) for k, v in data.items() if isinstance(v, (str, int, float))} + except Exception: + return {} + + +def _save_local_config(data: dict[str, str]) -> None: + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + CONFIG_PATH.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + try: + os.chmod(CONFIG_PATH, 0o600) + except OSError: + pass + + +def get_secret(name: str, default: str = "") -> str: + """Read a single secret by name.""" + try: + import streamlit as st # noqa: PLC0415 + try: + if name in st.secrets: + return str(st.secrets[name]) + except Exception: + pass + except ImportError: + pass + env = os.environ.get(name) + if env: + return env + return _load_local_config().get(name, default) + + +def save_secret(name: str, value: str) -> None: + data = _load_local_config() + data[name] = value + _save_local_config(data) + _upsert_secrets_toml(name, value) + + +def delete_secret(name: str) -> None: + data = _load_local_config() + if name in data: + del data[name] + _save_local_config(data) + _remove_from_secrets_toml(name) + + +def _toml_escape(value: str) -> str: + """Escape backslashes and double quotes for TOML basic strings.""" + return value.replace("\\", "\\\\").replace('"', '\\"') + + +def _is_streamlit_cloud() -> bool: + """Best-effort detection for Streamlit Cloud, where .streamlit/secrets.toml + is read-only and writing it would either fail or be silently ignored.""" + if os.environ.get("STREAMLIT_RUNTIME") == "cloud": + return True + if os.environ.get("STREAMLIT_SHARING") in ("1", "true", "True"): + return True + # Streamlit Cloud containers have hostnames like 'streamlit-app-xyz'. + hostname = os.environ.get("HOSTNAME", "") + return hostname.startswith("streamlit-") + + +def _upsert_secrets_toml(key: str, value: str) -> None: + if _is_streamlit_cloud(): + # secrets.toml is managed via Settings → Secrets in the Cloud UI; + # filesystem writes are pointless and may raise. Local config.json + # write in save_secret() above already persisted the value for this + # session — Cloud users have to mirror it via the dashboard. + return + SECRETS_PATH.parent.mkdir(parents=True, exist_ok=True) + line = f'{key} = "{_toml_escape(value)}"' + if SECRETS_PATH.exists(): + existing = SECRETS_PATH.read_text(encoding="utf-8").splitlines() + replaced = False + new_lines: list[str] = [] + for ln in existing: + stripped = ln.lstrip() + if stripped.startswith(f"{key}=") or stripped.startswith(f"{key} ="): + new_lines.append(line) + replaced = True + else: + new_lines.append(ln) + if not replaced: + new_lines.append(line) + SECRETS_PATH.write_text("\n".join(new_lines).rstrip() + "\n", encoding="utf-8") + else: + SECRETS_PATH.write_text(line + "\n", encoding="utf-8") + try: + os.chmod(SECRETS_PATH, 0o600) + except OSError: + pass + + +def _remove_from_secrets_toml(key: str) -> None: + if not SECRETS_PATH.exists(): + return + existing = SECRETS_PATH.read_text(encoding="utf-8").splitlines() + new_lines = [ + ln for ln in existing + if not (ln.lstrip().startswith(f"{key}=") or ln.lstrip().startswith(f"{key} =")) + ] + if new_lines and any(ln.strip() for ln in new_lines): + SECRETS_PATH.write_text("\n".join(new_lines).rstrip() + "\n", encoding="utf-8") + else: + SECRETS_PATH.unlink() + + +def secret_locations(name: str) -> list[str]: + """Where the secret currently lives — for UI hints.""" + locs: list[str] = [] + try: + import streamlit as st # noqa: PLC0415 + try: + if name in st.secrets: + locs.append("st.secrets") + except Exception: + pass + except ImportError: + pass + if os.environ.get(name): + locs.append("env") + if name in _load_local_config(): + locs.append(str(CONFIG_PATH)) + return locs diff --git a/content_parser/jobs/__init__.py b/content_parser/jobs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/jobs/cron.py b/content_parser/jobs/cron.py new file mode 100644 index 0000000..d06dd93 --- /dev/null +++ b/content_parser/jobs/cron.py @@ -0,0 +1,247 @@ +"""Manage a managed block in the user's crontab. + +We never touch lines outside our markers. `install` regenerates the block +from current jobs (anything inside the markers gets replaced); `remove` +deletes the block entirely; `read` returns what's currently inside. + +All shell-bound paths and arguments go through `shlex.quote` so a +malicious job name (which the schema regex already rejects) couldn't +inject extra commands even if it slipped through. +""" +from __future__ import annotations + +import shlex +import subprocess +import sys +from dataclasses import dataclass +from pathlib import Path + +from .schema import Job +from .store import list_jobs + + +BEGIN_MARKER = "# >>> content_parser jobs >>>" +END_MARKER = "# <<< content_parser jobs <<<" + + +def is_cron_available() -> bool: + """Probe for a usable `crontab` binary. Used by the UI to gray out the + cron-management buttons on hosts where crontab isn't installed + (e.g. Streamlit Cloud's container).""" + try: + subprocess.run( + ["crontab", "-l"], + capture_output=True, + text=True, + check=False, + ) + return True + except FileNotFoundError: + return False + + +@dataclass +class CronEntry: + """One cron line as managed by us.""" + schedule: str + job_name: str + command: str # full command string, post-quoting + + +class CronError(Exception): + pass + + +# ---------------------------------------------------------------------- +# Block edit + + +def _existing_crontab() -> str: + """Return the user's current crontab, or '' if none / no crontab program.""" + try: + proc = subprocess.run( + ["crontab", "-l"], + capture_output=True, + text=True, + check=False, + ) + except FileNotFoundError as e: + raise CronError( + "`crontab` command not found. cron-installer requires a Unix system " + "with cron installed (or use `cli jobs run` directly)." + ) from e + # crontab -l exits 1 with stderr 'no crontab for ...' when empty — treat as ''. + if proc.returncode != 0: + if "no crontab" in (proc.stderr or "").lower(): + return "" + raise CronError(f"crontab -l failed: {proc.stderr.strip() or proc.returncode}") + return proc.stdout + + +def _write_crontab(text: str) -> None: + """Replace the user's crontab with `text`.""" + try: + proc = subprocess.run( + ["crontab", "-"], + input=text, + text=True, + capture_output=True, + check=False, + ) + except FileNotFoundError as e: + raise CronError("`crontab` command not found.") from e + if proc.returncode != 0: + raise CronError(f"crontab - failed: {proc.stderr.strip() or proc.returncode}") + + +def _strip_block(crontab_text: str) -> str: + """Return crontab with our managed block removed (idempotent).""" + lines = crontab_text.splitlines() + out: list[str] = [] + in_block = False + for line in lines: + if line.strip() == BEGIN_MARKER: + in_block = True + continue + if line.strip() == END_MARKER: + in_block = False + continue + if not in_block: + out.append(line) + return "\n".join(out).rstrip() + ("\n" if out else "") + + +def _build_block(entries: list[CronEntry]) -> list[str]: + """Produce the lines (including markers) that go into the crontab.""" + if not entries: + return [] + lines = [BEGIN_MARKER] + for e in entries: + lines.append(f"{e.schedule} {e.command} # job:{e.job_name}") + lines.append(END_MARKER) + return lines + + +# ---------------------------------------------------------------------- +# Commands + + +def build_command_for_job( + job: Job, + *, + project_root: Path | None = None, + python_executable: str | None = None, + log_path: Path | None = None, +) -> str: + """Compose the shell command that cron should execute for this job. + + project_root: where to `cd` before running. Defaults to cwd. + python_executable: path to the python interpreter. Defaults to sys.executable. + log_path: file to append stdout+stderr to. Defaults to + <project_root>/output/scheduled/.cron.log. + """ + project_root = (project_root or Path.cwd()).resolve() + python_executable = python_executable or sys.executable + log_path = log_path or (project_root / "output" / "scheduled" / ".cron.log") + + # crontab format is line-based, and shlex.quote happily preserves a + # newline inside its single-quoted output. A path with a literal \n + # would split the cron entry into two lines and corrupt the file. + for label, value in ( + ("project_root", str(project_root)), + ("python_executable", python_executable), + ("log_path", str(log_path)), + ("job.name", job.name), + ): + if "\n" in value or "\r" in value: + raise CronError( + f"{label} contains a newline; crontab entries must be single-line." + ) + + cd_part = f"cd {shlex.quote(str(project_root))}" + run_part = " ".join([ + shlex.quote(python_executable), + "-m", "content_parser.cli", + "jobs", "run", + shlex.quote(job.name), + ]) + redirect = f">> {shlex.quote(str(log_path))} 2>&1" + return f"{cd_part} && {run_part} {redirect}" + + +def install_cron( + *, + jobs: list[Job] | None = None, + project_root: Path | None = None, + python_executable: str | None = None, + log_path: Path | None = None, +) -> list[CronEntry]: + """Regenerate our managed block in the user's crontab. + + Anything outside the markers is preserved. Jobs without a schedule are + skipped. Returns the entries that ended up in the block. + """ + job_list = jobs if jobs is not None else list_jobs() + entries: list[CronEntry] = [] + for job in job_list: + if not job.schedule: + continue + cmd = build_command_for_job( + job, + project_root=project_root, + python_executable=python_executable, + log_path=log_path, + ) + entries.append(CronEntry(schedule=job.schedule, job_name=job.name, command=cmd)) + + existing = _existing_crontab() + stripped = _strip_block(existing) + block = _build_block(entries) + if not block: + # Nothing to install — still flush the (now empty) block from crontab. + new_crontab = stripped + else: + new_crontab = (stripped + "\n".join(block) + "\n") if stripped else ("\n".join(block) + "\n") + _write_crontab(new_crontab) + return entries + + +def remove_cron() -> bool: + """Delete our managed block from crontab. Returns True if anything was removed.""" + existing = _existing_crontab() + if BEGIN_MARKER not in existing: + return False + _write_crontab(_strip_block(existing)) + return True + + +def read_block() -> list[CronEntry]: + """Parse current managed block back into entries (best-effort).""" + existing = _existing_crontab() + out: list[CronEntry] = [] + in_block = False + for line in existing.splitlines(): + stripped_line = line.strip() + if stripped_line == BEGIN_MARKER: + in_block = True + continue + if stripped_line == END_MARKER: + break + if not in_block or not stripped_line or stripped_line.startswith("#"): + continue + # Format: "<5 schedule tokens> <command> # job:<name>" + parts = stripped_line.split(None, 5) + if len(parts) < 6: + continue + schedule = " ".join(parts[:5]) + rest = parts[5] + # Try to extract job name from the trailing " # job:<name>" comment. + job_name = "" + if "# job:" in rest: + cmd_part, _, comment = rest.partition("# job:") + cmd_part = cmd_part.rstrip() + job_name = comment.strip() + else: + cmd_part = rest + out.append(CronEntry(schedule=schedule, job_name=job_name, command=cmd_part)) + return out diff --git a/content_parser/jobs/runner.py b/content_parser/jobs/runner.py new file mode 100644 index 0000000..b94296f --- /dev/null +++ b/content_parser/jobs/runner.py @@ -0,0 +1,201 @@ +"""Run a saved job: merge inline + Sheet inputs, call core.runner.run.""" +from __future__ import annotations + +import json +import traceback +from datetime import datetime +from pathlib import Path +from typing import Callable + +from ..core.errors import PluginError +from ..core.registry import get_plugin +from ..core.runner import RunResult, run as core_run +from ..core.secrets import get_secret +from .schema import Job +from .store import load_job + + +# Optional secrets that any plugin may need but may not have been declared. +_OPTIONAL_SECRET_KEYS = ( + "WEBSHARE_USERNAME", + "WEBSHARE_PASSWORD", + "PROXY_HTTP_URL", + "PROXY_HTTPS_URL", + "GOOGLE_SHEETS_CREDENTIALS", + "OPENAI_API_KEY", + "INSTAGRAM_ACCESS_TOKEN", +) + + +def _collect_secrets(plugin_secret_keys: list[str], *, need_sheets: bool) -> dict[str, str]: + """Gather every secret the job may need.""" + keys = list(plugin_secret_keys) + if need_sheets and "GOOGLE_SHEETS_CREDENTIALS" not in keys: + keys.append("GOOGLE_SHEETS_CREDENTIALS") + secrets: dict[str, str] = {k: get_secret(k) for k in keys} + for opt in _OPTIONAL_SECRET_KEYS: + v = get_secret(opt) + if v: + secrets[opt] = v + return secrets + + +def _resolve_inputs(job: Job, secrets: dict[str, str]) -> dict[str, list[str]]: + """Combine inline inputs with values pulled from any Google Sheets refs.""" + merged: dict[str, list[str]] = {k: list(v) for k, v in job.inputs.items()} + + if job.sheet_inputs: + from ..loaders.gsheets import GoogleSheetsLoader # noqa: PLC0415 + + loader = GoogleSheetsLoader.from_secrets(secrets) + for ref in job.sheet_inputs: + loaded = loader.load( + ref.sheet, + tab=ref.tab, + range_a1=ref.range_a1, + skip_header=ref.skip_header, + ) + merged.setdefault(ref.target, []).extend(loaded.values) + + # Per-kind dedupe preserving insertion order. + for kind, values in list(merged.items()): + merged[kind] = list(dict.fromkeys(values)) + # Drop empty kinds so plugins don't see them. + return {k: v for k, v in merged.items() if v} + + +def run_job( + name: str, + *, + log: Callable[[str], None] | None = None, + progress=None, +) -> RunResult: + """Resolve inputs and run a saved job. Writes last_error.txt on failure.""" + job = load_job(name) + return run_job_obj(job, log=log, progress=progress) + + +def run_job_obj( + job: Job, + *, + log: Callable[[str], None] | None = None, + progress=None, +) -> RunResult: + log = log or (lambda _msg: None) + + # Compute out_dir ONCE so an early failure (Sheets-load error, + # empty-resolved-inputs, etc.) writes last_error.txt into the same + # timestamped directory the run would have used — instead of creating + # a brand-new timestamped dir just for the error file. + out_dir = job.resolved_output_dir() + + log(f"Job: {job.name} (source={job.source})") + log(f"Output: {out_dir}") + + secrets = _collect_secrets( + get_plugin(job.source).secret_keys, + need_sheets=bool(job.sheet_inputs), + ) + + try: + inputs = _resolve_inputs(job, secrets) + except Exception as e: + _record_failure(job, e, out_dir=out_dir, secrets=secrets) + raise + + if not inputs: + msg = f"Job {job.name!r} has no resolved inputs (inline empty, Sheets returned nothing)." + _record_failure(job, PluginError(msg), out_dir=out_dir, secrets=secrets) + raise PluginError(msg) + + plugin = get_plugin(job.source) + + try: + result = core_run( + plugin, + inputs, + job.settings, + secrets, + output_dir=out_dir, + log=log, + progress=progress, + ) + except Exception as e: + _record_failure(job, e, out_dir=out_dir, secrets=secrets) + raise + + _record_success(job, out_dir, result) + return result + + +# ---------------------------------------------------------------------- +# Last-run / last-error markers + + +def _record_success(job: Job, out_dir: Path, result: RunResult) -> None: + # core_run usually created out_dir already, but mkdir is idempotent and + # protects against the edge case where it bailed out before doing so. + out_dir.mkdir(parents=True, exist_ok=True) + marker = out_dir / ".last_run.txt" + marker.write_text( + f"job: {job.name}\n" + f"finished_at: {datetime.now().isoformat()}\n" + f"items: {len(result.items)}\n", + encoding="utf-8", + ) + _write_status(job, out_dir, status="success", items=len(result.items)) + + +def _write_status( + job: Job, + out_dir: Path, + *, + status: str, + items: int = 0, + error: str | None = None, +) -> None: + """Single canonical status file consumed by monitoring / UI. + + Lives in <output_dir>/.last_status.json and is overwritten on every + run so a watcher only needs to mtime + parse one path. + """ + payload = { + "job": job.name, + "source": job.source, + "status": status, # "success" | "failure" + "finished_at": datetime.now().isoformat(), + "items": items, + "error": error, + } + try: + out_dir.mkdir(parents=True, exist_ok=True) + path = out_dir / ".last_status.json" + tmp = path.with_suffix(".json.tmp") + tmp.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + tmp.replace(path) + except OSError: + pass + + +def _record_failure(job: Job, exc: Exception, *, out_dir: Path, secrets: dict[str, str] | None = None) -> None: + if job.notify_on_failure == "none": + return + text = ( + f"job: {job.name}\n" + f"failed_at: {datetime.now().isoformat()}\n" + f"error: {type(exc).__name__}: {exc}\n\n" + f"{traceback.format_exc()}" + ) + # Tracebacks may include URLs / response bodies that embed tokens. + # Replace every secret value we know about with [REDACTED] before the + # file lands on disk where logs can be shared in support tickets. + for value in (secrets or {}).values(): + if isinstance(value, str) and len(value) >= 8 and value in text: + text = text.replace(value, "[REDACTED]") + + try: + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "last_error.txt").write_text(text, encoding="utf-8") + except OSError: + pass + _write_status(job, out_dir, status="failure", error=f"{type(exc).__name__}: {exc}") diff --git a/content_parser/jobs/schema.py b/content_parser/jobs/schema.py new file mode 100644 index 0000000..ad02017 --- /dev/null +++ b/content_parser/jobs/schema.py @@ -0,0 +1,248 @@ +"""Job schema: YAML on disk ↔ Job dataclass in memory. + +A job describes one scheduled run: which plugin to use, what inputs it needs +(inline values, Google Sheets references, or both), and an optional cron +schedule. Jobs without a schedule still work via `cli jobs run <name>`. +""" +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from ..core.errors import PluginError + + +# Cron expressions can vary, but common forms have 5 whitespace-separated tokens: +# "min hour day-of-month month day-of-week" with @-aliases as a separate case. +_CRON_TOKEN_RE = re.compile(r"^[\d\*/,\-A-Za-z]+$") +_CRON_ALIAS_RE = re.compile(r"^@(yearly|annually|monthly|weekly|daily|hourly|reboot)$") +# Job names map to filenames, so they must be a safe filesystem token. +_JOB_NAME_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + + +@dataclass +class SheetInput: + """One Google Sheets reference inside a job.""" + sheet: str # URL or ID + target: str # plugin input kind, e.g. "community" / "channel" + tab: str | None = None # None or "" → first sheet + range_a1: str = "A:A" + skip_header: bool = False + + +@dataclass +class Job: + """In-memory job. Use `Job.from_dict` / `Job.to_dict` for YAML I/O.""" + name: str + source: str + inputs: dict[str, list[str]] = field(default_factory=dict) + sheet_inputs: list[SheetInput] = field(default_factory=list) + settings: dict[str, Any] = field(default_factory=dict) + schedule: str | None = None + description: str | None = None + output_dir: str | None = None + notify_on_failure: str = "log" # "log" | "none" + + # ------------------------------------------------------------------ + # Validation + + def validate(self) -> None: + if not _JOB_NAME_RE.match(self.name): + raise PluginError( + f"Invalid job name {self.name!r}. " + "Allowed: letters, digits, underscore, hyphen; 1-64 chars." + ) + if not self.source or not isinstance(self.source, str): + raise PluginError(f"Job {self.name!r} is missing 'source'.") + if self.schedule is not None and not is_valid_cron(self.schedule): + raise PluginError( + f"Job {self.name!r} schedule {self.schedule!r} is not a valid cron expression." + ) + if self.notify_on_failure not in ("log", "none"): + raise PluginError( + f"Job {self.name!r} notify_on_failure must be 'log' or 'none', " + f"got {self.notify_on_failure!r}." + ) + if not self.inputs and not self.sheet_inputs: + raise PluginError( + f"Job {self.name!r} has no inputs (need 'inputs' or 'sheet_inputs')." + ) + for kind, values in self.inputs.items(): + if not isinstance(values, list): + raise PluginError( + f"Job {self.name!r} inputs.{kind} must be a list, got {type(values).__name__}." + ) + for ref in self.sheet_inputs: + if not ref.sheet: + raise PluginError( + f"Job {self.name!r} sheet_inputs entry missing 'sheet'." + ) + if not ref.target: + raise PluginError( + f"Job {self.name!r} sheet_inputs entry missing 'target'." + ) + # output_dir guard: reject path-traversal segments. Absolute paths are + # allowed (user explicitly opted in), but ".." anywhere in the value + # is rejected — it's almost always a bug, and on shared/multi-user + # hosts could surprise the user where files actually land. + if self.output_dir: + parts = Path(self.output_dir).parts + if ".." in parts: + raise PluginError( + f"Job {self.name!r} output_dir {self.output_dir!r} must not contain '..'." + ) + + # ------------------------------------------------------------------ + # Output dir resolution + + def resolved_output_dir(self, *, timestamp: str | None = None) -> Path: + """Return the path where this run should write its files. + + Default: output/scheduled/<job-name>/<timestamp>/. + If `output_dir` is set in YAML and absolute, used as-is. + If relative, resolved against cwd; the timestamp subdir is still appended. + """ + from datetime import datetime # noqa: PLC0415 + + ts = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") + if self.output_dir: + base = Path(self.output_dir) + if not base.is_absolute(): + base = Path.cwd() / base + else: + base = Path("output") / "scheduled" / self.name + return base / ts + + # ------------------------------------------------------------------ + # Serialization + + @classmethod + def from_dict(cls, data: dict, *, name_hint: str | None = None) -> Job: + if not isinstance(data, dict): + raise PluginError( + f"Job YAML must be a mapping at the top level, got {type(data).__name__}." + ) + # Allow filename to provide the name when not in the body. + name = data.get("name") or name_hint or "" + sheet_refs_raw = data.get("sheet_inputs") or [] + sheet_inputs: list[SheetInput] = [] + for i, ref in enumerate(sheet_refs_raw): + if not isinstance(ref, dict): + raise PluginError( + f"sheet_inputs[{i}] must be a mapping, got {type(ref).__name__}." + ) + sheet_inputs.append(SheetInput( + sheet=str(ref.get("sheet") or ""), + target=str(ref.get("target") or ""), + tab=ref.get("tab"), + range_a1=str(ref.get("range") or ref.get("range_a1") or "A:A"), + skip_header=bool(ref.get("skip_header", False)), + )) + # inputs: dict of kind → list[str] + inputs_raw = data.get("inputs") or {} + if not isinstance(inputs_raw, dict): + raise PluginError("'inputs' must be a mapping kind → list.") + inputs: dict[str, list[str]] = {} + for k, v in inputs_raw.items(): + # Catch the common typo 'community: name' (string instead of list of names). + # Without this check, the loop would iterate the string character by + # character and produce one-letter "values" — silent corruption. + if v is not None and not isinstance(v, list): + raise PluginError( + f"inputs.{k} must be a list, got {type(v).__name__}: {v!r}. " + "Wrap a single value in [] like `inputs.{k}: [name]`." + ) + inputs[str(k)] = [str(x) for x in (v or [])] + settings = data.get("settings") or {} + if not isinstance(settings, dict): + raise PluginError("'settings' must be a mapping.") + + job = cls( + name=str(name), + source=str(data.get("source") or ""), + inputs=inputs, + sheet_inputs=sheet_inputs, + settings=dict(settings), + schedule=(str(data["schedule"]) if data.get("schedule") else None), + description=(str(data["description"]) if data.get("description") else None), + output_dir=(str(data["output_dir"]) if data.get("output_dir") else None), + notify_on_failure=str(data.get("notify_on_failure") or "log"), + ) + job.validate() + return job + + def to_dict(self) -> dict: + out: dict[str, Any] = { + "name": self.name, + "source": self.source, + } + if self.description: + out["description"] = self.description + if self.schedule: + out["schedule"] = self.schedule + if self.inputs: + out["inputs"] = {k: list(v) for k, v in self.inputs.items()} + if self.sheet_inputs: + out["sheet_inputs"] = [ + { + "sheet": ref.sheet, + "target": ref.target, + **({"tab": ref.tab} if ref.tab else {}), + "range": ref.range_a1, + **({"skip_header": True} if ref.skip_header else {}), + } + for ref in self.sheet_inputs + ] + if self.settings: + out["settings"] = dict(self.settings) + if self.output_dir: + out["output_dir"] = self.output_dir + if self.notify_on_failure != "log": + out["notify_on_failure"] = self.notify_on_failure + return out + + +# ---------------------------------------------------------------------- +# YAML I/O + + +def load_job_yaml(text: str, *, name_hint: str | None = None) -> Job: + """Parse a YAML document into a Job. ALWAYS uses safe_load.""" + import yaml # noqa: PLC0415 + + try: + data = yaml.safe_load(text) + except yaml.YAMLError as e: + raise PluginError(f"Invalid YAML: {e}") from e + if data is None: + raise PluginError("Empty YAML document.") + return Job.from_dict(data, name_hint=name_hint) + + +def dump_job_yaml(job: Job) -> str: + import yaml # noqa: PLC0415 + + return yaml.safe_dump( + job.to_dict(), allow_unicode=True, sort_keys=False, default_flow_style=False + ) + + +# ---------------------------------------------------------------------- +# Cron validation + + +def is_valid_cron(expr: str) -> bool: + """Loose check: 5 whitespace-separated tokens of cron-y characters, + or a recognised @-alias (@daily, @weekly, ...). Doesn't fully validate + field ranges — leaves that to crond at install time.""" + if not isinstance(expr, str): + return False + expr = expr.strip() + if _CRON_ALIAS_RE.match(expr): + return True + parts = expr.split() + if len(parts) != 5: + return False + return all(_CRON_TOKEN_RE.match(p) for p in parts) diff --git a/content_parser/jobs/store.py b/content_parser/jobs/store.py new file mode 100644 index 0000000..74ae594 --- /dev/null +++ b/content_parser/jobs/store.py @@ -0,0 +1,95 @@ +"""Filesystem-backed job store at ~/.content_parser/jobs/<name>.yaml. + +All paths are resolved through `Path.resolve()` and checked to live under +JOBS_DIR before any read/write — the validated job-name regex prevents +filenames like `../../etc/passwd.yaml`, but the resolve check is a belt +on top of the suspenders. +""" +from __future__ import annotations + +import os +from pathlib import Path + +from ..core.errors import PluginError +from .schema import Job, _JOB_NAME_RE, dump_job_yaml, load_job_yaml + + +JOBS_DIR = Path.home() / ".content_parser" / "jobs" + + +def _job_path(name: str) -> Path: + if not _JOB_NAME_RE.match(name): + raise PluginError( + f"Invalid job name {name!r}. " + "Allowed: letters, digits, underscore, hyphen; 1-64 chars." + ) + JOBS_DIR.mkdir(parents=True, exist_ok=True) + candidate = (JOBS_DIR / f"{name}.yaml").resolve() + # Defense-in-depth: refuse if the resolved path escapes JOBS_DIR. + base = JOBS_DIR.resolve() + try: + candidate.relative_to(base) + except ValueError: + raise PluginError(f"Job path escapes the jobs directory: {candidate}") + return candidate + + +def list_jobs() -> list[Job]: + """Return every well-formed job file in the directory, sorted by name. + + Files that fail to parse are skipped; their names are surfaced separately + via `list_invalid()` if a caller wants to show errors. + """ + if not JOBS_DIR.exists(): + return [] + out: list[Job] = [] + for path in sorted(JOBS_DIR.glob("*.yaml")): + try: + out.append(load_job(path.stem)) + except PluginError: + continue + return out + + +def list_invalid() -> list[tuple[str, str]]: + """Return (name, error) pairs for files that don't parse.""" + if not JOBS_DIR.exists(): + return [] + out: list[tuple[str, str]] = [] + for path in sorted(JOBS_DIR.glob("*.yaml")): + try: + load_job(path.stem) + except PluginError as e: + out.append((path.stem, str(e))) + return out + + +def load_job(name: str) -> Job: + path = _job_path(name) + if not path.exists(): + raise PluginError(f"Job {name!r} not found at {path}") + return load_job_yaml(path.read_text(encoding="utf-8"), name_hint=name) + + +def save_job(job: Job) -> Path: + job.validate() + path = _job_path(job.name) + path.write_text(dump_job_yaml(job), encoding="utf-8") + try: + os.chmod(path, 0o600) + except OSError: + pass + return path + + +def delete_job(name: str) -> bool: + """Remove a job file. Returns True if removed, False if it didn't exist.""" + path = _job_path(name) + if path.exists(): + path.unlink() + return True + return False + + +def job_exists(name: str) -> bool: + return _job_path(name).exists() diff --git a/content_parser/loaders/__init__.py b/content_parser/loaders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/loaders/gsheets.py b/content_parser/loaders/gsheets.py new file mode 100644 index 0000000..8b9654d --- /dev/null +++ b/content_parser/loaders/gsheets.py @@ -0,0 +1,240 @@ +"""Google Sheets loader: pull a column/range of values for plugin inputs. + +Auth: service-account JSON. The user creates a service account in Google Cloud, +downloads its JSON key, stores it in the GOOGLE_SHEETS_CREDENTIALS secret, and +shares each target spreadsheet with the service account's email address. + +Usage: + loader = GoogleSheetsLoader.from_secrets({"GOOGLE_SHEETS_CREDENTIALS": "..."}) + loaded = loader.load(sheet_id_or_url, tab="Communities", range_a1="A2:A100") + print(loaded.values) # ['durov_says', 'telegram', ...] +""" +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import urlparse + +from ..core.errors import AuthError, PluginError + + +SHEETS_SCOPES = [ + "https://www.googleapis.com/auth/spreadsheets.readonly", +] + + +# Sheet ID embedded in /d/<id>/ in the spreadsheet URL. +_URL_SHEET_ID_RE = re.compile(r"/d/([A-Za-z0-9_-]{20,})") +# Bare sheet ID (no slashes, just the alphanumeric portion). +_BARE_SHEET_ID_RE = re.compile(r"^[A-Za-z0-9_-]{20,}$") +# A1 range like "Sheet1!A2:B10" or just "A2:B10" — light validation. +_RANGE_A1_RE = re.compile(r"^[A-Za-z0-9_]+(?::[A-Za-z0-9_]+)?$|^[^!]+!.+$") + + +@dataclass +class LoadedRange: + """Result of a Google Sheets load.""" + sheet_title: str + tab_title: str + range_a1: str + values: list[str] + sheet_url: str + + @property + def count(self) -> int: + return len(self.values) + + +class GoogleSheetsLoader: + """Thin wrapper around gspread for read-only sheet access.""" + + def __init__(self, credentials_json: dict | str): + self._credentials_json = self.validate_credentials(credentials_json) + self._gc = self._build_client() + + @staticmethod + def validate_credentials(credentials_json: dict | str) -> dict: + """Parse + shape-check service account JSON. Raises AuthError on bad input. + + Does NOT build a gspread client — safe to call without network. + Returns the parsed dict for the caller to reuse. + """ + if isinstance(credentials_json, str): + try: + credentials_json = json.loads(credentials_json) + except json.JSONDecodeError as e: + raise AuthError( + "GOOGLE_SHEETS_CREDENTIALS is not valid JSON. " + "Paste the full service-account JSON file contents." + ) from e + + if not isinstance(credentials_json, dict): + raise AuthError("GOOGLE_SHEETS_CREDENTIALS must be a JSON object.") + if credentials_json.get("type") != "service_account": + raise AuthError( + "GOOGLE_SHEETS_CREDENTIALS is not a service account JSON " + "(missing or wrong 'type' field). OAuth client JSONs won't work — " + "use a service account key." + ) + for required in ("client_email", "private_key"): + if not credentials_json.get(required): + raise AuthError( + f"GOOGLE_SHEETS_CREDENTIALS missing field {required!r} — " + "is this a service account JSON?" + ) + return credentials_json + + @classmethod + def from_secrets(cls, secrets: dict[str, str]) -> GoogleSheetsLoader: + raw = secrets.get("GOOGLE_SHEETS_CREDENTIALS") + if not raw: + raise AuthError("GOOGLE_SHEETS_CREDENTIALS is required.") + return cls(raw) + + def service_account_email(self) -> str | None: + """Email to share spreadsheets with. Useful for UX hints.""" + return self._credentials_json.get("client_email") + + # ------------------------------------------------------------------ + + def _build_client(self) -> Any: + try: + from google.oauth2.service_account import Credentials # noqa: PLC0415 + import gspread # noqa: PLC0415 + except ImportError as e: + raise PluginError( + "Install gspread + google-auth: pip install gspread google-auth" + ) from e + + try: + creds = Credentials.from_service_account_info( + self._credentials_json, scopes=SHEETS_SCOPES + ) + except Exception as e: + raise AuthError(f"Cannot build credentials: {e}") from e + return gspread.authorize(creds) + + # ------------------------------------------------------------------ + + def load( + self, + sheet: str, + *, + tab: str | None = None, + range_a1: str = "A:A", + skip_header: bool = False, + ) -> LoadedRange: + """Read a range from a sheet, return a flat list of non-empty values. + + sheet: spreadsheet ID or full URL. + tab: worksheet (tab) name. None or empty string → first sheet. + range_a1: A1 notation. Defaults to entire column A. + skip_header: drop the first row (e.g. when header is in row 1). + """ + sheet_id = self._extract_sheet_id(sheet) + if not range_a1 or not _RANGE_A1_RE.match(range_a1): + raise PluginError( + f"Invalid range {range_a1!r}. Use A1 notation, e.g. 'A2:A100' or 'A:A'." + ) + + # Defensive: callers (cron config, CLI) may pass empty string for "first sheet". + tab = tab.strip() if isinstance(tab, str) else tab + if not tab: + tab = None + + try: + spreadsheet = self._gc.open_by_key(sheet_id) + except Exception as e: + self._raise_friendly_open_error(e, sheet_id) + + try: + worksheet = spreadsheet.worksheet(tab) if tab else spreadsheet.sheet1 + except Exception as e: + available = self._list_tab_names(spreadsheet) + raise PluginError( + f"Tab {tab!r} not found in spreadsheet. Available: {available}" + ) from e + + try: + rows = worksheet.get(range_a1) or [] + except Exception as e: + raise PluginError(f"Cannot read range {range_a1!r}: {e}") from e + + if skip_header and rows: + rows = rows[1:] + + # Flatten: each row becomes its non-empty cells; concatenate cells of all rows. + # For a single-column range this collapses cleanly to one value per row. + values: list[str] = [] + seen: set[str] = set() + for row in rows: + for cell in row: + v = (cell or "").strip() + if v and v not in seen: + seen.add(v) + values.append(v) + + return LoadedRange( + sheet_title=spreadsheet.title, + tab_title=worksheet.title, + range_a1=range_a1, + values=values, + sheet_url=f"https://docs.google.com/spreadsheets/d/{sheet_id}", + ) + + # ------------------------------------------------------------------ + # Helpers + + @staticmethod + def _extract_sheet_id(sheet: str) -> str: + v = (sheet or "").strip() + if not v: + raise PluginError("Sheet ID or URL is required.") + + # Bare ID (no URL). + if _BARE_SHEET_ID_RE.match(v): + return v + + # URL form: validate host strictly first, then pull the ID from the path. + # Without the host check, a URL like https://evil.com/spreadsheets/d/<id>/ + # would silently get its `/d/<id>/` substring matched and accepted. + if v.lower().startswith(("http://", "https://")): + parsed = urlparse(v) + host = (parsed.hostname or "").lower() + if host != "docs.google.com": + raise PluginError( + f"Cannot extract spreadsheet ID from {v!r}. " + "URL host must be docs.google.com." + ) + m = _URL_SHEET_ID_RE.search(parsed.path) + if m: + return m.group(1) + + raise PluginError( + f"Cannot extract spreadsheet ID from {v!r}. " + "Expected a docs.google.com/spreadsheets/d/<ID>/... URL or the ID itself." + ) + + @staticmethod + def _list_tab_names(spreadsheet: Any) -> list[str]: + try: + return [ws.title for ws in spreadsheet.worksheets()] + except Exception: + return [] + + @staticmethod + def _raise_friendly_open_error(e: Exception, sheet_id: str) -> None: + msg = str(e) + # gspread raises APIError; check for common substrings. + if "404" in msg or "not found" in msg.lower(): + raise PluginError( + f"Spreadsheet {sheet_id!r} not found. Check the URL/ID is correct." + ) from e + if "403" in msg or "permission" in msg.lower() or "denied" in msg.lower(): + raise AuthError( + f"Service account doesn't have access to spreadsheet {sheet_id!r}. " + "Share the sheet with the service account email." + ) from e + raise PluginError(f"Cannot open spreadsheet: {e}") from e diff --git a/content_parser/plugins/__init__.py b/content_parser/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/instagram/__init__.py b/content_parser/plugins/instagram/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/instagram/adapter.py b/content_parser/plugins/instagram/adapter.py new file mode 100644 index 0000000..31ba3e6 --- /dev/null +++ b/content_parser/plugins/instagram/adapter.py @@ -0,0 +1,68 @@ +"""Convert Apify Instagram-scraper output dicts into the unified core schema.""" +from __future__ import annotations + +from ...core.schema import Comment, Item + + +def _comment_from_apify(c: dict, parent_id: str | None = None) -> Comment: + return Comment( + comment_id=str(c.get("id") or c.get("comment_id") or ""), + parent_id=parent_id, + author=c.get("ownerUsername") or c.get("owner_username"), + author_id=str(c.get("ownerId") or c.get("owner_id") or "") or None, + text=c.get("text"), + like_count=int(c.get("likesCount", 0) or 0), + published_at=c.get("timestamp"), + ) + + +def _flatten_comments(raw: list[dict]) -> list[Comment]: + out: list[Comment] = [] + for c in raw or []: + top = _comment_from_apify(c) + out.append(top) + for reply in c.get("replies", []) or []: + out.append(_comment_from_apify(reply, parent_id=top.comment_id)) + return out + + +def post_to_item(post: dict) -> Item: + short_code = post.get("shortCode") or post.get("shortcode") or post.get("id", "") + url = post.get("url") or ( + f"https://www.instagram.com/p/{short_code}/" if short_code else "" + ) + + music = post.get("musicInfo") or {} + media: dict = { + "type": post.get("type"), # Image, Video, Sidecar + "view_count": post.get("videoViewCount") or post.get("videoPlayCount"), + "like_count": post.get("likesCount"), + "comment_count": post.get("commentsCount"), + "video_url": post.get("videoUrl"), + "display_url": post.get("displayUrl"), + "video_duration": post.get("videoDuration"), + "audio_id": music.get("audio_id") or music.get("audioId"), + "audio_artist": music.get("artist_name") or music.get("artistName"), + "audio_title": music.get("song_name") or music.get("songName"), + } + media = {k: v for k, v in media.items() if v is not None} + + return Item( + source="instagram", + item_id=short_code or post.get("id", ""), + url=url, + title=(post.get("caption") or "").split("\n", 1)[0][:120] or None, + author=post.get("ownerUsername") or post.get("owner_username"), + author_id=str(post.get("ownerId") or "") or None, + published_at=post.get("timestamp"), + text=post.get("caption"), + media=media, + comments=_flatten_comments(post.get("latestComments") or post.get("comments") or []), + extra={ + "hashtags": post.get("hashtags", []), + "mentions": post.get("mentions", []), + "location_name": post.get("locationName"), + "is_sponsored": post.get("isSponsored"), + "product_type": post.get("productType"), + }, + ) diff --git a/content_parser/plugins/instagram/plugin.py b/content_parser/plugins/instagram/plugin.py new file mode 100644 index 0000000..63b5d88 --- /dev/null +++ b/content_parser/plugins/instagram/plugin.py @@ -0,0 +1,209 @@ +"""Instagram plugin — posts and reels via Apify's instagram-scraper actor.""" +from __future__ import annotations + +import re +from typing import Any, Iterator +from urllib.parse import urlparse + +from ...core.errors import AuthError, PluginError +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.schema import Item +from .adapter import post_to_item +from ...clients.apify import ApifyClient, ApifyError + + +ACTOR_ID = "apify/instagram-scraper" + +# Path segments that indicate a URL refers to a post/reel rather than an account. +_POST_PATH_SEGMENTS = {"p", "reel", "reels", "tv", "explore", "stories"} +_USERNAME_RE = re.compile(r"^[A-Za-z0-9._]{1,30}$") + + +class InstagramPlugin(SourcePlugin): + name = "instagram" + label = "Instagram" + secret_keys = ["APIFY_API_TOKEN"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="hashtag", + label="Хэштеги", + placeholder="smm\nреклама\nmarketing", + help="Без #. Каждая строка — один тег.", + ), + InputSpec( + kind="account", + label="Аккаунты", + placeholder="nasa\n@durov\nhttps://instagram.com/zuck", + help="Только username, @handle или URL профиля. Ссылки на /p/ или /reel/ — во вкладку «Ссылки на посты/рилсы».", + ), + InputSpec( + kind="post_url", + label="Ссылки на посты/рилсы", + placeholder="https://www.instagram.com/p/xxxxx/\nhttps://www.instagram.com/reel/yyyyy/", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec("max_posts_per_input", "Макс. постов на источник", "number", 20, + min_value=1, max_value=500, + help="Apify тарифицируется за пост — большие значения дороже."), + FieldSpec("results_type", "Тип данных для аккаунтов/хэштегов", + "select", "posts", + options=["posts", "details", "comments"], + help="Для прямых ссылок на посты/рилсы всегда используется 'details'."), + FieldSpec("add_parent_data", "Включать данные родительского аккаунта", "checkbox", False), + FieldSpec("transcribe_videos", "🎤 Транскрибировать видео (Whisper)", "checkbox", False, + help="Скачивает аудио рилса и шлёт в OpenAI Whisper. " + "Нужен OPENAI_API_KEY и ffmpeg на машине. ~$0.006/мин."), + FieldSpec("max_audio_seconds_per_video", "Макс. секунд аудио на пост", + "number", 600, min_value=10, max_value=3600, + help="Если рилс длиннее — пропускается. Защита от случайных счетов."), + ] + + # ------------------------------------------------------------------ + # Resolve + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + """Returns 'kind:url' specs so fetch() can route each to the right Apify call.""" + specs: list[str] = [] + for h in inputs.get("hashtag", []): + tag = h.strip().lstrip("#") + if tag: + specs.append(f"hashtag:https://www.instagram.com/explore/tags/{tag}/") + + for a in inputs.get("account", []): + user = self._normalize_account(a) + specs.append(f"account:https://www.instagram.com/{user}/") + + for u in inputs.get("post_url", []): + url = u.strip() + if not url: + continue + if not self._is_post_url(url): + raise PluginError( + f"{url!r} doesn't look like a post or reel URL " + "(expected /p/ or /reel/ in the path)." + ) + specs.append(f"post:{url}") + + # Dedupe preserving order. + return list(dict.fromkeys(specs)) + + # ------------------------------------------------------------------ + # Fetch + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + token = secrets.get("APIFY_API_TOKEN") + if not token: + raise AuthError("APIFY_API_TOKEN is required") + client = ApifyClient(token) + + # Group inputs by kind so each goes to the right Apify resultsType. + groups: dict[str, list[str]] = {"hashtag": [], "account": [], "post": []} + for spec in item_ids: + kind, _, url = spec.partition(":") + if kind in groups and url: + groups[kind].append(url) + + # Aggregate "what to ask Apify": per (resultsType, urls) call. + listing_type = str(settings.get("results_type", "posts")) + listing_urls = groups["hashtag"] + groups["account"] + post_urls = groups["post"] + + max_posts = int(settings.get("max_posts_per_input", 20)) + add_parent = bool(settings.get("add_parent_data", False)) + + # Apify run #1: hashtags + accounts → user-chosen results_type + all_posts: list[dict] = [] + if listing_urls: + try: + all_posts.extend( + client.run_actor(ACTOR_ID, { + "directUrls": listing_urls, + "resultsType": listing_type, + "resultsLimit": max_posts, + "addParentData": add_parent, + }) + ) + except ApifyError as e: + raise PluginError(f"Apify call failed (listing): {e}") from e + + # Apify run #2: explicit post URLs → always 'details' + if post_urls: + try: + all_posts.extend( + client.run_actor(ACTOR_ID, { + "directUrls": post_urls, + "resultsType": "details", + "resultsLimit": max(len(post_urls), 1), + "addParentData": add_parent, + }) + ) + except ApifyError as e: + raise PluginError(f"Apify call failed (post details): {e}") from e + + total = len(all_posts) + for i, post in enumerate(all_posts, 1): + try: + item = post_to_item(post) + except Exception as e: + item = Item( + source="instagram", + item_id=str(post.get("shortCode") or post.get("id") or f"unknown_{i}"), + url=str(post.get("url") or ""), + extra={"adapter_error": str(e), "raw": post}, + ) + + from ...transcription.runner import maybe_transcribe # noqa: PLC0415 + maybe_transcribe(item, settings, secrets) + + if progress: + progress(i, total, item.item_id) + yield item + + # ------------------------------------------------------------------ + # Helpers + + @staticmethod + def _is_post_url(url: str) -> bool: + try: + parts = [p for p in urlparse(url).path.split("/") if p] + except Exception: + return False + return bool(parts) and parts[0] in _POST_PATH_SEGMENTS + + @classmethod + def _normalize_account(cls, value: str) -> str: + v = value.strip().lstrip("@") + if v.startswith("http"): + parsed = urlparse(v) + parts = [p for p in parsed.path.split("/") if p] + if not parts: + raise PluginError(f"Cannot parse account URL: {value!r}") + head = parts[0] + if head in _POST_PATH_SEGMENTS: + raise PluginError( + f"{value!r} is a post/reel URL, not an account. " + "Use the «Ссылки на посты/рилсы» tab." + ) + v = head + if not _USERNAME_RE.match(v): + raise PluginError( + f"{value!r} is not a valid Instagram username " + "(letters, digits, dot, underscore; 1-30 chars)." + ) + return v diff --git a/content_parser/plugins/instagram_graph/__init__.py b/content_parser/plugins/instagram_graph/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/instagram_graph/adapter.py b/content_parser/plugins/instagram_graph/adapter.py new file mode 100644 index 0000000..b4fa9c1 --- /dev/null +++ b/content_parser/plugins/instagram_graph/adapter.py @@ -0,0 +1,128 @@ +"""Convert Instagram Graph API responses into the unified core schema.""" +from __future__ import annotations + +from typing import Any + +from ...core.schema import Comment, Item + + +def media_to_item( + media: dict, + *, + owner_username: str | None = None, + insights: list[dict] | dict | None = None, +) -> Item: + """Convert a Graph API media object (post / reel / carousel) to core.Item. + + Comments are NOT populated here — the plugin attaches them after a + separate /{media-id}/comments call to keep paging explicit. + + `insights` (when provided) overrides anything embedded in `media["insights"]`. + Pass it explicitly so the caller doesn't need to mutate the media dict. + """ + if not media.get("id"): + raise ValueError( + f"Malformed Graph media: missing id (got keys {sorted(media.keys())[:6]})" + ) + + media_type = media.get("media_type") # IMAGE / VIDEO / CAROUSEL_ALBUM / REEL + insights_flat = _flatten_insights(insights if insights is not None else media.get("insights")) + + media_dict: dict[str, Any] = { + "media_type": media_type, + "like_count": media.get("like_count"), + "comments_count": media.get("comments_count"), + "is_comment_enabled": media.get("is_comment_enabled"), + "media_url": media.get("media_url"), + "thumbnail_url": media.get("thumbnail_url"), + "permalink": media.get("permalink"), + # Reel/post insight metrics, when fetched. + "plays": insights_flat.get("plays"), + "reach": insights_flat.get("reach"), + "impressions": insights_flat.get("impressions"), + "saved": insights_flat.get("saved"), + "shares": insights_flat.get("shares"), + "total_interactions": insights_flat.get("total_interactions"), + "video_views": insights_flat.get("video_views"), + } + media_dict = {k: v for k, v in media_dict.items() if v not in (None, "")} + + caption = media.get("caption") or "" + title = caption.split("\n", 1)[0][:120].strip() if caption else None + + return Item( + source="instagram_graph", + item_id=str(media["id"]), + url=media.get("permalink") or "", + title=title, + author=owner_username, + author_id=str(media.get("owner", {}).get("id") or media.get("owner_id") or "") or None, + published_at=media.get("timestamp"), + text=caption or None, + media=media_dict, + extra={ + "shortcode": media.get("shortcode"), + "is_shared_to_feed": media.get("is_shared_to_feed"), + "children": media.get("children", {}).get("data") if media.get("children") else None, + }, + ) + + +def comment_to_core(c: dict, parent_id: str | None = None) -> Comment: + return Comment( + comment_id=str(c.get("id", "") or ""), + parent_id=parent_id, + author=c.get("username"), + author_id=str(c.get("user", {}).get("id") or "") or None, + text=c.get("text"), + like_count=int(c.get("like_count", 0) or 0), + published_at=c.get("timestamp"), + ) + + +def flatten_comments(comments: list[dict]) -> list[Comment]: + """Each top-level comment may have a `replies.data` list expanded inline.""" + out: list[Comment] = [] + for c in comments or []: + if not isinstance(c, dict): + continue + top = comment_to_core(c) + out.append(top) + replies = ((c.get("replies") or {}).get("data")) or [] + for reply in replies: + if isinstance(reply, dict): + out.append(comment_to_core(reply, parent_id=top.comment_id)) + return out + + +# ---------------------------------------------------------------------- +# Helpers + + +def _flatten_insights(insights: Any) -> dict[str, int | float]: + """Graph returns insights as {"data": [{"name": "reach", "values": [{"value": 1234}]}, ...]}. + + We flatten to {name: value}. + """ + if not insights: + return {} + if isinstance(insights, dict): + data = insights.get("data") or [] + elif isinstance(insights, list): + data = insights + else: + return {} + + out: dict[str, int | float] = {} + for entry in data: + if not isinstance(entry, dict): + continue + name = entry.get("name") + if not name: + continue + values = entry.get("values") or [] + if values and isinstance(values[0], dict): + v = values[0].get("value") + if isinstance(v, (int, float)): + out[name] = v + return out diff --git a/content_parser/plugins/instagram_graph/client.py b/content_parser/plugins/instagram_graph/client.py new file mode 100644 index 0000000..703c0ee --- /dev/null +++ b/content_parser/plugins/instagram_graph/client.py @@ -0,0 +1,165 @@ +"""Thin wrapper over Meta's Instagram Graph API. + +The Graph API only accepts the access_token as a query-string parameter +(no Authorization header support), so we never echo full request URLs in +error messages — they would leak the token. Tokens travel encrypted over +TLS to graph.facebook.com. + +Auth model: long-lived user token (≈60 days, refreshable). Granted scopes +must include `instagram_basic`, `pages_show_list`, +`instagram_manage_comments`, and `business_management`. Scope errors come +back as 400 with `error.code=10` (permissions). +""" +from __future__ import annotations + +import time +from typing import Any +from urllib.parse import urlparse + +import requests + +from ...core.errors import AuthError, PluginError, RateLimitError + + +GRAPH_BASE = "https://graph.facebook.com/v19.0" + + +# Meta error codes that always require user intervention. +_AUTH_ERROR_CODES = {190, 102, 458, 459, 460, 463, 467} # invalid/expired/missing token +_PERMISSION_ERROR_CODES = {10, 200, 803} # permissions / approval issues +# Meta uses error_subcode 2207051 for "rate limit reached", or top-level codes 4/17/32/613. +_RATE_LIMIT_CODES = {4, 17, 32, 613} + + +class GraphClient: + """Read-only client for Instagram Graph API endpoints. + + Sleeps and retries 429 / 5xx with exponential backoff (2s, 4s). + """ + + def __init__( + self, + token: str, + *, + timeout: int = 60, + max_rate_limit_retries: int = 2, + ): + if not token: + raise ValueError("Instagram access token is required") + self.token = token + self.timeout = timeout + self.max_rate_limit_retries = max_rate_limit_retries + self.session = requests.Session() + + @staticmethod + def _sleep(seconds: float) -> None: + time.sleep(seconds) + + def get(self, path: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + """GET a Graph endpoint. `path` is either '/me' or a node id like '17841…/media'.""" + delay = 2.0 + for attempt in range(self.max_rate_limit_retries + 1): + try: + return self._get_once(path, params or {}) + except RateLimitError: + if attempt >= self.max_rate_limit_retries: + raise + self._sleep(delay) + delay *= 2 + raise PluginError("retry loop exhausted") # pragma: no cover + + def get_paginated( + self, + path: str, + params: dict[str, Any] | None = None, + *, + max_items: int | None = None, + ) -> list[dict]: + """Walk the `paging.next` chain. Returns all items collected up to max_items.""" + items: list[dict] = [] + current_path: str | None = path + current_params: dict[str, Any] | None = dict(params or {}) + while current_path: + data = self.get(current_path, current_params) + page = data.get("data") or [] + for entry in page: + items.append(entry) + if max_items is not None and len(items) >= max_items: + return items + next_url = (data.get("paging") or {}).get("next") + if not next_url: + break + # The 'next' URL embeds an access_token in its query. We strip it and + # let _get_once inject our own — defends against a (theoretical) + # man-in-the-middle swap of next URLs to a different token. + current_path = self._next_path_from_url(next_url) + current_params = self._next_params_from_url(next_url) + return items + + def _redact(self, message: str) -> str: + """Replace the access token with [REDACTED] in any error string.""" + if self.token and self.token in message: + return message.replace(self.token, "[REDACTED]") + return message + + # ------------------------------------------------------------------ + + def _get_once(self, path: str, params: dict[str, Any]) -> dict[str, Any]: + url = f"{GRAPH_BASE}/{path.lstrip('/')}" + # Always overwrite any embedded token. + params = {**params, "access_token": self.token} + try: + r = self.session.get(url, params=params, timeout=self.timeout) + except requests.RequestException as e: + # requests sometimes embeds the full URL — including ?access_token=… — + # in its exception message. Scrub before propagating. + raise PluginError( + f"Network error calling Graph API: {self._redact(str(e))}" + ) from None + + if r.status_code in (200, 201): + try: + return r.json() + except ValueError as e: + raise PluginError(f"Graph returned non-JSON: {r.text[:200]}") from e + + # Meta error envelope: {"error": {"message", "type", "code", "error_subcode"}} + try: + err = r.json().get("error") or {} + except ValueError: + err = {} + code = err.get("code") + message = err.get("message", r.text[:200]) + + if r.status_code == 401 or code in _AUTH_ERROR_CODES: + raise AuthError(f"Instagram Graph rejected the token (code {code}): {message}") + if code in _PERMISSION_ERROR_CODES: + raise AuthError( + f"Instagram Graph permissions error (code {code}): {message}. " + "App probably needs the instagram_basic + instagram_manage_comments scopes " + "approved by Meta." + ) + if r.status_code == 429 or code in _RATE_LIMIT_CODES: + raise RateLimitError(f"Instagram Graph rate-limit (code {code}): {message}") + if 500 <= r.status_code < 600: + raise RateLimitError(f"Instagram Graph server error ({r.status_code}): {message}") + raise PluginError(f"Graph {path!r} failed ({r.status_code}, code {code}): {message}") + + @staticmethod + def _next_path_from_url(next_url: str) -> str: + """Strip protocol/host/leading version to get just the relative path Graph expects.""" + parsed = urlparse(next_url) + # Path looks like /v19.0/17841.../media; trim the version. + parts = [p for p in parsed.path.split("/") if p] + if parts and parts[0].startswith("v"): + parts = parts[1:] + return "/".join(parts) + + @staticmethod + def _next_params_from_url(next_url: str) -> dict[str, Any]: + from urllib.parse import parse_qs, urlparse as _u + q = parse_qs(_u(next_url).query) + # parse_qs returns lists; flatten singletons. + out: dict[str, Any] = {k: (v[0] if len(v) == 1 else v) for k, v in q.items()} + out.pop("access_token", None) # we always inject our own + return out diff --git a/content_parser/plugins/instagram_graph/plugin.py b/content_parser/plugins/instagram_graph/plugin.py new file mode 100644 index 0000000..a455fc9 --- /dev/null +++ b/content_parser/plugins/instagram_graph/plugin.py @@ -0,0 +1,265 @@ +"""Instagram Graph API plugin — for YOUR OWN accounts only. + +Use cases this plugin solves that the public Instagram (Apify) plugin can't: + - Insights: reach / impressions / saved / shares / plays / engagement + - Full comment thread including replies, with likes + - Free of charge (no Apify credits) + +What you need: + 1. Convert your Instagram account to Business or Creator (instant, free). + 2. Connect it to a Facebook Page. + 3. Create a Meta Developer App: https://developers.facebook.com/apps/ + 4. Generate a long-lived access token via Graph API Explorer or OAuth flow. + Required scopes: instagram_basic, instagram_manage_comments, + pages_show_list, business_management. + 5. Find your Instagram Business Account ID (numeric, ~17 digits). +""" +from __future__ import annotations + +import re +from typing import Any, Iterator + +from ...core.errors import AuthError, PluginError +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.redact import redact_spec +from ...core.schema import Item +from ...transcription.runner import maybe_transcribe +from .adapter import flatten_comments, media_to_item +from .client import GraphClient + + +# Instagram Business Account IDs are 15-20 digit numbers. +_IG_BUSINESS_ID_RE = re.compile(r"^\d{15,20}$") +# Media (post/reel) IDs are similar — long numeric strings. +_MEDIA_ID_RE = re.compile(r"^\d{10,30}(?:_\d+)?$") +# Reel-specific insight metrics; for non-Reels we use the post-grade set. +_REEL_INSIGHT_METRICS = "plays,reach,total_interactions,saved,shares" +_POST_INSIGHT_METRICS = "impressions,reach,saved" + + +class InstagramGraphPlugin(SourcePlugin): + name = "instagram_graph" + label = "Instagram (свой аккаунт)" + secret_keys = ["INSTAGRAM_ACCESS_TOKEN"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="account", + label="ID Business-аккаунтов", + placeholder="17841405822304914", + help="15-20 цифр. Найти в Meta Business Settings или через Graph API Explorer " + "(GET /me/accounts → instagram_business_account.id).", + ), + InputSpec( + kind="post_id", + label="ID конкретных постов/рилсов", + placeholder="17895695668004550", + help="ID поста (можно получить через GET /{ig-user-id}/media).", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec("max_posts_per_account", "Макс. постов с аккаунта", + "number", 25, min_value=1, max_value=200), + FieldSpec("fetch_comments", "Парсить комментарии", "checkbox", True), + FieldSpec("max_comments_per_post", "Макс. комментариев на пост", + "number", 100, min_value=1, max_value=1000), + FieldSpec("fetch_replies", "Включая ответы на комментарии", "checkbox", True), + FieldSpec("fetch_insights", "🔢 Подгружать insights (reach, plays, …)", + "checkbox", True, + help="Аналитика только для своих постов. Доступна 0-90 дней после публикации."), + FieldSpec("transcribe_videos", "🎤 Транскрибировать видео (Whisper)", "checkbox", False, + help="Скачивает аудио + Whisper API. Нужен OPENAI_API_KEY и ffmpeg."), + FieldSpec("max_audio_seconds_per_video", "Макс. секунд аудио на пост", + "number", 600, min_value=10, max_value=3600), + ] + + # ------------------------------------------------------------------ + # Resolve + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + specs: list[str] = [] + for raw in inputs.get("account", []): + v = raw.strip() + if not _IG_BUSINESS_ID_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid Instagram Business Account ID " + "(expected 15-20 digits). Find yours via the Meta Business UI or " + "Graph API Explorer." + ) + specs.append(f"account:{v}") + for raw in inputs.get("post_id", []): + v = raw.strip() + if not _MEDIA_ID_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid Graph media ID (expected a numeric ID)." + ) + specs.append(f"post:{v}") + return list(dict.fromkeys(specs)) + + # ------------------------------------------------------------------ + # Fetch + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + token = (secrets.get("INSTAGRAM_ACCESS_TOKEN") or "").strip() + if not token: + raise AuthError("INSTAGRAM_ACCESS_TOKEN is required") + client = GraphClient(token) + + max_posts = int(settings.get("max_posts_per_account", 25)) + fetch_comments = bool(settings.get("fetch_comments", True)) + max_comments = int(settings.get("max_comments_per_post", 100)) + fetch_replies = bool(settings.get("fetch_replies", True)) + fetch_insights = bool(settings.get("fetch_insights", True)) + + # Collect (media_dict, owner_username) pairs across all specs. + jobs: list[tuple[dict, str | None]] = [] + for spec in item_ids: + kind, _, value = spec.partition(":") + try: + self._collect_for_spec( + client, kind, value, + max_posts=max_posts, + jobs=jobs, + ) + except (AuthError, PluginError): + raise + except Exception as e: + raise PluginError( + f"Graph error for {redact_spec(spec)!r}: {e}" + ) from e + + # Dedupe by media id. + seen: set[str] = set() + unique: list[tuple[dict, str | None]] = [] + for m, u in jobs: + mid = str(m.get("id") or "") + if mid and mid not in seen: + seen.add(mid) + unique.append((m, u)) + + total = len(unique) + for i, (media, owner) in enumerate(unique, 1): + # Insights: separate call to /{media-id}/insights. + insights_data = self._fetch_insights(client, media) if fetch_insights else None + + try: + item = media_to_item(media, owner_username=owner, insights=insights_data) + except Exception as e: + item = Item( + source="instagram_graph", + item_id=str(media.get("id") or f"unknown_{i}"), + url=str(media.get("permalink") or ""), + extra={"adapter_error": str(e), "raw": media}, + ) + + if fetch_comments: + try: + item.comments = self._fetch_comments( + client, media["id"], + max_comments=max_comments, + with_replies=fetch_replies, + ) + except Exception as e: + item.extra["comments_error"] = str(e) + + maybe_transcribe(item, settings, secrets) + + if progress: + progress(i, total, item.item_id) + yield item + + # ------------------------------------------------------------------ + # Per-spec collection + + def _collect_for_spec( + self, + client: GraphClient, + kind: str, + value: str, + *, + max_posts: int, + jobs: list[tuple[dict, str | None]], + ) -> None: + if kind == "account": + user = client.get(value, params={"fields": "username"}) + owner = user.get("username") + media_fields = ( + "id,caption,media_type,media_url,thumbnail_url,permalink,timestamp," + "is_comment_enabled,comments_count,like_count,owner{id,username}" + ) + posts = client.get_paginated( + f"{value}/media", + params={"fields": media_fields, "limit": min(50, max_posts)}, + max_items=max_posts, + ) + for m in posts: + jobs.append((m, owner)) + + elif kind == "post": + media_fields = ( + "id,caption,media_type,media_url,thumbnail_url,permalink,timestamp," + "is_comment_enabled,comments_count,like_count,owner{id,username}" + ) + media = client.get(value, params={"fields": media_fields}) + owner = (media.get("owner") or {}).get("username") + jobs.append((media, owner)) + + else: + raise PluginError(f"Unknown Instagram Graph input kind: {kind!r}") + + # ------------------------------------------------------------------ + # Comments + + def _fetch_comments( + self, + client: GraphClient, + media_id: str, + *, + max_comments: int, + with_replies: bool, + ) -> list: + fields = "id,text,username,timestamp,like_count,user{id}" + if with_replies: + fields += ",replies.limit(50){id,text,username,timestamp,like_count,user{id}}" + comments = client.get_paginated( + f"{media_id}/comments", + params={"fields": fields, "limit": min(50, max_comments)}, + max_items=max_comments, + ) + return flatten_comments(comments) + + # ------------------------------------------------------------------ + # Insights + + def _fetch_insights(self, client: GraphClient, media: dict) -> list[dict] | None: + media_id = media.get("id") + if not media_id: + return None + media_type = (media.get("media_type") or "").upper() + # Reels use a different metric set than feed posts. Parens make the + # precedence explicit: REEL OR (VIDEO AND product_type=REELS). + is_reel = (media_type == "REEL") or ( + media_type == "VIDEO" and media.get("media_product_type") == "REELS" + ) + metrics = _REEL_INSIGHT_METRICS if is_reel else _POST_INSIGHT_METRICS + try: + data = client.get(f"{media_id}/insights", params={"metric": metrics}) + except (AuthError, PluginError): + # Insights commonly fail with permission errors on archived posts — + # don't kill the whole run. + return None + return data.get("data") or None diff --git a/content_parser/plugins/reddit/__init__.py b/content_parser/plugins/reddit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/reddit/adapter.py b/content_parser/plugins/reddit/adapter.py new file mode 100644 index 0000000..e63eac6 --- /dev/null +++ b/content_parser/plugins/reddit/adapter.py @@ -0,0 +1,95 @@ +"""Convert PRAW Submission/Comment objects into the unified core schema.""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ...core.schema import Comment, Item + + +def _iso(ts: float | None) -> str | None: + if ts is None: + return None + return datetime.fromtimestamp(float(ts), tz=timezone.utc).isoformat() + + +def _author_str(author: Any) -> str | None: + """PRAW returns a Redditor object, None for [deleted], or sometimes a string.""" + if author is None: + return "[deleted]" + name = getattr(author, "name", None) + if name: + return str(name) + return str(author) if author else "[deleted]" + + +def _subreddit_str(subreddit: Any) -> str | None: + if subreddit is None: + return None + return getattr(subreddit, "display_name", None) or str(subreddit) + + +def submission_to_item(s: Any) -> Item: + """Map a PRAW Submission to core.Item. + + Comments are NOT populated here — the plugin attaches them separately + after iterating s.comments to keep ordering and depth control explicit. + """ + permalink = getattr(s, "permalink", "") or "" + url_external = getattr(s, "url", None) + is_self = bool(getattr(s, "is_self", False)) + + media: dict = { + "score": getattr(s, "score", None), + "upvote_ratio": getattr(s, "upvote_ratio", None), + "num_comments": getattr(s, "num_comments", None), + "num_crossposts": getattr(s, "num_crossposts", None), + "subreddit": _subreddit_str(getattr(s, "subreddit", None)), + "flair": getattr(s, "link_flair_text", None), + "is_video": bool(getattr(s, "is_video", False)), + "is_self": is_self, + "over_18": bool(getattr(s, "over_18", False)), + "spoiler": bool(getattr(s, "spoiler", False)), + "locked": bool(getattr(s, "locked", False)), + "domain": getattr(s, "domain", None), + "url_external": None if is_self else url_external, + } + media = {k: v for k, v in media.items() if v not in (None, "")} + + extra: dict = {} + awards = getattr(s, "all_awardings", None) + if awards: + extra["awards"] = [ + {"name": a.get("name"), "count": a.get("count")} for a in awards if isinstance(a, dict) + ] + post_hint = getattr(s, "post_hint", None) + if post_hint: + extra["post_hint"] = post_hint + removed = getattr(s, "removed_by_category", None) + if removed: + extra["removed_by_category"] = removed + + return Item( + source="reddit", + item_id=str(getattr(s, "id", "") or ""), + url=f"https://reddit.com{permalink}" if permalink else (url_external or ""), + title=getattr(s, "title", None), + author=_author_str(getattr(s, "author", None)), + author_id=getattr(s, "author_fullname", None), + published_at=_iso(getattr(s, "created_utc", None)), + text=(getattr(s, "selftext", None) or None) if is_self else None, + media=media, + extra=extra, + ) + + +def comment_to_core(c: Any, parent_id: str | None) -> Comment: + return Comment( + comment_id=str(getattr(c, "id", "") or ""), + parent_id=parent_id, + author=_author_str(getattr(c, "author", None)), + author_id=getattr(c, "author_fullname", None), + text=getattr(c, "body", None), + like_count=int(getattr(c, "score", 0) or 0), + published_at=_iso(getattr(c, "created_utc", None)), + ) diff --git a/content_parser/plugins/reddit/client.py b/content_parser/plugins/reddit/client.py new file mode 100644 index 0000000..dc45158 --- /dev/null +++ b/content_parser/plugins/reddit/client.py @@ -0,0 +1,43 @@ +"""Thin wrapper that builds a read-only praw.Reddit instance from secrets.""" +from __future__ import annotations + +import logging +from typing import Any + + +logger = logging.getLogger(__name__) + + +DEFAULT_USER_AGENT = "content_parser/1.0" + + +def build_reddit(secrets: dict[str, str]) -> Any: + """Return a praw.Reddit instance configured for read-only auth. + + Imported lazily so importing this module doesn't fail when praw isn't installed. + """ + import praw # noqa: PLC0415 + + client_id = secrets.get("REDDIT_CLIENT_ID") + client_secret = secrets.get("REDDIT_CLIENT_SECRET") + user_agent = secrets.get("REDDIT_USER_AGENT", "") + + if not client_id or not client_secret: + raise ValueError("REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET are required") + + if not user_agent.strip(): + logger.warning( + "REDDIT_USER_AGENT not set; falling back to %r. Reddit's API rules " + "expect '<platform>:<app-id>:<version> by /u/<username>' — generic " + "agents may be rate-limited or blocked.", + DEFAULT_USER_AGENT, + ) + user_agent = DEFAULT_USER_AGENT + + reddit = praw.Reddit( + client_id=client_id, + client_secret=client_secret, + user_agent=user_agent, + ) + reddit.read_only = True + return reddit diff --git a/content_parser/plugins/reddit/plugin.py b/content_parser/plugins/reddit/plugin.py new file mode 100644 index 0000000..055faba --- /dev/null +++ b/content_parser/plugins/reddit/plugin.py @@ -0,0 +1,339 @@ +"""Reddit plugin — posts and comments via PRAW (read-only).""" +from __future__ import annotations + +import logging +import re +from typing import Any, Iterable, Iterator +from urllib.parse import urlparse + +from ...core.errors import AuthError, PluginError +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.redact import redact_spec +from ...core.schema import Item +from .adapter import comment_to_core, submission_to_item +from .client import build_reddit + + +logger = logging.getLogger(__name__) + + +_SUBREDDIT_RE = re.compile(r"^[A-Za-z0-9_]{1,21}$") +_USERNAME_RE = re.compile(r"^[A-Za-z0-9_-]{3,20}$") +_REDDIT_HOSTS = ("reddit.com", "redd.it") +# Hard cap on MoreComments expansions when expand_more=True. Each expansion +# triggers a network round-trip and can pull ~250 extra comments, so an +# unbounded replace_more(limit=None) easily produces minutes of work and +# Reddit-side rate limits on big threads. +_MAX_REPLACE_MORE = 32 + + +def _is_reddit_host(host: str) -> bool: + host = host.lower() + return any(host == h or host.endswith("." + h) for h in _REDDIT_HOSTS) + + +class RedditPlugin(SourcePlugin): + name = "reddit" + label = "Reddit" + secret_keys = ["REDDIT_CLIENT_ID", "REDDIT_CLIENT_SECRET"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="subreddit", + label="Сабреддиты", + placeholder="python\nMachineLearning\nhttps://reddit.com/r/marketing", + help="Имя саба, с/без префикса r/, или URL.", + ), + InputSpec( + kind="query", + label="Поисковые запросы", + placeholder="claude code\nremote work", + help="Полнотекстовый поиск по всему Reddit.", + ), + InputSpec( + kind="post_url", + label="Ссылки на посты", + placeholder="https://www.reddit.com/r/python/comments/abc123/title/", + ), + InputSpec( + kind="user", + label="Авторы", + placeholder="spez\n/u/automoderator", + help="Username для трекинга постов конкретного пользователя.", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec("listing", "Сортировка постов", "select", "top", + options=["hot", "top", "new", "rising", "controversial"]), + FieldSpec("time_filter", "Период (для top/controversial)", "select", "month", + options=["hour", "day", "week", "month", "year", "all"]), + FieldSpec("max_posts_per_input", "Макс. постов на источник", "number", 25, + min_value=1, max_value=500), + FieldSpec("fetch_comments", "Парсить комментарии", "checkbox", True), + FieldSpec("max_comments_per_post", "Макс. комментариев на пост", "number", 100, + min_value=1, max_value=2000), + FieldSpec("comment_depth", "Глубина комментариев", "select", "top_level", + options=["top_level", "all"], + help="top_level — только верхний уровень; all — все включая ответы."), + FieldSpec("expand_more_comments", "Раскрывать «load more» (медленно)", "checkbox", False), + ] + + # ------------------------------------------------------------------ + # Resolve + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + """Returns 'kind:value' specs so fetch() routes them through PRAW.""" + specs: list[str] = [] + + for raw in inputs.get("subreddit", []): + name = self._normalize_subreddit(raw) + specs.append(f"subreddit:{name}") + + for q in inputs.get("query", []): + q = q.strip() + if q: + specs.append(f"query:{q}") + + for url in inputs.get("post_url", []): + url = url.strip() + if not url: + continue + if not self._is_reddit_post_url(url): + raise PluginError( + f"{url!r} doesn't look like a Reddit post URL " + "(expected /r/<sub>/comments/<id>/)." + ) + specs.append(f"post_url:{url}") + + for u in inputs.get("user", []): + name = self._normalize_user(u) + specs.append(f"user:{name}") + + return list(dict.fromkeys(specs)) + + # ------------------------------------------------------------------ + # Fetch + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + if not secrets.get("REDDIT_CLIENT_ID") or not secrets.get("REDDIT_CLIENT_SECRET"): + raise AuthError("REDDIT_CLIENT_ID and REDDIT_CLIENT_SECRET are required") + + reddit = build_reddit(secrets) + + listing = str(settings.get("listing", "top")) + time_filter = str(settings.get("time_filter", "month")) + max_posts = int(settings.get("max_posts_per_input", 25)) + fetch_comments = bool(settings.get("fetch_comments", True)) + max_comments = int(settings.get("max_comments_per_post", 100)) + depth = str(settings.get("comment_depth", "top_level")) + expand_more = bool(settings.get("expand_more_comments", False)) + + # Collect all submissions across specs, then yield with comments attached. + submissions: list[Any] = [] + for spec in item_ids: + kind, _, value = spec.partition(":") + try: + submissions.extend(self._collect_submissions( + reddit, kind, value, listing, time_filter, max_posts + )) + except Exception as e: + raise PluginError( + f"Reddit error for {redact_spec(spec)!r}: {e}" + ) from e + + # Dedupe by submission id (same post can come from multiple inputs). + seen: set[str] = set() + unique_subs: list[Any] = [] + for s in submissions: + sid = str(getattr(s, "id", "") or "") + if sid and sid not in seen: + seen.add(sid) + unique_subs.append(s) + + total = len(unique_subs) + for i, sub in enumerate(unique_subs, 1): + item = submission_to_item(sub) + + if fetch_comments: + try: + item.comments = self._collect_comments( + sub, max_comments=max_comments, depth=depth, expand_more=expand_more + ) + except Exception as e: + item.extra["comments_error"] = str(e) + + if progress: + progress(i, total, item.item_id) + yield item + + # ------------------------------------------------------------------ + # Helpers — submission collection per input kind + + def _collect_submissions( + self, + reddit: Any, + kind: str, + value: str, + listing: str, + time_filter: str, + limit: int, + ) -> Iterable[Any]: + if kind == "subreddit": + sub = reddit.subreddit(value) + return self._listing_iter(sub, listing, time_filter, limit) + + if kind == "query": + sub = reddit.subreddit("all") + return list(sub.search(value, sort=self._search_sort(listing), time_filter=time_filter, limit=limit)) + + if kind == "user": + redditor = reddit.redditor(value) + return self._listing_iter(redditor.submissions, listing, time_filter, limit, is_user=True) + + if kind == "post_url": + return [reddit.submission(url=value)] + + raise PluginError(f"Unknown Reddit input kind: {kind!r}") + + @staticmethod + def _listing_iter( + source: Any, listing: str, time_filter: str, limit: int, is_user: bool = False + ) -> list[Any]: + # User submissions has slightly different listing methods. + if listing == "top": + return list(source.top(time_filter=time_filter, limit=limit)) + if listing == "controversial": + return list(source.controversial(time_filter=time_filter, limit=limit)) + if listing == "new": + return list(source.new(limit=limit)) + if listing == "rising": + if is_user: + logger.info( + "Reddit user submissions don't expose 'rising'; using 'new' instead." + ) + return list(source.new(limit=limit)) + return list(source.rising(limit=limit)) + # default 'hot' + return list(source.hot(limit=limit)) + + @staticmethod + def _search_sort(listing: str) -> str: + # Reddit search sort: relevance, hot, top, new, comments + if listing in ("top", "new", "hot", "comments"): + return listing + return "relevance" + + # ------------------------------------------------------------------ + # Comment flattening + + @staticmethod + def _collect_comments( + sub: Any, *, max_comments: int, depth: str, expand_more: bool + ) -> list: + # When expand_more=True, cap at _MAX_REPLACE_MORE expansions instead of + # unbounded — see the constant for the rationale. + sub.comments.replace_more(limit=_MAX_REPLACE_MORE if expand_more else 0) + + out: list = [] + + def _walk(comments_iter, parent_id: str | None, only_top: bool) -> None: + for c in comments_iter: + if len(out) >= max_comments: + return + # MoreComments may slip through if expand_more=False and replace_more left some + if c.__class__.__name__ == "MoreComments": + continue + out.append(comment_to_core(c, parent_id)) + if only_top: + continue + replies = getattr(c, "replies", None) + if replies: + _walk(replies, str(getattr(c, "id", "") or ""), only_top=False) + + if depth == "top_level": + _walk(sub.comments, parent_id=None, only_top=True) + else: + _walk(sub.comments, parent_id=None, only_top=False) + + return out + + # ------------------------------------------------------------------ + # Input normalization + + @classmethod + def _normalize_subreddit(cls, raw: str) -> str: + v = raw.strip() + if v.startswith("http"): + parsed = urlparse(v) + if not _is_reddit_host(parsed.hostname or ""): + raise PluginError( + f"{raw!r} is not a Reddit URL (host must be reddit.com or redd.it)." + ) + parts = [p for p in parsed.path.split("/") if p] + if len(parts) >= 2 and parts[0].lower() == "r": + v = parts[1] + else: + raise PluginError(f"Cannot parse subreddit URL: {raw!r}") + if v.lower().startswith("r/"): + v = v[2:] + v = v.strip("/") + if not _SUBREDDIT_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid subreddit name " + "(letters, digits, underscore; 1-21 chars)." + ) + return v + + @classmethod + def _normalize_user(cls, raw: str) -> str: + v = raw.strip() + if v.startswith("http"): + parsed = urlparse(v) + if not _is_reddit_host(parsed.hostname or ""): + raise PluginError( + f"{raw!r} is not a Reddit URL (host must be reddit.com or redd.it)." + ) + parts = [p for p in parsed.path.split("/") if p] + if len(parts) >= 2 and parts[0].lower() in ("u", "user"): + v = parts[1] + else: + raise PluginError(f"Cannot parse user URL: {raw!r}") + # Note: the longer prefixes ('/user/', '/u/') must be checked before the + # shorter ones ('user/', 'u/'), and '@' last, so we don't strip too little. + for prefix in ("/user/", "/u/", "user/", "u/", "@"): + if v.lower().startswith(prefix): + v = v[len(prefix):] + break + v = v.strip("/") + if not _USERNAME_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid Reddit username " + "(letters, digits, _ or -; 3-20 chars)." + ) + return v + + @staticmethod + def _is_reddit_post_url(url: str) -> bool: + try: + parsed = urlparse(url) + parts = [p for p in parsed.path.split("/") if p] + except Exception: + return False + if not _is_reddit_host(parsed.hostname or ""): + return False + # Expected: /r/<sub>/comments/<id>/<slug>/ + return len(parts) >= 4 and parts[0].lower() == "r" and parts[2].lower() == "comments" diff --git a/content_parser/plugins/telegram/__init__.py b/content_parser/plugins/telegram/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/telegram/adapter.py b/content_parser/plugins/telegram/adapter.py new file mode 100644 index 0000000..321201c --- /dev/null +++ b/content_parser/plugins/telegram/adapter.py @@ -0,0 +1,222 @@ +"""Convert Apify Telegram-scraper output dicts into the unified core schema. + +Field names vary between scraper actors, so the adapter looks up each +logical field through a list of likely keys (`_pick`) and falls back to None. +""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ...core.schema import Comment, Item + + +def _pick(d: dict, *keys: str, default: Any = None) -> Any: + for k in keys: + v = d.get(k) + if v not in (None, ""): + return v + return default + + +def _iso(value: Any) -> str | None: + """Accepts ISO string, UNIX timestamp, or None.""" + if value is None: + return None + if isinstance(value, str): + return value + try: + return datetime.fromtimestamp(float(value), tz=timezone.utc).isoformat() + except (TypeError, ValueError): + return None + + +def _reactions_total(reactions: Any) -> int | None: + """reactions can be: + [{"emoji": "👍", "count": 5}, ...] → sum counts + {"👍": 5, "❤️": 2} → sum values + an int → as-is + """ + if reactions is None: + return None + if isinstance(reactions, int): + return reactions + if isinstance(reactions, list): + total = 0 + for r in reactions: + if isinstance(r, dict): + try: + total += int(r.get("count", 0) or 0) + except (TypeError, ValueError): + continue + return total or None + if isinstance(reactions, dict): + try: + return sum(int(v or 0) for v in reactions.values()) + except (TypeError, ValueError): + return None + return None + + +def _to_int(value: Any) -> int | None: + """Coerce a numeric value to int. Returns None for missing/non-numeric values + instead of letting a stray dict (e.g. {'count': 100}) leak into media.""" + if value is None or isinstance(value, bool): + return None + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + try: + return int(value) + except ValueError: + return None + return None + + +def _replies_count(msg: dict) -> int | None: + """Telegram scrapers represent replies inconsistently: + replies: 42 → number + replies: [...] → list of comment dicts → use length + repliesCount / replies_count → number under different keys + """ + raw = msg.get("replies") + if isinstance(raw, int): + return raw + if isinstance(raw, list): + return len(raw) + return _to_int(_pick(msg, "repliesCount", "replies_count", "commentsCount")) + + +def _channel_label(msg: dict) -> tuple[str | None, str | None]: + """Extract (display name, channel username) from a message dict.""" + title = _pick(msg, "channelTitle", "channel_title", "chatTitle", "chat_title") + username = _pick(msg, "channelUsername", "channel_username", "chatUsername", "chat_username") + if not title and not username: + # nested 'channel' or 'chat' subobject + for k in ("channel", "chat"): + sub = msg.get(k) + if isinstance(sub, dict): + title = title or _pick(sub, "title", "name") + username = username or _pick(sub, "username", "screen_name") + return title, username + + +def message_to_item(msg: dict) -> Item: + """Convert a Telegram message dict (from any common Apify actor) to core.Item.""" + msg_id = _pick(msg, "id", "messageId", "message_id") + if msg_id is None: + raise ValueError( + f"Malformed Telegram message: missing id (got keys {sorted(msg.keys())[:8]})" + ) + + title_label, username = _channel_label(msg) + url = _pick(msg, "url", "messageUrl", "message_url") + if not url and username and msg_id is not None: + url = f"https://t.me/{username}/{msg_id}" + + item_id = f"{username}_{msg_id}" if username else str(msg_id) + + text = _pick(msg, "text", "message", "content", default="") + + reactions_count = _reactions_total(_pick(msg, "reactions", "reactions_count")) + raw_media = msg.get("media") + media_obj = raw_media if isinstance(raw_media, dict) else None + media_type = _pick(msg, "mediaType", "media_type") or ( + media_obj.get("type") if media_obj else None + ) + + media: dict = { + "views_count": _to_int(_pick(msg, "views", "viewCount", "view_count")), + "forwards_count": _to_int(_pick(msg, "forwards", "forwardCount", "forward_count")), + "comments_count": _replies_count(msg), + "reactions_count": reactions_count, + "media_type": media_type, + "is_pinned": bool(_pick(msg, "isPinned", "is_pinned", default=False)), + } + # Drop None/empty/False but keep zero counts (0 is a meaningful signal). + media = {k: v for k, v in media.items() if v is not None and v != "" and v is not False} + + title = text.split("\n", 1)[0][:120].strip() if text else None + + return Item( + source="telegram", + item_id=item_id, + url=url or "", + title=title, + author=title_label or username, + author_id=username, + published_at=_iso(_pick(msg, "date", "timestamp", "publishedAt", "published_at")), + text=text or None, + media=media, + comments=_extract_comments(msg, parent_url=url), + extra={ + "channel_id": _pick(msg, "channelId", "channel_id", "chatId", "chat_id"), + "is_forwarded": bool(_pick(msg, "fwdFromId", "forwarded_from", default=False)), + "has_media": bool(msg.get("media")), + }, + ) + + +def _extract_comments(msg: dict, *, parent_url: str | None = None) -> list[Comment]: + raw = _pick(msg, "replies_data", "comments", "discussion", "thread", default=None) + if raw is None: + # Plain 'replies' might also hold the list (some scrapers). + replies_field = msg.get("replies") + if isinstance(replies_field, list): + raw = replies_field + else: + return [] + if isinstance(raw, int): + return [] + if isinstance(raw, dict) and "items" in raw: + raw = raw["items"] + if not isinstance(raw, list): + return [] + + # Two-pass: collect known comment IDs first, then assign parent_id from + # reply_to_message_id when the parent is in the same fetched batch. + flat: list[dict] = [c for c in raw if isinstance(c, dict)] + known_ids: set[str] = set() + for c in flat: + cid = _pick(c, "id", "messageId", "message_id", "comment_id") + if cid is not None: + known_ids.add(str(cid)) + + out: list[Comment] = [] + for c in flat: + reply_to = _pick(c, "reply_to_message_id", "replyToMessageId", "reply_to_msg_id") + parent_id = str(reply_to) if reply_to is not None and str(reply_to) in known_ids else None + out.append(_comment_from_dict(c, parent_id=parent_id)) + return out + + +def _comment_from_dict(c: dict, *, parent_id: str | None) -> Comment: + cid = _pick(c, "id", "messageId", "message_id", "comment_id") + sender = c.get("from") or c.get("sender") or {} + author_name = ( + _pick(c, "authorName", "author_name", "fromName") + or (isinstance(sender, dict) and (sender.get("name") or + ((sender.get("first_name") or "") + " " + (sender.get("last_name") or "")).strip())) + or (isinstance(sender, dict) and sender.get("username")) + or _pick(c, "username", "fromUsername") + or None + ) + if author_name == "": + author_name = None + author_id = ( + _pick(c, "authorUsername", "author_username", "fromUsername") + or (isinstance(sender, dict) and sender.get("username")) + or None + ) + + return Comment( + comment_id=str(cid if cid is not None else ""), + parent_id=parent_id, + author=author_name, + author_id=author_id, + text=_pick(c, "text", "message", "content"), + like_count=int(_reactions_total(_pick(c, "reactions", "reactions_count")) or 0), + published_at=_iso(_pick(c, "date", "timestamp", "publishedAt", "published_at")), + ) diff --git a/content_parser/plugins/telegram/plugin.py b/content_parser/plugins/telegram/plugin.py new file mode 100644 index 0000000..92dd42f --- /dev/null +++ b/content_parser/plugins/telegram/plugin.py @@ -0,0 +1,278 @@ +"""Telegram plugin — public channels and posts via Apify scrapers (read-only). + +Uses an Apify actor (configurable, default ``apify/telegram-channel-scraper``) +because Telegram's Bot API can't read other people's groups and the user-level +MTProto API needs a phone number plus exposes the user's account to bans. +""" +from __future__ import annotations + +import re +from typing import Any, Iterator +from urllib.parse import urlparse + +from ...clients.apify import ApifyClient, ApifyError +from ...core.errors import AuthError, PluginError +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.redact import redact_spec +from ...core.schema import Item +from .adapter import message_to_item + + +_TG_HOSTS = ("t.me", "telegram.me") +_USERNAME_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]{4,31}$") +# Apify actor IDs are <username>/<actor> or <username>~<actor>. +_ACTOR_ID_RE = re.compile(r"^[A-Za-z0-9_-]+[/~][A-Za-z0-9_.-]+$") +_RESERVED_PATHS = { + "joinchat", "addstickers", "share", "iv", "proxy", "socks", "addtheme", + "login", "setlanguage", "addlist", +} + + +def _is_tg_host(host: str) -> bool: + host = (host or "").lower() + return any(host == h or host.endswith("." + h) for h in _TG_HOSTS) + + +class TelegramPlugin(SourcePlugin): + name = "telegram" + label = "Telegram" + secret_keys = ["APIFY_API_TOKEN"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="channel", + label="Каналы", + placeholder="durov\n@telegram\nhttps://t.me/somechan", + help="Username, @handle или URL канала. Парсит последние посты.", + ), + InputSpec( + kind="post_url", + label="Ссылки на посты", + placeholder="https://t.me/durov/123", + help="Конкретный пост — забираем его + комментарии (если у канала есть привязанный чат).", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec( + "actor_id", "Apify actor", "text", + default="apify/telegram-channel-scraper", + help="ID скрейпер-актора в Apify. По умолчанию apify/telegram-channel-scraper. " + "Можешь поставить другой совместимый, например 73code/telegram-scraper.", + ), + FieldSpec("max_messages_per_channel", "Макс. постов с канала", + "number", 50, min_value=1, max_value=1000), + FieldSpec("fetch_comments", "Парсить комментарии", "checkbox", True), + FieldSpec("max_comments_per_post", "Макс. комментариев на пост", + "number", 100, min_value=1, max_value=1000), + FieldSpec("transcribe_videos", "🎤 Транскрибировать видео (Whisper)", "checkbox", False, + help="Скачивает аудио + шлёт в OpenAI Whisper. " + "Нужен OPENAI_API_KEY и ffmpeg. ~$0.006/мин."), + FieldSpec("max_audio_seconds_per_video", "Макс. секунд аудио на пост", + "number", 600, min_value=10, max_value=3600), + ] + + # ------------------------------------------------------------------ + # Resolve + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + specs: list[str] = [] + + for c in inputs.get("channel", []): + specs.append(f"channel:{self._normalize_channel(c)}") + + for u in inputs.get("post_url", []): + url = u.strip() + if not url: + continue + if self._is_private_channel_url(url): + raise PluginError( + f"{url!r} is a private-channel URL (path /c/<chat_id>/...). " + "Apify scrapers can only read public channels." + ) + normalized = self._extract_post_url(url) + if not normalized: + raise PluginError( + f"{url!r} doesn't look like a Telegram post URL " + "(expected t.me/<channel>/<message_id>)." + ) + specs.append(f"post:{normalized}") + + return list(dict.fromkeys(specs)) + + # ------------------------------------------------------------------ + # Fetch + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + token = secrets.get("APIFY_API_TOKEN") + if not token: + raise AuthError("APIFY_API_TOKEN is required") + client = ApifyClient(token) + + actor_id = str(settings.get("actor_id") or "").strip() or "apify/telegram-channel-scraper" + if not _ACTOR_ID_RE.match(actor_id): + raise PluginError( + f"Invalid actor_id {actor_id!r}. Expected 'username/actor' or 'username~actor'." + ) + max_messages = int(settings.get("max_messages_per_channel", 50)) + fetch_comments = bool(settings.get("fetch_comments", True)) + max_comments = int(settings.get("max_comments_per_post", 100)) + + # Group inputs by kind so we make one actor call per kind. + channels: list[str] = [] + post_urls: list[str] = [] + for spec in item_ids: + kind, _, value = spec.partition(":") + if kind == "channel": + channels.append(f"https://t.me/{value}") + elif kind == "post": + post_urls.append(value) + + all_messages: list[dict] = [] + + if channels: + try: + all_messages.extend( + client.run_actor(actor_id, { + "urls": channels, + "channels": channels, + "directUrls": channels, + "maxItems": max_messages, + "messagesPerChannel": max_messages, + "maxMessages": max_messages, + "fetchComments": fetch_comments, + "extractComments": fetch_comments, + "commentsLimit": max_comments, + "maxComments": max_comments, + }) + ) + except ApifyError as e: + raise PluginError(f"Apify call failed (channels): {e}") from e + + if post_urls: + try: + all_messages.extend( + client.run_actor(actor_id, { + "urls": post_urls, + "directUrls": post_urls, + "messageUrls": post_urls, + "maxItems": len(post_urls), + "fetchComments": fetch_comments, + "extractComments": fetch_comments, + "commentsLimit": max_comments, + "maxComments": max_comments, + }) + ) + except ApifyError as e: + raise PluginError(f"Apify call failed (posts): {e}") from e + + # Single-pass parse + dedupe by item_id (no double work for big result sets). + seen: set[str] = set() + items: list[Item] = [] + for i, msg in enumerate(all_messages, 1): + try: + item = message_to_item(msg) + except Exception as e: + items.append(Item( + source="telegram", + item_id=str(msg.get("id") or f"unknown_{i}"), + url=str(msg.get("url") or ""), + extra={"adapter_error": str(e), "raw": msg}, + )) + continue + if item.item_id in seen: + continue + seen.add(item.item_id) + items.append(item) + + total = len(items) + for i, item in enumerate(items, 1): + # Cap comments to settings even if the actor returned more. + if item.comments and len(item.comments) > max_comments: + item.comments = item.comments[:max_comments] + + from ...transcription.runner import maybe_transcribe # noqa: PLC0415 + maybe_transcribe(item, settings, secrets) + + if progress: + progress(i, total, item.item_id) + yield item + + # ------------------------------------------------------------------ + # Input normalization + + @classmethod + def _normalize_channel(cls, raw: str) -> str: + v = raw.strip() + if not v: + raise PluginError("Empty channel value.") + + if v.startswith("http"): + parsed = urlparse(v) + if not _is_tg_host(parsed.hostname or ""): + raise PluginError( + f"{raw!r} is not a Telegram URL (host must be t.me or telegram.me)." + ) + parts = [p for p in parsed.path.split("/") if p] + if not parts: + raise PluginError(f"Cannot parse Telegram URL: {raw!r}") + if parts[0].lower() in _RESERVED_PATHS: + raise PluginError( + f"{raw!r} is a Telegram reserved path, not a channel." + ) + v = parts[0] + + if v.startswith("@"): + v = v[1:] + + if not _USERNAME_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid Telegram username " + "(must start with a letter, 5-32 chars: letters, digits, underscore)." + ) + return v + + @classmethod + def _is_private_channel_url(cls, url: str) -> bool: + """t.me/c/<chat_id>/<msg_id> is the private-channel URL form.""" + v = url.strip() + if not v.startswith("http"): + return False + parsed = urlparse(v) + if not _is_tg_host(parsed.hostname or ""): + return False + parts = [p for p in parsed.path.split("/") if p] + return len(parts) >= 2 and parts[0].lower() == "c" + + @classmethod + def _extract_post_url(cls, url: str) -> str | None: + """Validate a t.me/<channel>/<msg_id> URL and return the canonical https form.""" + v = url.strip() + if not v.startswith("http"): + return None + parsed = urlparse(v) + if not _is_tg_host(parsed.hostname or ""): + return None + parts = [p for p in parsed.path.split("/") if p] + # 'c/<chatid>/<msgid>' pattern is private channels — skip. + if len(parts) < 2 or parts[0].lower() in _RESERVED_PATHS: + return None + if parts[0] == "c": + return None + # Last segment must be a numeric message id. + if not parts[-1].isdigit(): + return None + return f"https://t.me/{parts[0]}/{parts[-1]}" diff --git a/content_parser/plugins/vk/__init__.py b/content_parser/plugins/vk/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/vk/adapter.py b/content_parser/plugins/vk/adapter.py new file mode 100644 index 0000000..6ff4b0b --- /dev/null +++ b/content_parser/plugins/vk/adapter.py @@ -0,0 +1,142 @@ +"""Convert VK API dicts into the unified core schema.""" +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from ...core.schema import Comment, Item + + +def _iso(ts: int | float | None) -> str | None: + if ts is None: + return None + return datetime.fromtimestamp(float(ts), tz=timezone.utc).isoformat() + + +def _label_for_id( + actor_id: int, + profiles_by_id: dict[int, dict], + groups_by_id: dict[int, dict], +) -> tuple[str | None, str | None]: + """Return (display_name, full_id_string) for a VK from_id. + + VK IDs are positive for users, negative for groups. + """ + if actor_id is None: + return None, None + if actor_id > 0: + prof = profiles_by_id.get(actor_id) or {} + first = prof.get("first_name") or "" + last = prof.get("last_name") or "" + name = (first + " " + last).strip() or prof.get("screen_name") or f"id{actor_id}" + return name, f"id{actor_id}" + if actor_id < 0: + group = groups_by_id.get(-actor_id) or {} + name = group.get("name") or group.get("screen_name") or f"club{-actor_id}" + return name, f"club{-actor_id}" + return None, None + + +def post_to_item( + post: dict, + *, + owner_label: str | None = None, + profiles_by_id: dict[int, dict] | None = None, + groups_by_id: dict[int, dict] | None = None, +) -> Item: + """Convert a VK wall post dict to core.Item. + + `owner_label` is the human-readable community/user name; if not provided, + we look it up via groups_by_id / profiles_by_id from the same response. + """ + profiles_by_id = profiles_by_id or {} + groups_by_id = groups_by_id or {} + + if post.get("owner_id") is None or post.get("id") is None: + raise ValueError( + f"Malformed VK post: missing owner_id or id (got keys {sorted(post.keys())[:8]})" + ) + owner_id = int(post["owner_id"]) + post_id = int(post["id"]) + item_id = f"{owner_id}_{post_id}" + + if owner_label is None: + owner_label, _ = _label_for_id(owner_id, profiles_by_id, groups_by_id) + + text = post.get("text") or "" + title = text.split("\n", 1)[0][:120].strip() if text else None + + likes = (post.get("likes") or {}).get("count") + reposts = (post.get("reposts") or {}).get("count") + views = (post.get("views") or {}).get("count") + comments_count = (post.get("comments") or {}).get("count") + + attachments = post.get("attachments") or [] + attachment_types = [a.get("type") for a in attachments if isinstance(a, dict)] + + media: dict = { + "views_count": views, + "likes_count": likes, + "reposts_count": reposts, + "comments_count": comments_count, + "has_photo": "photo" in attachment_types, + "has_video": "video" in attachment_types, + "has_link": "link" in attachment_types, + "has_poll": "poll" in attachment_types, + "is_pinned": bool(post.get("is_pinned")), + "marked_as_ads": bool(post.get("marked_as_ads")), + "post_type": post.get("post_type"), + } + media = {k: v for k, v in media.items() if v not in (None, False, "")} + + return Item( + source="vk", + item_id=item_id, + url=f"https://vk.com/wall{item_id}", + title=title, + author=owner_label, + author_id=str(owner_id) if owner_id else None, + published_at=_iso(post.get("date")), + text=text or None, + media=media, + extra={ + "attachment_types": attachment_types, + "signer_id": post.get("signer_id"), + "copy_history": bool(post.get("copy_history")), + }, + ) + + +def comment_to_core( + c: dict, + *, + parent_id: str | None, + profiles_by_id: dict[int, dict] | None = None, + groups_by_id: dict[int, dict] | None = None, +) -> Comment: + profiles_by_id = profiles_by_id or {} + groups_by_id = groups_by_id or {} + + from_id = int(c.get("from_id", 0)) + author_name, author_label = _label_for_id(from_id, profiles_by_id, groups_by_id) + + return Comment( + comment_id=str(c.get("id", "") or ""), + parent_id=parent_id, + author=author_name, + author_id=author_label, + text=c.get("text"), + like_count=int((c.get("likes") or {}).get("count", 0) or 0), + published_at=_iso(c.get("date")), + ) + + +def index_by_id(items: list[dict], key: str = "id") -> dict[int, dict]: + """Build a dict keyed by an integer id field.""" + out: dict[int, dict] = {} + for it in items or []: + try: + out[int(it[key])] = it + except (KeyError, TypeError, ValueError): + continue + return out diff --git a/content_parser/plugins/vk/client.py b/content_parser/plugins/vk/client.py new file mode 100644 index 0000000..780adca --- /dev/null +++ b/content_parser/plugins/vk/client.py @@ -0,0 +1,103 @@ +"""Thin client over the VK API (api.vk.com). + +The token travels in the POST body — never as a query string — so it doesn't +leak into nginx access logs or Streamlit's URL bar history. +""" +from __future__ import annotations + +import time +from typing import Any + +import requests + +from ...core.errors import AuthError, PluginError, RateLimitError + + +VK_API_URL = "https://api.vk.com/method/{method}" +VK_API_VERSION = "5.199" + + +# VK error codes that mean the request will keep failing without intervention. +_AUTH_ERROR_CODES = {5, 17, 27, 28} # bad/expired token, blocked, etc. +_RATE_LIMIT_CODES = {6, 9, 29} # too many requests / per-second flood + + +class VKClient: + def __init__( + self, + token: str, + timeout: int = 60, + version: str = VK_API_VERSION, + max_rate_limit_retries: int = 3, + ): + if not token: + raise ValueError("VK access token is required") + self.token = token + self.timeout = timeout + self.version = version + self.max_rate_limit_retries = max_rate_limit_retries + # Reuse one connection across requests to avoid TLS handshake per call. + self.session = requests.Session() + + @staticmethod + def _sleep(seconds: float) -> None: + time.sleep(seconds) + + def call(self, method: str, **params: Any) -> Any: + """Call a VK API method with retries on rate-limit errors. + + Returns the unwrapped 'response' field. Raises AuthError on bad + token, RateLimitError after exhausting retries, PluginError on + other API or transport errors. + """ + delay = 1.0 + last_rate_limit: RateLimitError | None = None + for attempt in range(self.max_rate_limit_retries + 1): + try: + return self._call_once(method, params) + except RateLimitError as e: + last_rate_limit = e + if attempt >= self.max_rate_limit_retries: + raise + self._sleep(delay) + delay *= 2 + # Defensive — only reachable if max_rate_limit_retries < 0. + raise last_rate_limit # type: ignore[misc] + + def _call_once(self, method: str, params: dict[str, Any]) -> Any: + body = {"access_token": self.token, "v": self.version} + for k, v in params.items(): + if v is None: + continue + if isinstance(v, bool): + body[k] = "1" if v else "0" + elif isinstance(v, (list, tuple)): + body[k] = ",".join(str(x) for x in v) + else: + body[k] = str(v) + + url = VK_API_URL.format(method=method) + try: + r = self.session.post(url, data=body, timeout=self.timeout) + except requests.RequestException as e: + raise PluginError(f"Network error calling VK {method}: {e}") from e + + if not r.ok: + raise PluginError(f"VK {method} returned HTTP {r.status_code}") + + try: + data = r.json() + except ValueError as e: + raise PluginError(f"VK {method} returned non-JSON: {r.text[:200]}") from e + + if "error" in data: + err = data["error"] + code = err.get("error_code") + msg = err.get("error_msg", "unknown") + if code in _AUTH_ERROR_CODES: + raise AuthError(f"VK auth error ({code}): {msg}") + if code in _RATE_LIMIT_CODES: + raise RateLimitError(f"VK rate limit ({code}): {msg}") + raise PluginError(f"VK {method} error ({code}): {msg}") + + return data.get("response") diff --git a/content_parser/plugins/vk/plugin.py b/content_parser/plugins/vk/plugin.py new file mode 100644 index 0000000..6b8eedf --- /dev/null +++ b/content_parser/plugins/vk/plugin.py @@ -0,0 +1,438 @@ +"""VK plugin — communities, walls, and comments via VK API (read-only).""" +from __future__ import annotations + +import logging +import re +from typing import Any, Iterator +from urllib.parse import urlparse + +from ...core.errors import AuthError, PluginError +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.redact import redact_spec +from ...core.schema import Item +from .adapter import ( + comment_to_core, + index_by_id, + post_to_item, +) +from .client import VKClient + + +logger = logging.getLogger(__name__) + + +_VK_HOSTS = ("vk.com", "vk.ru", "m.vk.com") +_SCREEN_NAME_RE = re.compile(r"^[A-Za-z0-9_.]{1,32}$") +# Wall post URL: vk.com/wall<owner>_<post> (owner can be -group or +user) +_WALL_RE = re.compile(r"^wall(-?\d+)_(\d+)$") +# Common VK reserved-namespace prefixes that aren't communities/users. +_RESERVED_PATH_PREFIXES = { + "wall", "feed", "video", "audio", "doc", "photo", "id", "club", "public", + "event", "topic", "albums", "market", "im", "friends", "settings", + "search", "groups", "apps", "stickers", "support", "dev", +} + + +def _is_vk_host(host: str) -> bool: + host = (host or "").lower() + return any(host == h or host.endswith("." + h) for h in _VK_HOSTS) + + +class VKPlugin(SourcePlugin): + name = "vk" + label = "ВКонтакте" + secret_keys = ["VK_ACCESS_TOKEN"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="query", + label="Поиск сообществ", + placeholder="маркетинг\nрилсы для бизнеса", + help="Полнотекстовый поиск по сообществам ВК. Парсит стены найденных групп.", + ), + InputSpec( + kind="community", + label="Сообщества", + placeholder="durov_says\nhttps://vk.com/club1\n12345", + help="Screen name, club-ID или URL сообщества.", + ), + InputSpec( + kind="post_url", + label="Ссылки на посты", + placeholder="https://vk.com/wall-12345_678", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec("max_communities_per_query", "Макс. сообществ на поисковый запрос", + "number", 10, min_value=1, max_value=100), + FieldSpec("max_posts_per_input", "Макс. постов с сообщества/из поиска", + "number", 25, min_value=1, max_value=100, + help="Лимит VK API на wall.get — 100 постов за вызов."), + FieldSpec("fetch_comments", "Парсить комментарии", "checkbox", True), + FieldSpec("max_comments_per_post", "Макс. комментариев на пост", + "number", 100, min_value=1, max_value=1000), + FieldSpec("comment_depth", "Глубина комментариев", "select", "top_level", + options=["top_level", "all"], + help="top_level — только верхний уровень; all — со всеми ответами."), + FieldSpec("transcribe_videos", "🎤 Транскрибировать видео (Whisper)", "checkbox", False, + help="Скачивает аудио + шлёт в OpenAI Whisper. " + "Нужен OPENAI_API_KEY и ffmpeg. ~$0.006/мин."), + FieldSpec("max_audio_seconds_per_video", "Макс. секунд аудио на пост", + "number", 600, min_value=10, max_value=3600), + ] + + # ------------------------------------------------------------------ + # Resolve + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + """Returns 'kind:value' specs so fetch() routes them through VK API.""" + specs: list[str] = [] + + for q in inputs.get("query", []): + q = q.strip() + if q: + specs.append(f"query:{q}") + + for c in inputs.get("community", []): + specs.append(f"community:{self._normalize_community(c)}") + + for url in inputs.get("post_url", []): + url = url.strip() + if not url: + continue + wall_id = self._extract_wall_id(url) + if not wall_id: + raise PluginError( + f"{url!r} doesn't look like a VK post URL " + "(expected vk.com/wall-<group_id>_<post_id>)." + ) + specs.append(f"post:{wall_id}") + + return list(dict.fromkeys(specs)) + + # ------------------------------------------------------------------ + # Fetch + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + token = secrets.get("VK_ACCESS_TOKEN") + if not token: + raise AuthError("VK_ACCESS_TOKEN is required") + client = VKClient(token) + + max_communities = int(settings.get("max_communities_per_query", 10)) + max_posts = min(int(settings.get("max_posts_per_input", 25)), 100) + fetch_comments = bool(settings.get("fetch_comments", True)) + max_comments = int(settings.get("max_comments_per_post", 100)) + depth = str(settings.get("comment_depth", "top_level")) + + # Step 1: resolve all specs into a list of (post_dict, owner_label) + # plus a cache of group/profile dicts for author resolution in comments. + post_jobs: list[tuple[dict, str | None]] = [] # (post, owner_label) + groups_cache: dict[int, dict] = {} + profiles_cache: dict[int, dict] = {} + + for spec in item_ids: + kind, _, value = spec.partition(":") + try: + self._collect_for_spec( + client, kind, value, + max_communities=max_communities, + max_posts=max_posts, + post_jobs=post_jobs, + groups_cache=groups_cache, + profiles_cache=profiles_cache, + ) + except (AuthError, PluginError): + raise + except Exception as e: + raise PluginError( + f"VK error for {redact_spec(spec)!r}: {e}" + ) from e + + # Dedupe by VK item_id (owner_post) + seen: set[str] = set() + unique_jobs: list[tuple[dict, str | None]] = [] + for post, owner_label in post_jobs: + owner_id = post.get("owner_id") + post_id = post.get("id") + iid = f"{owner_id}_{post_id}" + if iid not in seen: + seen.add(iid) + unique_jobs.append((post, owner_label)) + + # Step 2: yield items, optionally with comments. + total = len(unique_jobs) + for i, (post, owner_label) in enumerate(unique_jobs, 1): + item = post_to_item( + post, + owner_label=owner_label, + profiles_by_id=profiles_cache, + groups_by_id=groups_cache, + ) + if fetch_comments: + try: + item.comments = self._fetch_comments( + client, + owner_id=int(post["owner_id"]), + post_id=int(post["id"]), + max_comments=max_comments, + depth=depth, + ) + except Exception as e: + item.extra["comments_error"] = str(e) + + from ...transcription.runner import maybe_transcribe # noqa: PLC0415 + maybe_transcribe(item, settings, secrets) + + if progress: + progress(i, total, item.item_id) + yield item + + # ------------------------------------------------------------------ + # Per-spec collection + + def _collect_for_spec( + self, + client: VKClient, + kind: str, + value: str, + *, + max_communities: int, + max_posts: int, + post_jobs: list[tuple[dict, str | None]], + groups_cache: dict[int, dict], + profiles_cache: dict[int, dict], + ) -> None: + if kind == "query": + search_resp = client.call("groups.search", q=value, count=max_communities) + groups = (search_resp or {}).get("items", []) + for g in groups: + groups_cache[int(g["id"])] = g + self._collect_wall(client, -int(g["id"]), g.get("name"), max_posts, post_jobs, profiles_cache, groups_cache) + + elif kind == "community": + owner_id, label = self._resolve_community(client, value, groups_cache, profiles_cache) + self._collect_wall(client, owner_id, label, max_posts, post_jobs, profiles_cache, groups_cache) + + elif kind == "post": + posts_resp = client.call("wall.getById", posts=value, extended=1) + items = self._extract_extended(posts_resp, profiles_cache, groups_cache) + # owner_label=None — adapter resolves via shared profile/group caches. + for post in items: + post_jobs.append((post, None)) + + else: + raise PluginError(f"Unknown VK input kind: {kind!r}") + + def _collect_wall( + self, + client: VKClient, + owner_id: int, + owner_label: str | None, + max_posts: int, + post_jobs: list[tuple[dict, str | None]], + profiles_cache: dict[int, dict], + groups_cache: dict[int, dict], + ) -> None: + resp = client.call("wall.get", owner_id=owner_id, count=max_posts, extended=1) + items = self._extract_extended(resp, profiles_cache, groups_cache) + for post in items: + post_jobs.append((post, owner_label)) + + @staticmethod + def _extract_extended( + resp: Any, + profiles_cache: dict[int, dict], + groups_cache: dict[int, dict], + ) -> list[dict]: + """Pull items from a VK extended=1 response and merge profiles/groups into caches.""" + if not isinstance(resp, dict): + return resp or [] + for p in resp.get("profiles") or []: + try: + profiles_cache[int(p["id"])] = p + except (KeyError, TypeError, ValueError): + continue + for g in resp.get("groups") or []: + try: + groups_cache[int(g["id"])] = g + except (KeyError, TypeError, ValueError): + continue + return resp.get("items", []) + + # ------------------------------------------------------------------ + # Comments + + def _fetch_comments( + self, + client: VKClient, + *, + owner_id: int, + post_id: int, + max_comments: int, + depth: str, + ) -> list: + """Paginate wall.getComments respecting `max_comments` exactly. + + VK caps `count` at 100 per request. We check the cap *before* every + append (top-level or reply) so the result is always ≤ max_comments, + and we use the response's `count` field to short-circuit pagination + instead of doing one extra round-trip just to confirm the end. + """ + out: list = [] + offset = 0 + page = min(100, max_comments) + thread_count = 0 if depth == "top_level" else 10 + + while len(out) < max_comments: + resp = client.call( + "wall.getComments", + owner_id=owner_id, + post_id=post_id, + offset=offset, + count=page, + need_likes=1, + extended=1, + thread_items_count=thread_count, + sort="asc", + ) + if not isinstance(resp, dict): + break + items = resp.get("items", []) + profiles = index_by_id(resp.get("profiles") or []) + groups = index_by_id(resp.get("groups") or []) + total = int(resp.get("count", 0) or 0) + + for c in items: + if len(out) >= max_comments: + break + out.append( + comment_to_core( + c, parent_id=None, profiles_by_id=profiles, groups_by_id=groups + ) + ) + if depth == "all": + thread = c.get("thread") or {} + parent_id = str(c.get("id", "") or "") + for reply in thread.get("items", []) or []: + if len(out) >= max_comments: + break + out.append( + comment_to_core( + reply, + parent_id=parent_id, + profiles_by_id=profiles, + groups_by_id=groups, + ) + ) + + if not items: + break + offset += len(items) + # Short-circuit: VK told us the total — stop if we've consumed it + # or the page came back smaller than requested (no more data). + if total and offset >= total: + break + if len(items) < page: + break + + return out + + # ------------------------------------------------------------------ + # Community resolution + + def _resolve_community( + self, + client: VKClient, + value: str, + groups_cache: dict[int, dict], + profiles_cache: dict[int, dict], + ) -> tuple[int, str | None]: + """Return (owner_id, label) for a community spec. + + owner_id is negative for groups; positive for users (rare for + community input but handled because VK treats them the same way). + """ + resp = client.call("groups.getById", group_ids=value, fields="name,screen_name") + if isinstance(resp, dict): + items = resp.get("groups") or [] + else: + items = resp or [] + if not items: + raise PluginError(f"VK community {value!r} not found") + g = items[0] + gid = int(g["id"]) + groups_cache[gid] = g + return -gid, g.get("name") + + # ------------------------------------------------------------------ + # Input normalization + + @classmethod + def _normalize_community(cls, raw: str) -> str: + """Return a value usable in groups.getById: screen_name, 'club<id>', or numeric id.""" + v = raw.strip() + if not v: + raise PluginError("Empty community value.") + + if v.startswith("http"): + parsed = urlparse(v) + if not _is_vk_host(parsed.hostname or ""): + raise PluginError( + f"{raw!r} is not a VK URL (host must be vk.com)." + ) + parts = [p for p in parsed.path.split("/") if p] + if not parts: + raise PluginError(f"Cannot parse community URL: {raw!r}") + v = parts[0] + + # 'club12345', 'public12345' — VK community URL prefixes + m = re.match(r"^(?:club|public)(\d+)$", v, re.IGNORECASE) + if m: + return f"club{m.group(1)}" + + if v.isdigit(): + return v + + if v.lower() in _RESERVED_PATH_PREFIXES: + raise PluginError( + f"{raw!r} is a VK reserved path, not a community." + ) + + if not _SCREEN_NAME_RE.match(v): + raise PluginError( + f"{raw!r} is not a valid VK community identifier " + "(letters, digits, underscore, dot; 1-32 chars)." + ) + return v + + @classmethod + def _extract_wall_id(cls, url: str) -> str | None: + """Return 'owner_post' string for a vk.com/wall... URL, else None.""" + v = url.strip() + if v.startswith("http"): + parsed = urlparse(v) + if not _is_vk_host(parsed.hostname or ""): + return None + parts = [p for p in parsed.path.split("/") if p] + if not parts: + return None + v = parts[0] + m = _WALL_RE.match(v) + if not m: + return None + return f"{m.group(1)}_{m.group(2)}" diff --git a/content_parser/plugins/youtube/__init__.py b/content_parser/plugins/youtube/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/plugins/youtube/adapter.py b/content_parser/plugins/youtube/adapter.py new file mode 100644 index 0000000..d996161 --- /dev/null +++ b/content_parser/plugins/youtube/adapter.py @@ -0,0 +1,49 @@ +"""Convert YouTube API dicts into the unified core schema.""" +from __future__ import annotations + +from ...core.schema import Comment, Item, Transcript + + +def metadata_to_item(meta: dict) -> Item: + return Item( + source="youtube", + item_id=meta["video_id"], + url=meta.get("url") or f"https://www.youtube.com/watch?v={meta['video_id']}", + title=meta.get("title"), + author=meta.get("channel_title"), + author_id=meta.get("channel_id"), + published_at=meta.get("published_at"), + text=meta.get("description"), + media={ + "duration": meta.get("duration"), + "view_count": meta.get("view_count"), + "like_count": meta.get("like_count"), + "comment_count": meta.get("comment_count"), + }, + extra={"tags": meta.get("tags", [])}, + ) + + +def comment_dict_to_comment(c: dict) -> Comment: + return Comment( + comment_id=c["comment_id"], + parent_id=c.get("parent_id"), + author=c.get("author"), + author_id=c.get("author_channel_id"), + text=c.get("text"), + like_count=c.get("like_count", 0) or 0, + published_at=c.get("published_at"), + updated_at=c.get("updated_at"), + ) + + +def transcript_dict_to_transcript(t: dict | None) -> Transcript | None: + if not t: + return None + return Transcript( + language=t.get("language"), + is_generated=t.get("is_generated"), + segments=t.get("segments", []), + text=t.get("text", ""), + error=t.get("error"), + ) diff --git a/content_parser/plugins/youtube/comments.py b/content_parser/plugins/youtube/comments.py new file mode 100644 index 0000000..a85f892 --- /dev/null +++ b/content_parser/plugins/youtube/comments.py @@ -0,0 +1,122 @@ +"""Fetch comments (and optionally replies) for a video via YouTube Data API.""" +from __future__ import annotations + +from googleapiclient.discovery import Resource +from googleapiclient.errors import HttpError + + +def _format_comment(snippet: dict, comment_id: str, parent_id: str | None = None) -> dict: + return { + "comment_id": comment_id, + "parent_id": parent_id, + "author": snippet.get("authorDisplayName"), + "author_channel_id": (snippet.get("authorChannelId") or {}).get("value"), + "text": snippet.get("textOriginal") or snippet.get("textDisplay"), + "like_count": snippet.get("likeCount", 0), + "published_at": snippet.get("publishedAt"), + "updated_at": snippet.get("updatedAt"), + } + + +def fetch_comments( + youtube: Resource, + video_id: str, + *, + include_replies: bool = False, + max_comments: int | None = None, + order: str = "relevance", +) -> list[dict]: + """Fetch top-level comments. If include_replies=True, also pull all replies.""" + comments: list[dict] = [] + page_token: str | None = None + + while True: + try: + response = ( + youtube.commentThreads() + .list( + part="snippet,replies" if include_replies else "snippet", + videoId=video_id, + maxResults=100, + order=order, + pageToken=page_token, + textFormat="plainText", + ) + .execute() + ) + except HttpError as e: + if e.resp.status == 403 and b"commentsDisabled" in e.content: + return [] + raise + + for item in response.get("items", []): + top_snippet = item["snippet"]["topLevelComment"]["snippet"] + top_id = item["snippet"]["topLevelComment"]["id"] + comments.append(_format_comment(top_snippet, top_id)) + + if max_comments and len(comments) >= max_comments: + return comments + + if include_replies: + # Bound the reply pull by remaining cap so a single popular + # top-level comment with hundreds of replies doesn't burn quota + # only for the slice [:max_comments] to throw most of them away. + remaining = (max_comments - len(comments)) if max_comments else None + reply_count = item["snippet"].get("totalReplyCount", 0) + inline_replies = item.get("replies", {}).get("comments", []) + if reply_count and len(inline_replies) < reply_count: + comments.extend( + _fetch_all_replies(youtube, top_id, max_replies=remaining) + ) + else: + for reply in inline_replies: + if remaining is not None and remaining <= 0: + break + comments.append( + _format_comment(reply["snippet"], reply["id"], parent_id=top_id) + ) + if remaining is not None: + remaining -= 1 + if max_comments and len(comments) >= max_comments: + return comments[:max_comments] + + page_token = response.get("nextPageToken") + if not page_token: + break + + return comments + + +def _fetch_all_replies( + youtube: Resource, + parent_id: str, + *, + max_replies: int | None = None, +) -> list[dict]: + replies: list[dict] = [] + page_token: str | None = None + while True: + if max_replies is not None and len(replies) >= max_replies: + return replies[:max_replies] + page_size = 100 if max_replies is None else min(100, max_replies - len(replies)) + if page_size <= 0: + return replies + response = ( + youtube.comments() + .list( + part="snippet", + parentId=parent_id, + maxResults=page_size, + pageToken=page_token, + textFormat="plainText", + ) + .execute() + ) + for item in response.get("items", []): + replies.append(_format_comment(item["snippet"], item["id"], parent_id=parent_id)) + if max_replies is not None and len(replies) >= max_replies: + return replies[:max_replies] + page_token = response.get("nextPageToken") + if not page_token: + break + return replies diff --git a/content_parser/plugins/youtube/plugin.py b/content_parser/plugins/youtube/plugin.py new file mode 100644 index 0000000..b8bfb6b --- /dev/null +++ b/content_parser/plugins/youtube/plugin.py @@ -0,0 +1,159 @@ +"""YouTube plugin — implements SourcePlugin contract.""" +from __future__ import annotations + +from typing import Any, Iterator + +from googleapiclient.discovery import build + +from ...core.plugin import FieldSpec, InputSpec, ProgressCb, SourcePlugin +from ...core.schema import Item +from .adapter import ( + comment_dict_to_comment, + metadata_to_item, + transcript_dict_to_transcript, +) +from .comments import fetch_comments +from .sources import collect_video_ids, fetch_video_metadata +from .transcripts import fetch_transcript_verbose + + +class YouTubePlugin(SourcePlugin): + name = "youtube" + label = "YouTube" + secret_keys = ["YOUTUBE_API_KEY"] + + def input_specs(self) -> list[InputSpec]: + return [ + InputSpec( + kind="query", + label="Поисковые запросы", + placeholder="python tutorial\nclaude code demo", + help="Каждый запрос стоит 100 единиц квоты (из 10 000 в день).", + ), + InputSpec( + kind="channel", + label="Каналы", + placeholder="https://youtube.com/@veritasium\n@3blue1brown\nUC...", + help="URL, @handle или channel ID.", + ), + InputSpec( + kind="playlist", + label="Плейлисты", + placeholder="https://youtube.com/playlist?list=PL...", + ), + InputSpec( + kind="video", + label="Видео", + placeholder="https://youtu.be/dQw4w9WgXcQ", + ), + ] + + def settings_specs(self) -> list[FieldSpec]: + return [ + FieldSpec("fetch_comments", "Парсить комментарии", "checkbox", True), + FieldSpec("include_replies", "Включая ответы", "checkbox", False), + FieldSpec("max_comments", "Макс. комментариев на видео (0 = все)", + "number", 0, min_value=0), + FieldSpec("comment_order", "Сортировка комментариев", "select", + "relevance", options=["relevance", "time"]), + FieldSpec("fetch_transcripts", "Парсить транскрипты", "checkbox", True), + FieldSpec("transcript_langs", "Языки транскриптов (через запятую)", + "text", "ru,en"), + FieldSpec("search_max", "Макс. видео на запрос", "number", 10, + min_value=1, max_value=500), + FieldSpec("per_source_max", + "Макс. видео на канал/плейлист (0 = все)", + "number", 20, min_value=0), + FieldSpec("proxy_provider", "Прокси для транскриптов", "select", + "Без прокси", options=["Без прокси", "Webshare", "HTTP-прокси"], + help="На Streamlit Cloud YouTube блокирует запросы за субтитрами."), + FieldSpec("transcribe_videos", "🎤 Whisper fallback (если субтитров нет)", "checkbox", False, + help="Когда youtube-transcript-api не вернул субтитры, скачивает аудио " + "и транскрибирует через OpenAI Whisper. Нужен OPENAI_API_KEY и ffmpeg."), + FieldSpec("max_audio_seconds_per_video", "Макс. секунд аудио на видео", + "number", 600, min_value=10, max_value=3600), + ] + + def resolve( + self, + inputs: dict[str, list[str]], + settings: dict[str, Any], + secrets: dict[str, str], + ) -> list[str]: + youtube = self._client(secrets) + return collect_video_ids( + youtube, + queries=inputs.get("query", []), + channels=inputs.get("channel", []), + playlists=inputs.get("playlist", []), + videos=inputs.get("video", []), + search_max=int(settings.get("search_max", 10)), + per_source_max=int(settings.get("per_source_max", 0)) or None, + ) + + def fetch( + self, + item_ids: list[str], + settings: dict[str, Any], + secrets: dict[str, str], + progress: ProgressCb | None = None, + ) -> Iterator[Item]: + youtube = self._client(secrets) + metadata = fetch_video_metadata(youtube, item_ids) + languages = [s.strip() for s in str(settings.get("transcript_langs", "ru,en")).split(",") if s.strip()] + proxy_config = self._build_proxy(settings, secrets) + max_comments = int(settings.get("max_comments", 0)) or None + + for i, vid in enumerate(item_ids, 1): + meta = metadata.get(vid) + if progress: + progress(i, len(item_ids), vid) + if not meta: + continue + + item = metadata_to_item(meta) + + if settings.get("fetch_comments", True): + try: + raw = fetch_comments( + youtube, vid, + include_replies=bool(settings.get("include_replies")), + max_comments=max_comments, + order=str(settings.get("comment_order", "relevance")), + ) + item.comments = [comment_dict_to_comment(c) for c in raw] + except Exception as e: + item.extra["comments_error"] = str(e) + + if settings.get("fetch_transcripts", True): + t = fetch_transcript_verbose(vid, languages=languages, proxy_config=proxy_config) + item.transcript = transcript_dict_to_transcript(t) + + # Whisper fallback: only if youtube-transcript-api couldn't produce + # segments (subtitles disabled, blocked, etc.) AND user opted in. + from ...transcription.runner import maybe_transcribe # noqa: PLC0415 + maybe_transcribe(item, settings, secrets, only_if_missing=True) + + yield item + + def _client(self, secrets: dict[str, str]): + key = secrets.get("YOUTUBE_API_KEY") + if not key: + raise ValueError("YOUTUBE_API_KEY is required") + return build("youtube", "v3", developerKey=key, cache_discovery=False) + + def _build_proxy(self, settings: dict[str, Any], secrets: dict[str, str]): + provider = settings.get("proxy_provider", "Без прокси") + if provider == "Webshare": + user = secrets.get("WEBSHARE_USERNAME", "") + pwd = secrets.get("WEBSHARE_PASSWORD", "") + if user and pwd: + from youtube_transcript_api.proxies import WebshareProxyConfig + return WebshareProxyConfig(proxy_username=user, proxy_password=pwd) + elif provider == "HTTP-прокси": + http = secrets.get("PROXY_HTTP_URL", "") + https = secrets.get("PROXY_HTTPS_URL", "") or http + if http or https: + from youtube_transcript_api.proxies import GenericProxyConfig + return GenericProxyConfig(http_url=http or None, https_url=https or None) + return None diff --git a/content_parser/plugins/youtube/sources.py b/content_parser/plugins/youtube/sources.py new file mode 100644 index 0000000..241687d --- /dev/null +++ b/content_parser/plugins/youtube/sources.py @@ -0,0 +1,214 @@ +"""Resolve search queries, channel URLs, and playlist URLs to lists of video IDs.""" +from __future__ import annotations + +import re +from typing import Iterable +from urllib.parse import parse_qs, urlparse + +from googleapiclient.discovery import Resource + + +VIDEO_ID_RE = re.compile(r"^[A-Za-z0-9_-]{11}$") +CHANNEL_ID_RE = re.compile(r"^UC[A-Za-z0-9_-]{22}$") +PLAYLIST_ID_RE = re.compile(r"^(PL|UU|FL|RD|OL|LL)[A-Za-z0-9_-]+$") + + +def extract_video_id(url_or_id: str) -> str | None: + if VIDEO_ID_RE.match(url_or_id): + return url_or_id + parsed = urlparse(url_or_id) + if parsed.hostname in ("youtu.be",): + vid = parsed.path.lstrip("/") + return vid if VIDEO_ID_RE.match(vid) else None + if parsed.hostname and "youtube" in parsed.hostname: + if parsed.path == "/watch": + vid = parse_qs(parsed.query).get("v", [None])[0] + return vid if vid and VIDEO_ID_RE.match(vid) else None + if parsed.path.startswith("/shorts/") or parsed.path.startswith("/embed/"): + vid = parsed.path.split("/")[2] + return vid if VIDEO_ID_RE.match(vid) else None + return None + + +def extract_playlist_id(url_or_id: str) -> str | None: + if PLAYLIST_ID_RE.match(url_or_id): + return url_or_id + parsed = urlparse(url_or_id) + pid = parse_qs(parsed.query).get("list", [None])[0] + return pid if pid and PLAYLIST_ID_RE.match(pid) else None + + +def resolve_channel_id(youtube: Resource, channel_input: str) -> str: + """Accepts a channel ID, /channel/UC..., /@handle, /c/name, or /user/name URL.""" + if CHANNEL_ID_RE.match(channel_input): + return channel_input + + parsed = urlparse(channel_input) + path = parsed.path.strip("/") if parsed.path else channel_input.lstrip("@") + parts = path.split("/") + + if parts and parts[0] == "channel" and len(parts) > 1 and CHANNEL_ID_RE.match(parts[1]): + return parts[1] + + handle = None + username = None + if parts and parts[0].startswith("@"): + handle = parts[0] + elif parts and parts[0] == "c" and len(parts) > 1: + handle = "@" + parts[1] + elif parts and parts[0] == "user" and len(parts) > 1: + username = parts[1] + elif channel_input.startswith("@"): + handle = channel_input + + request_kwargs: dict = {"part": "id"} + if handle: + request_kwargs["forHandle"] = handle + elif username: + request_kwargs["forUsername"] = username + else: + raise ValueError(f"Cannot resolve channel from input: {channel_input!r}") + + response = youtube.channels().list(**request_kwargs).execute() + items = response.get("items", []) + if not items: + raise ValueError(f"Channel not found: {channel_input!r}") + return items[0]["id"] + + +def get_uploads_playlist_id(youtube: Resource, channel_id: str) -> str: + response = youtube.channels().list(part="contentDetails", id=channel_id).execute() + items = response.get("items", []) + if not items: + raise ValueError(f"Channel not found: {channel_id}") + return items[0]["contentDetails"]["relatedPlaylists"]["uploads"] + + +def list_playlist_video_ids( + youtube: Resource, playlist_id: str, max_results: int | None = None +) -> list[str]: + video_ids: list[str] = [] + page_token: str | None = None + while True: + response = ( + youtube.playlistItems() + .list( + part="contentDetails", + playlistId=playlist_id, + maxResults=50, + pageToken=page_token, + ) + .execute() + ) + for item in response.get("items", []): + vid = item["contentDetails"]["videoId"] + video_ids.append(vid) + if max_results and len(video_ids) >= max_results: + return video_ids + page_token = response.get("nextPageToken") + if not page_token: + break + return video_ids + + +def search_video_ids(youtube: Resource, query: str, max_results: int = 25) -> list[str]: + """Uses search.list (100 quota units per call). Each call returns up to 50 results.""" + video_ids: list[str] = [] + page_token: str | None = None + while len(video_ids) < max_results: + page_size = min(50, max_results - len(video_ids)) + response = ( + youtube.search() + .list( + part="id", + q=query, + type="video", + maxResults=page_size, + pageToken=page_token, + ) + .execute() + ) + for item in response.get("items", []): + vid = item["id"].get("videoId") + if vid: + video_ids.append(vid) + page_token = response.get("nextPageToken") + if not page_token: + break + return video_ids + + +def collect_video_ids( + youtube: Resource, + *, + queries: Iterable[str] = (), + channels: Iterable[str] = (), + playlists: Iterable[str] = (), + videos: Iterable[str] = (), + search_max: int = 25, + per_source_max: int | None = None, +) -> list[str]: + """Resolve all inputs to a deduplicated, ordered list of video IDs.""" + seen: set[str] = set() + result: list[str] = [] + + def add(vid: str) -> None: + if vid and vid not in seen: + seen.add(vid) + result.append(vid) + + for q in queries: + for vid in search_video_ids(youtube, q, max_results=search_max): + add(vid) + + for ch in channels: + channel_id = resolve_channel_id(youtube, ch) + uploads_id = get_uploads_playlist_id(youtube, channel_id) + for vid in list_playlist_video_ids(youtube, uploads_id, max_results=per_source_max): + add(vid) + + for pl in playlists: + playlist_id = extract_playlist_id(pl) + if not playlist_id: + raise ValueError(f"Cannot parse playlist input: {pl!r}") + for vid in list_playlist_video_ids(youtube, playlist_id, max_results=per_source_max): + add(vid) + + for v in videos: + vid = extract_video_id(v) + if not vid: + raise ValueError(f"Cannot parse video input: {v!r}") + add(vid) + + return result + + +def fetch_video_metadata(youtube: Resource, video_ids: list[str]) -> dict[str, dict]: + """Returns a dict mapping video_id -> metadata. Batches up to 50 IDs per call.""" + result: dict[str, dict] = {} + for i in range(0, len(video_ids), 50): + batch = video_ids[i : i + 50] + response = ( + youtube.videos() + .list(part="snippet,statistics,contentDetails", id=",".join(batch)) + .execute() + ) + for item in response.get("items", []): + snippet = item.get("snippet", {}) + stats = item.get("statistics", {}) + details = item.get("contentDetails", {}) + result[item["id"]] = { + "video_id": item["id"], + "title": snippet.get("title"), + "description": snippet.get("description"), + "channel_id": snippet.get("channelId"), + "channel_title": snippet.get("channelTitle"), + "published_at": snippet.get("publishedAt"), + "tags": snippet.get("tags", []), + "duration": details.get("duration"), + "view_count": int(stats.get("viewCount", 0)) if stats.get("viewCount") else None, + "like_count": int(stats.get("likeCount", 0)) if stats.get("likeCount") else None, + "comment_count": int(stats.get("commentCount", 0)) if stats.get("commentCount") else None, + "url": f"https://www.youtube.com/watch?v={item['id']}", + } + return result diff --git a/content_parser/plugins/youtube/transcripts.py b/content_parser/plugins/youtube/transcripts.py new file mode 100644 index 0000000..da7cd2b --- /dev/null +++ b/content_parser/plugins/youtube/transcripts.py @@ -0,0 +1,107 @@ +"""Fetch transcripts via youtube-transcript-api (no API quota cost).""" +from __future__ import annotations + +from typing import Any + +from youtube_transcript_api import ( + AgeRestricted, + IpBlocked, + NoTranscriptFound, + RequestBlocked, + TranscriptsDisabled, + VideoUnavailable, + VideoUnplayable, + YouTubeRequestFailed, + YouTubeTranscriptApi, +) + + +def fetch_transcript( + video_id: str, + languages: list[str] | None = None, + proxy_config: Any | None = None, +) -> dict | None: + result = fetch_transcript_verbose( + video_id, languages=languages, proxy_config=proxy_config + ) + if result.get("error"): + return None + if not result.get("segments"): + return None + return result + + +def fetch_transcript_verbose( + video_id: str, + languages: list[str] | None = None, + proxy_config: Any | None = None, +) -> dict: + preferred = languages or ["ru", "en"] + api = YouTubeTranscriptApi(proxy_config=proxy_config) if proxy_config else YouTubeTranscriptApi() + + try: + transcript_list = api.list(video_id) + except TranscriptsDisabled: + return {"error": "disabled", "segments": [], "text": "", "language": None, "is_generated": None} + except (VideoUnavailable, VideoUnplayable, AgeRestricted) as e: + return {"error": f"unavailable: {type(e).__name__}", "segments": [], "text": "", + "language": None, "is_generated": None} + except (IpBlocked, RequestBlocked, YouTubeRequestFailed) as e: + return {"error": f"blocked: {type(e).__name__}", "segments": [], "text": "", + "language": None, "is_generated": None} + except Exception as e: + return {"error": f"{type(e).__name__}: {e}", "segments": [], "text": "", + "language": None, "is_generated": None} + + transcript = None + try: + transcript = transcript_list.find_manually_created_transcript(preferred) + except NoTranscriptFound: + pass + + if transcript is None: + try: + transcript = transcript_list.find_generated_transcript(preferred) + except NoTranscriptFound: + pass + + if transcript is None: + try: + any_t = next(iter(transcript_list)) + if any_t.is_translatable: + try: + transcript = any_t.translate(preferred[0]) + except Exception: + transcript = any_t + else: + transcript = any_t + except StopIteration: + return {"error": "not_found", "segments": [], "text": "", + "language": None, "is_generated": None} + + try: + fetched = transcript.fetch() + except (IpBlocked, RequestBlocked, YouTubeRequestFailed) as e: + return {"error": f"blocked: {type(e).__name__}", "segments": [], "text": "", + "language": transcript.language_code, "is_generated": transcript.is_generated} + except Exception as e: + return {"error": f"{type(e).__name__}: {e}", "segments": [], "text": "", + "language": transcript.language_code, "is_generated": transcript.is_generated} + + snippets = list(fetched) + segments = [ + { + "start": float(getattr(s, "start", 0.0)), + "duration": float(getattr(s, "duration", 0.0)), + "text": getattr(s, "text", ""), + } + for s in snippets + ] + + return { + "language": transcript.language_code, + "is_generated": transcript.is_generated, + "segments": segments, + "text": " ".join(s["text"].strip() for s in segments if s["text"].strip()), + "error": None, + } diff --git a/content_parser/transcription/__init__.py b/content_parser/transcription/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/transcription/cache.py b/content_parser/transcription/cache.py new file mode 100644 index 0000000..9c3f4e4 --- /dev/null +++ b/content_parser/transcription/cache.py @@ -0,0 +1,62 @@ +"""Disk cache for transcription results — keyed by source + item_id. + +Whisper costs money, scraping is slow; running the same job twice should +NOT re-pay for things we already transcribed. The cache lives in +~/.content_parser/transcription_cache/ as one JSON file per item. +""" +from __future__ import annotations + +import json +import re +from pathlib import Path + + +CACHE_DIR = Path.home() / ".content_parser" / "transcription_cache" + + +def _safe(value: str) -> str: + """Sanitize a path component — collapse anything outside [\\w-] to _.""" + return re.sub(r"[^\w-]", "_", value)[:80] or "item" + + +def _cache_path(source: str, item_id: str) -> Path: + return CACHE_DIR / f"{_safe(source)}_{_safe(item_id)}.json" + + +def get(source: str, item_id: str) -> dict | None: + p = _cache_path(source, item_id) + if not p.exists(): + return None + try: + return json.loads(p.read_text(encoding="utf-8")) + except (OSError, ValueError): + return None + + +def put(source: str, item_id: str, transcript_dict: dict) -> Path: + """Write the transcript atomically — to a .tmp sibling, then rename. + + Without this, a crash mid-write (disk full, SIGTERM) would leave a + truncated JSON that future `get()` calls catch as ValueError and + treat as cache miss — but the corrupt file lingers on disk. + """ + p = _cache_path(source, item_id) + p.parent.mkdir(parents=True, exist_ok=True) + tmp = p.with_suffix(p.suffix + ".tmp") + tmp.write_text(json.dumps(transcript_dict, ensure_ascii=False), encoding="utf-8") + tmp.replace(p) # atomic on POSIX; near-atomic on Windows + return p + + +def clear() -> int: + """Remove all cached transcripts. Returns count removed.""" + if not CACHE_DIR.exists(): + return 0 + n = 0 + for p in CACHE_DIR.glob("*.json"): + try: + p.unlink() + n += 1 + except OSError: + continue + return n diff --git a/content_parser/transcription/downloader.py b/content_parser/transcription/downloader.py new file mode 100644 index 0000000..ae81bd4 --- /dev/null +++ b/content_parser/transcription/downloader.py @@ -0,0 +1,92 @@ +"""Pull audio from a public video URL via yt-dlp. + +Used by the transcription pipeline. Requires `ffmpeg` available on PATH for +audio extraction. The downloaded file is small (≈300-500 KB for a 30-sec +reel as MP3), well under the 25 MB limit of OpenAI's Whisper API. +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + + +class DownloadError(Exception): + """Raised when audio extraction fails (network, unavailable URL, no ffmpeg, ...).""" + + +def download_audio(url: str, target_dir: Path, *, max_filesize_mb: int = 24) -> Path: + """Download audio from `url` into `target_dir`. Returns the resulting Path. + + `max_filesize_mb` caps the post-extraction file size; OpenAI's Whisper API + limit is 25 MB, so we stay slightly under to leave room for headers. + """ + try: + import yt_dlp # noqa: PLC0415 + except ImportError as e: + raise DownloadError( + "yt-dlp is not installed. Add yt-dlp to requirements.txt." + ) from e + + target_dir.mkdir(parents=True, exist_ok=True) + out_template = str(target_dir / "%(id)s.%(ext)s") + + ydl_opts: dict[str, Any] = { + "format": "bestaudio/best", + "outtmpl": out_template, + "quiet": True, + "no_warnings": True, + "noprogress": True, + "max_filesize": max_filesize_mb * 1024 * 1024, + "postprocessors": [{ + "key": "FFmpegExtractAudio", + "preferredcodec": "mp3", + "preferredquality": "64", # 64 kbps is plenty for speech recognition + }], + # Be polite — these scrapers don't like aggressive concurrency. + "concurrent_fragment_downloads": 1, + "retries": 2, + } + + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(url, download=True) + except Exception as e: + raise DownloadError(f"yt-dlp failed for {url!r}: {e}") from e + + # The post-processor changes the extension to .mp3. + video_id = info.get("id") or "audio" + candidate = target_dir / f"{video_id}.mp3" + if candidate.exists(): + return candidate + + # Fallback: find any audio file we left in the dir + for ext in ("mp3", "m4a", "webm", "opus", "wav"): + for path in target_dir.glob(f"*.{ext}"): + return path + raise DownloadError(f"yt-dlp did not produce an audio file for {url!r}") + + +def get_duration_seconds(url: str) -> float | None: + """Return the video's duration in seconds without downloading. None if unknown. + + Useful for budget gating before paying for transcription. + """ + try: + import yt_dlp # noqa: PLC0415 + except ImportError: + return None + + ydl_opts = {"quiet": True, "no_warnings": True, "skip_download": True, "noprogress": True} + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(url, download=False) + except Exception: + return None + + duration = info.get("duration") + if duration is None: + return None + try: + return float(duration) + except (TypeError, ValueError): + return None diff --git a/content_parser/transcription/runner.py b/content_parser/transcription/runner.py new file mode 100644 index 0000000..01a3a05 --- /dev/null +++ b/content_parser/transcription/runner.py @@ -0,0 +1,209 @@ +"""Orchestrate per-item transcription: cache check → download audio → Whisper. + +Plugins call `maybe_transcribe(item, settings, secrets)` once per Item they +yield. The function is a no-op when transcription is disabled or no audio +URL is available; on hard errors it sets `item.transcript.error` so the +output still records why. +""" +from __future__ import annotations + +import tempfile +from ipaddress import ip_address +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from ..core.schema import Item, Transcript +from . import cache as cache_mod +from .downloader import DownloadError, download_audio, get_duration_seconds +from .whisper_api import WhisperError, transcribe_audio + + +# Hostnames that should never reach yt-dlp. Anything else gets parsed and +# checked via the ipaddress module if it looks like an IP literal. +_DENIED_HOSTNAMES = {"localhost", "0.0.0.0", "ip6-localhost", "ip6-loopback"} + + +def _is_public_url(url: str) -> bool: + """Reject URLs that point at the local machine or RFC1918 networks. + + yt-dlp would happily fetch from a URL like 'http://169.254.169.254/...' + (AWS metadata) or 'http://10.0.0.1/...' if any of our third-party + sources (Apify/VK/Telegram) ever returned one. This guard rejects: + - non-http(s) schemes + - 'localhost' and well-known loopback hostnames + - IPv4/IPv6 literals that resolve to loopback / private / link-local / + reserved space + + Note: bare DNS names that resolve to private IPs are NOT caught here + (DNS rebinding). Mitigating that needs name resolution + connection + pinning, which is yt-dlp's domain. + """ + if not url or not isinstance(url, str): + return False + try: + parsed = urlparse(url) + except Exception: + return False + if parsed.scheme not in ("http", "https"): + return False + host = (parsed.hostname or "").strip().lower() + if not host: + return False + if host in _DENIED_HOSTNAMES: + return False + # If the hostname is an IP literal, classify it. + try: + ip = ip_address(host) + except ValueError: + return True # ordinary DNS name — accept + return not (ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_reserved or ip.is_multicast) + + +def _transcript_from_whisper_response(resp: dict) -> Transcript: + segments_raw = resp.get("segments") or [] + segments = [] + for seg in segments_raw: + start = float(seg.get("start", 0.0) or 0.0) + end = float(seg.get("end", start) or start) + segments.append({ + "start": start, + "duration": max(end - start, 0.0), + "text": seg.get("text", ""), + }) + text = (resp.get("text") or "").strip() + if not text and segments: + text = " ".join(s["text"].strip() for s in segments if s["text"].strip()) + return Transcript( + language=resp.get("language"), + is_generated=True, # Whisper output is always machine-generated + segments=segments, + text=text, + error=None, + ) + + +def _transcript_to_dict(t: Transcript) -> dict: + return { + "language": t.language, + "is_generated": t.is_generated, + "segments": list(t.segments), + "text": t.text, + "error": t.error, + } + + +def _video_url_for(item: Item) -> str | None: + """Pick the best URL for audio extraction. + + Some plugins put a direct CDN URL in media.video_url, which yt-dlp can't + always re-fetch (cookies/expiry). For Instagram/TikTok/Telegram, the + canonical post URL works better — yt-dlp resolves the media URL fresh. + """ + # Prefer a recognizable platform URL over a CDN URL. + if item.url and any(host in item.url for host in ( + "instagram.com", "tiktok.com", "youtube.com", "youtu.be", + "vk.com", "vk.ru", "t.me", "telegram.me", + )): + return item.url + return item.media.get("video_url") or item.url or None + + +def maybe_transcribe( + item: Item, + settings: dict[str, Any], + secrets: dict[str, str], + *, + only_if_missing: bool = False, +) -> None: + """Populate item.transcript from Whisper if conditions are met. + + only_if_missing: skip when item.transcript already has segments. Used by + YouTube where youtube-transcript-api ran first; we only fall back to + Whisper when subtitles weren't available. + """ + if not settings.get("transcribe_videos"): + return + if only_if_missing and item.transcript and item.transcript.segments: + return + + api_key = (secrets.get("OPENAI_API_KEY") or "").strip() + if not api_key: + item.transcript = Transcript( + error="OPENAI_API_KEY not set; cannot run Whisper.", + language=None, is_generated=None, segments=[], text="", + ) + return + + video_url = _video_url_for(item) + if not video_url: + return # silently skip — no media to transcribe + + # SSRF guard: refuse to send loopback / RFC1918 / link-local URLs to + # yt-dlp, which would otherwise fetch from internal networks if any + # upstream API returned such a URL (chain-of-trust risk). + if not _is_public_url(video_url): + item.transcript = Transcript( + error=f"refused to fetch non-public URL: {video_url[:60]}", + language=None, is_generated=None, segments=[], text="", + ) + return + + # Cache lookup before any network/download. + cached = cache_mod.get(item.source, item.item_id) + if cached: + item.transcript = Transcript( + language=cached.get("language"), + is_generated=cached.get("is_generated"), + segments=list(cached.get("segments") or []), + text=cached.get("text") or "", + error=cached.get("error"), + ) + return + + # Per-item duration cap. We BLOCK if duration is unknown — without a + # known length we can't bound the Whisper bill, so refusing is the + # cheap-and-safe default. Power users who really need to transcribe + # platforms where yt-dlp can't probe metadata can bypass by saving + # the audio file and calling whisper_api directly. + max_seconds = int(settings.get("max_audio_seconds_per_video", 600) or 600) + duration = get_duration_seconds(video_url) + if duration is None: + item.transcript = Transcript( + error="video duration unknown; transcription skipped to avoid unbounded cost.", + language=None, is_generated=None, segments=[], text="", + ) + return + if duration > max_seconds: + item.transcript = Transcript( + error=f"video too long: {int(duration)}s > {max_seconds}s cap (transcription skipped).", + language=None, is_generated=None, segments=[], text="", + ) + return + + with tempfile.TemporaryDirectory(prefix="cp_audio_") as tmp: + try: + audio_path = download_audio(video_url, Path(tmp)) + except DownloadError as e: + item.transcript = Transcript( + error=f"download failed: {e}", + language=None, is_generated=None, segments=[], text="", + ) + return + + try: + response = transcribe_audio(audio_path, api_key) + except WhisperError as e: + item.transcript = Transcript( + error=f"whisper failed: {e}", + language=None, is_generated=None, segments=[], text="", + ) + return + + transcript = _transcript_from_whisper_response(response) + item.transcript = transcript + + try: + cache_mod.put(item.source, item.item_id, _transcript_to_dict(transcript)) + except OSError: + pass # cache is best-effort diff --git a/content_parser/transcription/whisper_api.py b/content_parser/transcription/whisper_api.py new file mode 100644 index 0000000..f837b5a --- /dev/null +++ b/content_parser/transcription/whisper_api.py @@ -0,0 +1,109 @@ +"""Thin HTTP client for OpenAI's Whisper transcription endpoint. + +We avoid the official `openai` Python package because it's a heavy dep with +its own dependencies. Whisper takes one POST: file in multipart, model name ++ optional language + response_format in the form fields. + +Pricing as of 2026-04: $0.006/minute of audio. The 25 MB upload limit is +enforced by OpenAI; we cap on the download side too. +""" +from __future__ import annotations + +import time +from pathlib import Path +from typing import Any + +import requests + + +WHISPER_URL = "https://api.openai.com/v1/audio/transcriptions" +WHISPER_MODEL = "whisper-1" + + +class WhisperError(Exception): + pass + + +class _RetryableWhisperError(WhisperError): + """Internal: subclass for 429 / 5xx responses that warrant a retry.""" + + +def transcribe_audio( + audio_path: Path, + api_key: str, + *, + language: str | None = None, + timeout: int = 300, + max_retries: int = 2, +) -> dict[str, Any]: + """Send an audio file to Whisper, retrying transient failures. + + Returns OpenAI's verbose_json shape (text, language, segments). 429 + (rate limit) and 5xx responses are retried up to `max_retries` times + with exponential backoff (2s, 4s, 8s). 401, 4xx (other), and other + WhisperError subclasses surface immediately. + """ + if not api_key: + raise WhisperError("OPENAI_API_KEY is required for Whisper transcription.") + if not audio_path.exists(): + raise WhisperError(f"Audio file does not exist: {audio_path}") + + delay = 2.0 + for attempt in range(max_retries + 1): + try: + return _transcribe_once(audio_path, api_key, language=language, timeout=timeout) + except _RetryableWhisperError: + if attempt >= max_retries: + raise + _sleep(delay) + delay *= 2 + # Defensive — only reachable if max_retries < 0. + raise WhisperError("Whisper retry loop exhausted without a result.") # pragma: no cover + + +def _sleep(seconds: float) -> None: + """Indirection so tests can patch sleep without slowing the suite.""" + time.sleep(seconds) + + +def _transcribe_once( + audio_path: Path, + api_key: str, + *, + language: str | None, + timeout: int, +) -> dict[str, Any]: + headers = {"Authorization": f"Bearer {api_key}"} + data: dict[str, Any] = { + "model": WHISPER_MODEL, + "response_format": "verbose_json", + "timestamp_granularities[]": "segment", + } + if language: + data["language"] = language + + with audio_path.open("rb") as f: + files = {"file": (audio_path.name, f, "audio/mpeg")} + try: + r = requests.post(WHISPER_URL, headers=headers, data=data, files=files, timeout=timeout) + except requests.RequestException as e: + raise WhisperError(f"Network error calling Whisper: {e}") from e + + if r.status_code == 401: + raise WhisperError("OpenAI rejected the API key (401). Check OPENAI_API_KEY.") + if r.status_code == 429: + raise _RetryableWhisperError("OpenAI rate-limit (429).") + if 500 <= r.status_code < 600: + raise _RetryableWhisperError(f"OpenAI server error ({r.status_code}).") + if not r.ok: + # 4xx other than 401/429 — non-retryable client error. + try: + err = r.json().get("error", {}).get("message", r.text[:200]) + except ValueError: + err = r.text[:200] + raise WhisperError(f"Whisper returned {r.status_code}: {err}") + + try: + return r.json() + except ValueError as e: + raise WhisperError(f"Whisper returned non-JSON: {r.text[:200]}") from e diff --git a/content_parser/ui/__init__.py b/content_parser/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content_parser/ui/app.py b/content_parser/ui/app.py new file mode 100644 index 0000000..7f537ea --- /dev/null +++ b/content_parser/ui/app.py @@ -0,0 +1,665 @@ +"""Streamlit-интерфейс — динамически рисует формы на основе плагинов.""" +from __future__ import annotations + +import io +import zipfile +from datetime import datetime +from pathlib import Path + +import streamlit as st + +from ..core.registry import all_plugins, get_plugin +from ..core.runner import run +from ..core.schema import Item +from ..core.secrets import ( + delete_secret, + get_secret, + save_secret, + secret_locations, +) + + +st.set_page_config(page_title="Парсер контента", page_icon="🎬", layout="wide") + + +def _split_lines(text: str) -> list[str]: + return [line.strip() for line in text.splitlines() if line.strip()] + + +def _zip_directory(directory: Path) -> bytes: + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for file in directory.rglob("*"): + if file.is_file(): + zf.write(file, arcname=file.relative_to(directory)) + return buf.getvalue() + + +def _render_field(spec, key_prefix: str): + key = f"{key_prefix}_{spec.key}" + if spec.widget == "text": + return st.text_input(spec.label, value=str(spec.default or ""), help=spec.help, key=key, placeholder=spec.placeholder) + if spec.widget == "textarea": + return st.text_area(spec.label, value=str(spec.default or ""), help=spec.help, key=key, placeholder=spec.placeholder) + if spec.widget == "password": + return st.text_input(spec.label, value=str(spec.default or ""), help=spec.help, key=key, type="password") + if spec.widget == "number": + kwargs: dict = {"label": spec.label, "value": int(spec.default or 0), "help": spec.help, "key": key} + if spec.min_value is not None: + kwargs["min_value"] = int(spec.min_value) + if spec.max_value is not None: + kwargs["max_value"] = int(spec.max_value) + return st.number_input(**kwargs) + if spec.widget == "checkbox": + return st.checkbox(spec.label, value=bool(spec.default), help=spec.help, key=key) + if spec.widget == "select": + opts = list(spec.options) if spec.options else ([spec.default] if spec.default is not None else []) + if not opts: + # No options and no default — render a free-text fallback so the form still works. + return st.text_input(spec.label, value="", help=spec.help, key=key) + idx = opts.index(spec.default) if spec.default in opts else 0 + return st.selectbox(spec.label, opts, index=idx, help=spec.help, key=key) + return st.text_input(spec.label, value=str(spec.default or ""), key=key) + + +def _sidebar(plugin) -> tuple[dict[str, str], dict]: + with st.sidebar: + st.header("⚙️ Настройки") + + # Source selector + plugins = all_plugins() + names = [p.name for p in plugins] + labels = {p.name: p.label for p in plugins} + current_name = st.selectbox( + "Источник", + names, + index=names.index(plugin.name), + format_func=lambda n: labels[n], + key="source_selector", + ) + if current_name != plugin.name: + st.session_state["selected_source"] = current_name + st.rerun() + + st.divider() + + # Secrets per plugin + st.subheader("🔑 Ключи") + secrets: dict[str, str] = {} + for k in plugin.secret_keys: + session_key = f"secret_{k}" + if session_key not in st.session_state: + st.session_state[session_key] = get_secret(k) + value = st.text_input(k, value=st.session_state[session_key], type="password", key=session_key) + secrets[k] = value + + # Optional shared secrets that some plugins use + for opt in ("WEBSHARE_USERNAME", "WEBSHARE_PASSWORD", "PROXY_HTTP_URL", "PROXY_HTTPS_URL", "OPENAI_API_KEY", "INSTAGRAM_ACCESS_TOKEN"): + v = get_secret(opt) + if v: + secrets[opt] = v + + col_save, col_clear = st.columns(2) + with col_save: + if st.button("💾 Сохранить", use_container_width=True, key="btn_save_secrets"): + saved = [] + for k in plugin.secret_keys: + if secrets.get(k): + save_secret(k, secrets[k]) + saved.append(k) + if saved: + st.success(f"Сохранено: {', '.join(saved)}") + else: + st.warning("Нечего сохранять") + with col_clear: + if st.button("🗑️ Удалить", use_container_width=True, key="btn_clear_secrets"): + for k in plugin.secret_keys: + delete_secret(k) + st.session_state[f"secret_{k}"] = "" + st.success("Удалено") + st.rerun() + + for k in plugin.secret_keys: + locs = secret_locations(k) + if locs: + st.caption(f"`{k}` сохранён в: {', '.join(locs)}") + + st.divider() + + # Per-plugin settings + st.subheader("Параметры") + settings: dict = {} + proxy_secrets: dict[str, str] = {} + for spec in plugin.settings_specs(): + settings[spec.key] = _render_field(spec, key_prefix=f"setting_{plugin.name}") + + # If plugin has a proxy_provider setting, expose Webshare/HTTP fields here + if "proxy_provider" in settings and settings["proxy_provider"] != "Без прокси": + with st.expander("Параметры прокси", expanded=True): + if settings["proxy_provider"] == "Webshare": + proxy_secrets["WEBSHARE_USERNAME"] = st.text_input( + "Webshare username", + value=get_secret("WEBSHARE_USERNAME"), + key="ws_user", + ) + proxy_secrets["WEBSHARE_PASSWORD"] = st.text_input( + "Webshare password", + value=get_secret("WEBSHARE_PASSWORD"), + type="password", + key="ws_pass", + ) + elif settings["proxy_provider"] == "HTTP-прокси": + proxy_secrets["PROXY_HTTP_URL"] = st.text_input( + "HTTP URL", + value=get_secret("PROXY_HTTP_URL"), + placeholder="http://user:pass@host:port", + key="http_url", + ) + proxy_secrets["PROXY_HTTPS_URL"] = st.text_input( + "HTTPS URL (опц.)", + value=get_secret("PROXY_HTTPS_URL"), + key="https_url", + ) + + secrets.update({k: v for k, v in proxy_secrets.items() if v}) + + # If transcription is on, expose the OpenAI key inline so the user + # can paste it without leaving the plugin form. + if settings.get("transcribe_videos"): + with st.expander("🎤 Параметры Whisper", expanded=True): + openai_key = st.text_input( + "OPENAI_API_KEY", + value=get_secret("OPENAI_API_KEY"), + type="password", + key="openai_api_key", + help="Получить на https://platform.openai.com/api-keys", + ) + col_s, col_c = st.columns(2) + with col_s: + if st.button("💾 Сохранить ключ", use_container_width=True, key="save_openai"): + if openai_key.strip(): + save_secret("OPENAI_API_KEY", openai_key.strip()) + st.success("Сохранено") + else: + st.warning("Сначала вставь ключ") + with col_c: + if st.button("🗑️ Удалить", use_container_width=True, key="clear_openai"): + delete_secret("OPENAI_API_KEY") + st.session_state["openai_api_key"] = "" + st.rerun() + if openai_key: + secrets["OPENAI_API_KEY"] = openai_key + st.caption( + "⚠️ Whisper тарифицируется ~$0.006/мин аудио. " + "На своей машине нужен `ffmpeg` (apt install ffmpeg / brew install ffmpeg). " + "Сохранённый ключ остаётся при выключении чекбокса — удалить можно " + "только кнопкой 🗑️." + ) + + # ----- Google Sheets loader ----- + st.divider() + _render_sheets_loader(plugin) + + return secrets, settings + + +def _render_sheets_loader(plugin) -> None: + """Sidebar block: pull values from a Google Sheets range into an input tab.""" + from ..loaders.gsheets import GoogleSheetsLoader + + with st.expander("📥 Загрузить из Google Sheets", expanded=False): + st.caption( + "⚠️ JSON содержит приватный ключ — не показывай экран другим. " + "Сервис-аккаунт нужно вручную добавить в шаринг таблицы." + ) + + # Show client_email summary if creds are already saved, instead of + # re-rendering the full JSON every page load. + saved_email: str | None = None + saved_creds = get_secret("GOOGLE_SHEETS_CREDENTIALS") + if saved_creds: + try: + saved_email = GoogleSheetsLoader.validate_credentials(saved_creds).get("client_email") + except Exception: + saved_email = None + + if saved_email and not st.session_state.get("gs_replace_creds"): + st.success(f"✓ Учётка сохранена: `{saved_email}`") + st.caption("Поделись с этим email-ом каждой таблицей, которую парсишь.") + col_replace, col_clear = st.columns(2) + with col_replace: + if st.button("✏️ Заменить", use_container_width=True, key="gs_replace_creds_btn"): + st.session_state["gs_replace_creds"] = True + st.rerun() + with col_clear: + if st.button("🗑️ Удалить", use_container_width=True, key="gs_clear_creds"): + delete_secret("GOOGLE_SHEETS_CREDENTIALS") + st.session_state.pop("gs_creds", None) + st.session_state.pop("gs_replace_creds", None) + st.rerun() + creds_input = saved_creds # used by load button below + else: + creds_input = st.text_area( + "GOOGLE_SHEETS_CREDENTIALS (service account JSON)", + value="" if st.session_state.get("gs_replace_creds") else (saved_creds or ""), + height=80, + key="gs_creds", + help="Вставь содержимое JSON-файла ключа сервис-аккаунта.", + ) + col_save, col_cancel = st.columns(2) + with col_save: + if st.button("💾 Сохранить", use_container_width=True, key="gs_save_creds"): + pasted = (creds_input or "").strip() + if not pasted: + st.warning("Сначала вставь JSON") + else: + try: + parsed = GoogleSheetsLoader.validate_credentials(pasted) + except Exception as e: + st.error(f"JSON невалиден: {e}") + else: + save_secret("GOOGLE_SHEETS_CREDENTIALS", pasted) + st.session_state.pop("gs_replace_creds", None) + st.success( + f"Сохранено. Поделись таблицей с: `{parsed.get('client_email')}`" + ) + st.rerun() + with col_cancel: + if saved_creds and st.button("✕ Отмена", use_container_width=True, key="gs_cancel_replace"): + st.session_state.pop("gs_replace_creds", None) + st.rerun() + + sheet = st.text_input( + "URL или ID таблицы", + placeholder="https://docs.google.com/spreadsheets/d/...", + key="gs_sheet", + ) + col_tab, col_range = st.columns([1, 1]) + with col_tab: + tab_name = st.text_input("Лист (tab)", value="", key="gs_tab", + placeholder="например: Communities") + with col_range: + range_a1 = st.text_input("Диапазон A1", value="A:A", key="gs_range") + skip_header = st.checkbox("Пропустить первую строку (заголовок)", value=False, key="gs_skip_header") + + target_kinds = [s.kind for s in plugin.input_specs()] + target_kind = st.selectbox( + f"Куда подставить (для плагина {plugin.label})", + target_kinds, + key="gs_target_kind", + ) + + if st.button("📥 Загрузить", use_container_width=True, key="gs_load"): + creds_value = (creds_input or "").strip() if creds_input else "" + if not creds_value: + st.error("Нужен JSON сервис-аккаунта.") + return + if not sheet.strip(): + st.error("Укажи URL или ID таблицы.") + return + + try: + with st.spinner("Читаю таблицу…"): + loader = GoogleSheetsLoader(creds_value) + loaded = loader.load( + sheet.strip(), + tab=tab_name.strip() or None, + range_a1=range_a1.strip() or "A:A", + skip_header=skip_header, + ) + except Exception as e: + st.error(f"Ошибка загрузки: {e}") + return + + input_key = f"input_{plugin.name}_{target_kind}" + existing = (st.session_state.get(input_key) or "").strip() + new_block = "\n".join(loaded.values) + merged = (existing + "\n" + new_block).strip() if existing else new_block + st.session_state[input_key] = merged + + st.success( + f"Загружено {loaded.count} из «{loaded.sheet_title}» / " + f"«{loaded.tab_title}» → вкладка «{target_kind}»" + ) + st.rerun() + + +def _render_jobs_panel(plugin, inputs: dict[str, list[str]], settings: dict) -> None: + """Schedule panel: list/run/edit/delete jobs + manage crontab block. + + `inputs` and `settings` come from the active plugin's input tabs and are + used to seed the "create job from current state" form, so a user who + just configured a one-off run can persist it as a scheduled job. + """ + from ..jobs import store as jobs_store + from ..jobs.cron import is_cron_available + from ..jobs.runner import run_job_obj + from ..jobs.schema import Job, dump_job_yaml, load_job_yaml + + st.divider() + st.header("🕐 Расписание") + st.caption( + "Сохранённые job'ы переиспользуют твои inputs (вкл. Google Sheets) и " + "запускаются по cron'у или из CLI." + ) + + jobs = jobs_store.list_jobs() + invalid = jobs_store.list_invalid() + + if not jobs: + st.info("Пока нет сохранённых job'ов. Создай первый ниже из текущего состояния.") + else: + for job in jobs: + schedule_label = f"⏰ `{job.schedule}`" if job.schedule else "✋ ручной запуск" + with st.expander(f"📋 {job.name} — {job.source} — {schedule_label}", expanded=False): + if job.description: + st.caption(job.description) + + col_run, col_edit, col_del = st.columns(3) + with col_run: + if st.button("▶️ Запустить", key=f"run_{job.name}", use_container_width=True): + with st.spinner(f"Прогон {job.name}…"): + try: + result = run_job_obj(job, log=lambda m: st.write(m)) + st.success( + f"Готово: {len(result.items)} item(s) в `{result.out_dir}`" + ) + except Exception as e: + st.error(f"Ошибка: {e}") + with col_edit: + if st.button("✏️ Изменить", key=f"edit_{job.name}", use_container_width=True): + st.session_state["editing_job"] = job.name + st.rerun() + with col_del: + if st.button("🗑️ Удалить", key=f"del_{job.name}", use_container_width=True): + jobs_store.delete_job(job.name) + st.rerun() + + if st.session_state.get("editing_job") == job.name: + edited = st.text_area( + "YAML", + value=dump_job_yaml(job), + height=300, + key=f"yaml_{job.name}", + ) + save_col, cancel_col = st.columns(2) + with save_col: + if st.button("💾 Сохранить", key=f"save_yaml_{job.name}", use_container_width=True): + try: + new_job = load_job_yaml(edited, name_hint=job.name) + jobs_store.save_job(new_job) + st.session_state.pop("editing_job", None) + st.success("Сохранено") + st.rerun() + except Exception as e: + st.error(f"Ошибка: {e}") + with cancel_col: + if st.button("✕ Отмена", key=f"cancel_yaml_{job.name}", use_container_width=True): + st.session_state.pop("editing_job", None) + st.rerun() + else: + st.markdown(f"**Inputs:**") + if job.inputs: + for kind, values in job.inputs.items(): + st.markdown(f"- `{kind}`: {len(values)} inline ({', '.join(values[:3])}{'…' if len(values) > 3 else ''})") + if job.sheet_inputs: + st.markdown("**Sheet inputs:**") + for ref in job.sheet_inputs: + st.markdown( + f"- `{ref.target}` ← {ref.sheet[:40]}… " + f"tab=`{ref.tab or '(первый)'}`, range=`{ref.range_a1}`" + ) + + if invalid: + with st.expander(f"⚠️ Невалидные YAML-файлы ({len(invalid)})", expanded=False): + for name, err in invalid: + st.markdown(f"- **{name}**: `{err}`") + + # ----- Create job from current state ----- + st.subheader("➕ Создать job из текущего состояния") + st.caption( + f"Источник: **{plugin.label}**. Inputs и settings возьмутся из текущего состояния " + "input-вкладок и параметров плагина." + ) + new_name = st.text_input( + "Имя job'а", placeholder="weekly-vk-marketing", key="new_job_name", + help="Только буквы, цифры, `-` и `_`; до 64 символов.", + ) + new_schedule = st.text_input( + "Расписание (cron, опц.)", placeholder="0 6 * * MON", key="new_job_schedule", + ) + new_description = st.text_input( + "Описание (опц.)", placeholder="Топ постов из ниши маркетинг по понедельникам", + key="new_job_description", + ) + if st.button("💾 Создать job", type="primary", use_container_width=True, key="new_job_create"): + non_empty = {k: v for k, v in inputs.items() if v} + if not new_name.strip(): + st.error("Имя обязательно.") + elif not non_empty: + st.error("Сначала заполни input-вкладки.") + else: + try: + job = Job( + name=new_name.strip(), + source=plugin.name, + inputs=non_empty, + settings=dict(settings), + schedule=new_schedule.strip() or None, + description=new_description.strip() or None, + ) + jobs_store.save_job(job) + st.success( + f"Job `{job.name}` сохранён. Чтобы добавить sheet_inputs — " + "открой ✏️ Изменить и допиши блок в YAML." + ) + st.rerun() + except Exception as e: + st.error(f"Ошибка: {e}") + + # ----- Cron block management ----- + st.subheader("📅 Cron") + if not is_cron_available(): + st.warning( + "На этом хосте `crontab` недоступен (например, Streamlit Cloud). " + "Используй **GitHub Actions cron** — рабочий шаблон ниже:" + ) + st.code( + """\ +# .github/workflows/run-jobs.yml +on: + schedule: + - cron: "0 6 * * MON" + workflow_dispatch: +jobs: + run: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: { python-version: "3.11" } + - run: pip install -r requirements.txt + - run: python -m content_parser.cli jobs run <ИМЯ_JOB'А> + env: + YOUTUBE_API_KEY: ${{ secrets.YOUTUBE_API_KEY }} + APIFY_API_TOKEN: ${{ secrets.APIFY_API_TOKEN }} + REDDIT_CLIENT_ID: ${{ secrets.REDDIT_CLIENT_ID }} + REDDIT_CLIENT_SECRET: ${{ secrets.REDDIT_CLIENT_SECRET }} + VK_ACCESS_TOKEN: ${{ secrets.VK_ACCESS_TOKEN }} + GOOGLE_SHEETS_CREDENTIALS: ${{ secrets.GOOGLE_SHEETS_CREDENTIALS }} + - uses: actions/upload-artifact@v4 + with: + name: scheduled-output + path: output/scheduled/ +""", + language="yaml", + ) + return + + col_install, col_remove = st.columns(2) + with col_install: + if st.button("📅 Установить блок в crontab", use_container_width=True, key="cron_install"): + from ..jobs.cron import install_cron, CronError + try: + entries = install_cron() + if entries: + st.success(f"Установлено {len(entries)} запис(ь/и):") + for e in entries: + st.write(f"`{e.schedule}` → {e.job_name}") + else: + st.info("Нет джоб с расписанием — блок очищен.") + except CronError as e: + st.error(f"Ошибка: {e}") + with col_remove: + if st.button("❌ Удалить блок", use_container_width=True, key="cron_remove"): + from ..jobs.cron import remove_cron, CronError + try: + if remove_cron(): + st.success("Блок удалён.") + else: + st.info("Блока в crontab нет.") + except CronError as e: + st.error(f"Ошибка: {e}") + + try: + from ..jobs.cron import read_block + entries = read_block() + if entries: + st.markdown("**Сейчас в crontab:**") + for e in entries: + st.markdown(f"- `{e.schedule}` → **{e.job_name or '?'}**") + except Exception: + pass + + +def _main_area(plugin) -> dict[str, list[str]]: + st.title(f"🎬 Парсер контента — {plugin.label}") + st.caption("Парсит метаданные, комментарии и (где возможно) транскрипты. Сохраняет JSON, Markdown и CSV.") + + specs = plugin.input_specs() + tabs = st.tabs([f"📥 {s.label}" for s in specs]) + inputs: dict[str, list[str]] = {} + for spec, tab in zip(specs, tabs): + with tab: + text = st.text_area( + f"{spec.label} — по одному на строку", + placeholder=spec.placeholder, + help=spec.help, + height=120, + key=f"input_{plugin.name}_{spec.kind}", + ) + inputs[spec.kind] = _split_lines(text) + return inputs + + +def main() -> None: + plugins = all_plugins() + if not plugins: + st.error("Нет доступных плагинов. Проверь установку зависимостей.") + return + + if "selected_source" not in st.session_state: + st.session_state.selected_source = plugins[0].name + if "last_run" not in st.session_state: + st.session_state.last_run = None + + plugin = get_plugin(st.session_state.selected_source) + + secrets, settings = _sidebar(plugin) + inputs = _main_area(plugin) + + st.divider() + if st.button("▶️ Запустить (разово)", type="primary", use_container_width=True): + non_empty = {k: v for k, v in inputs.items() if v} + if not non_empty: + st.error("Заполни хотя бы одну вкладку.") + st.stop() + + missing = plugin.validate_secrets(secrets) + if missing: + st.error(f"Заполни ключи: {', '.join(missing)}") + st.stop() + + out_dir = Path("output") / plugin.name / datetime.now().strftime("%Y%m%d_%H%M%S") + log = st.status("Запуск…", expanded=True) + progress_bar = st.progress(0.0, text="Подготовка…") + + try: + with log: + st.write(f"Источник: **{plugin.label}**") + st.write(f"Каталог: `{out_dir.resolve()}`") + + def st_log(msg: str) -> None: + with log: + st.write(msg) + + def st_progress(done: int, total: int, message: str) -> None: + progress_bar.progress(done / max(total, 1), text=f"{done}/{total} — {message[:60]}") + + result = run(plugin, non_empty, settings, secrets, output_dir=out_dir, log=st_log, progress=st_progress) + log.update(label=f"Готово — {len(result.items)} item(s)", state="complete") + st.session_state.last_run = {"out_dir": str(result.out_dir), "items": result.items} + except Exception as e: + log.update(label=f"Ошибка: {e}", state="error") + st.exception(e) + + _render_results() + _render_jobs_panel(plugin, inputs, settings) + + +def _render_results() -> None: + run_data = st.session_state.get("last_run") + if not run_data: + return + items: list[Item] = run_data["items"] + out_dir = Path(run_data["out_dir"]) + + st.divider() + st.subheader(f"📦 Результаты — {len(items)}") + st.caption(f"Сохранено в `{out_dir.resolve()}`") + + col1, col2 = st.columns(2) + with col1: + st.download_button( + "⬇️ Скачать всё (ZIP)", + data=_zip_directory(out_dir), + file_name=f"{out_dir.name}.zip", + mime="application/zip", + use_container_width=True, + ) + with col2: + summary = out_dir / "summary.csv" + if summary.exists(): + st.download_button( + "⬇️ summary.csv", + data=summary.read_bytes(), + file_name="summary.csv", + mime="text/csv", + use_container_width=True, + ) + + for it in items: + title = it.title or it.item_id + n_comments = len(it.comments) + has_t = bool(it.transcript and it.transcript.segments) + with st.expander(f"[{it.source}] {title} — комм.: {n_comments}, транскрипт: {'да' if has_t else 'нет'}"): + metric_pairs = " · ".join( + f"**{k}:** {v}" for k, v in it.media.items() if v is not None + ) + st.markdown( + f"**Автор:** {it.author or '—'} \n" + f"**Ссылка:** {it.url} \n" + f"**Опубликовано:** {it.published_at or '—'} \n" + + (f"{metric_pairs}\n" if metric_pairs else "") + ) + if it.transcript and it.transcript.text: + with st.expander("Транскрипт"): + st.text(it.transcript.text) + if it.comments: + with st.expander(f"Комментарии ({len(it.comments)})"): + for c in it.comments[:200]: + prefix = "↳ " if c.parent_id else "" + st.markdown(f"{prefix}**{c.author or '—'}** _({c.published_at or '—'}, ♥ {c.like_count})_") + st.write(c.text or "") + if len(it.comments) > 200: + st.caption(f"…первые 200 из {len(it.comments)}") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..18fd868 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +google-api-python-client>=2.100.0 +youtube-transcript-api>=1.0.0 +streamlit>=1.30.0 +requests>=2.31.0 +praw>=7.7 +gspread>=6.0 +google-auth>=2.20 +pyyaml>=6.0 +yt-dlp>=2024.0,<2027.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_gsheets_loader.py b/tests/test_gsheets_loader.py new file mode 100644 index 0000000..d720f22 --- /dev/null +++ b/tests/test_gsheets_loader.py @@ -0,0 +1,273 @@ +"""Tests for content_parser.loaders.gsheets — URL parsing, auth, value extraction.""" +from __future__ import annotations + +import json +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import AuthError, PluginError +from content_parser.loaders.gsheets import GoogleSheetsLoader, LoadedRange + + +VALID_CREDS = { + "type": "service_account", + "client_email": "bot@project.iam.gserviceaccount.com", + "private_key": "-----BEGIN PRIVATE KEY-----\nfake\n-----END PRIVATE KEY-----\n", + "project_id": "test-project", +} + + +def _patched_client(): + """Patch the gspread.authorize + Credentials so __init__ doesn't try real auth.""" + return patch.multiple( + "content_parser.loaders.gsheets", + # We patch _build_client itself to skip the import path + ) + + +class CredentialsValidationTest(unittest.TestCase): + def test_invalid_json_string_raises_auth_error(self): + with self.assertRaises(AuthError): + GoogleSheetsLoader("not-json-at-all") + + def test_non_dict_raises_auth_error(self): + with self.assertRaises(AuthError): + GoogleSheetsLoader("[]") + + def test_missing_required_field(self): + bad = dict(VALID_CREDS) + del bad["private_key"] + with self.assertRaises(AuthError) as cm: + GoogleSheetsLoader(json.dumps(bad)) + self.assertIn("private_key", str(cm.exception)) + + def test_oauth_client_json_rejected(self): + # OAuth client credentials have type=authorized_user, not service_account. + # We refuse them with a clear message instead of confusing field-missing errors. + oauth_client = { + "type": "authorized_user", + "client_id": "...", + "client_secret": "...", + "refresh_token": "...", + } + with self.assertRaises(AuthError) as cm: + GoogleSheetsLoader(json.dumps(oauth_client)) + self.assertIn("service account", str(cm.exception).lower()) + + def test_from_secrets_missing_token(self): + with self.assertRaises(AuthError): + GoogleSheetsLoader.from_secrets({}) + + def test_accepts_dict_directly(self): + with patch.object(GoogleSheetsLoader, "_build_client", return_value=MagicMock()): + loader = GoogleSheetsLoader(VALID_CREDS) + self.assertEqual( + loader.service_account_email(), + "bot@project.iam.gserviceaccount.com", + ) + + def test_accepts_json_string(self): + with patch.object(GoogleSheetsLoader, "_build_client", return_value=MagicMock()): + loader = GoogleSheetsLoader(json.dumps(VALID_CREDS)) + self.assertEqual( + loader.service_account_email(), + "bot@project.iam.gserviceaccount.com", + ) + + def test_validate_credentials_does_not_build_client(self): + # Crucial property for the UI: we want to validate freshly-pasted JSON + # before saving, without touching the network. _build_client must not run. + with patch.object(GoogleSheetsLoader, "_build_client") as bc: + parsed = GoogleSheetsLoader.validate_credentials(json.dumps(VALID_CREDS)) + bc.assert_not_called() + self.assertEqual(parsed["client_email"], VALID_CREDS["client_email"]) + + +class SheetIdExtractionTest(unittest.TestCase): + def test_bare_id(self): + bare_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + self.assertEqual(GoogleSheetsLoader._extract_sheet_id(bare_id), bare_id) + + def test_full_url(self): + sheet_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/edit#gid=0" + self.assertEqual(GoogleSheetsLoader._extract_sheet_id(url), sheet_id) + + def test_url_with_extra_path(self): + sheet_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/edit?usp=sharing" + self.assertEqual(GoogleSheetsLoader._extract_sheet_id(url), sheet_id) + + def test_empty_raises(self): + with self.assertRaises(PluginError): + GoogleSheetsLoader._extract_sheet_id("") + + def test_garbage_raises(self): + with self.assertRaises(PluginError): + GoogleSheetsLoader._extract_sheet_id("not-a-url-or-id") + + def test_non_google_host_rejected(self): + # Critical: a URL with /d/<id>/ on a non-google host must NOT be accepted. + # The previous version silently extracted the ID, which was confusing. + sheet_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + with self.assertRaises(PluginError) as cm: + GoogleSheetsLoader._extract_sheet_id( + f"https://evil.example/spreadsheets/d/{sheet_id}/edit" + ) + self.assertIn("docs.google.com", str(cm.exception).lower()) + + def test_lookalike_host_rejected(self): + sheet_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + with self.assertRaises(PluginError): + GoogleSheetsLoader._extract_sheet_id( + f"https://evildocs.google.com.attacker.com/spreadsheets/d/{sheet_id}/edit" + ) + + def test_other_google_subdomain_rejected(self): + # mail.google.com /spreadsheets/d/... shouldn't sneak through either. + sheet_id = "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890" + with self.assertRaises(PluginError): + GoogleSheetsLoader._extract_sheet_id( + f"https://mail.google.com/spreadsheets/d/{sheet_id}/" + ) + + +class LoadTest(unittest.TestCase): + """Mock gspread to verify load() behavior end-to-end.""" + + def _build_loader(self): + gc = MagicMock() + with patch.object(GoogleSheetsLoader, "_build_client", return_value=gc): + loader = GoogleSheetsLoader(VALID_CREDS) + return loader, gc + + def _wire_worksheet(self, gc, *, sheet_title="My Sheet", tab_title="Communities", rows=None): + worksheet = MagicMock() + worksheet.title = tab_title + worksheet.get.return_value = rows or [] + + spreadsheet = MagicMock() + spreadsheet.title = sheet_title + spreadsheet.worksheet.return_value = worksheet + spreadsheet.sheet1 = worksheet + spreadsheet.worksheets.return_value = [worksheet] + + gc.open_by_key.return_value = spreadsheet + return spreadsheet, worksheet + + def test_basic_load_returns_flat_values(self): + loader, gc = self._build_loader() + self._wire_worksheet(gc, rows=[ + ["durov_says"], + ["telegram"], + ["awesome_community"], + ]) + + result = loader.load( + "1AbC2DeFG_HiJkLmNoPqRsTuVwXyZ-1234567890", + tab="Communities", + range_a1="A:A", + ) + self.assertIsInstance(result, LoadedRange) + self.assertEqual(result.values, ["durov_says", "telegram", "awesome_community"]) + self.assertEqual(result.count, 3) + self.assertEqual(result.sheet_title, "My Sheet") + self.assertEqual(result.tab_title, "Communities") + + def test_skip_header(self): + loader, gc = self._build_loader() + self._wire_worksheet(gc, rows=[ + ["Channel name"], # header + ["durov"], + ["telegram"], + ]) + result = loader.load("ID" * 10, range_a1="A:A", skip_header=True) + self.assertEqual(result.values, ["durov", "telegram"]) + + def test_dedup_within_range(self): + loader, gc = self._build_loader() + self._wire_worksheet(gc, rows=[ + ["durov"], + ["telegram"], + ["durov"], # duplicate + [" "], # blank — dropped + ["new_chan"], + ]) + result = loader.load("ID" * 10, range_a1="A:A") + self.assertEqual(result.values, ["durov", "telegram", "new_chan"]) + + def test_multi_column_range_flattens(self): + loader, gc = self._build_loader() + self._wire_worksheet(gc, rows=[ + ["durov", "extra1"], + ["telegram", ""], + ["", "extra3"], + ]) + result = loader.load("ID" * 10, range_a1="A:B") + # Cells flattened in row-major order, deduped + self.assertEqual(result.values, ["durov", "extra1", "telegram", "extra3"]) + + def test_invalid_range_rejected(self): + loader, gc = self._build_loader() + with self.assertRaises(PluginError): + loader.load("ID" * 10, range_a1="not a range!") + + def test_tab_not_found_lists_available(self): + loader, gc = self._build_loader() + spreadsheet, worksheet = self._wire_worksheet(gc, tab_title="ActualTab") + spreadsheet.worksheet.side_effect = Exception("WorksheetNotFound") + # worksheets() still works for diagnostic + spreadsheet.worksheets.return_value = [ + MagicMock(title="ActualTab"), + MagicMock(title="Other"), + ] + + with self.assertRaises(PluginError) as cm: + loader.load("ID" * 10, tab="WrongName", range_a1="A:A") + self.assertIn("WrongName", str(cm.exception)) + self.assertIn("ActualTab", str(cm.exception)) + + def test_403_maps_to_auth_error(self): + loader, gc = self._build_loader() + gc.open_by_key.side_effect = Exception("403 permission denied") + with self.assertRaises(AuthError) as cm: + loader.load("ID" * 10, range_a1="A:A") + self.assertIn("Share the sheet", str(cm.exception)) + + def test_404_maps_to_plugin_error(self): + loader, gc = self._build_loader() + gc.open_by_key.side_effect = Exception("404 not found") + with self.assertRaises(PluginError) as cm: + loader.load("ID" * 10, range_a1="A:A") + self.assertIn("not found", str(cm.exception).lower()) + + def test_default_tab_is_first_sheet(self): + loader, gc = self._build_loader() + spreadsheet, worksheet = self._wire_worksheet(gc, rows=[["x"]]) + loader.load("ID" * 10, range_a1="A:A") + # We didn't pass tab=, so it should NOT call .worksheet(), only access .sheet1 + spreadsheet.worksheet.assert_not_called() + + def test_empty_string_tab_falls_back_to_first_sheet(self): + # Cron configs may pass tab="" rather than tab=None. + loader, gc = self._build_loader() + spreadsheet, worksheet = self._wire_worksheet(gc, rows=[["x"]]) + loader.load("ID" * 10, tab="", range_a1="A:A") + spreadsheet.worksheet.assert_not_called() + + def test_whitespace_tab_falls_back(self): + loader, gc = self._build_loader() + spreadsheet, worksheet = self._wire_worksheet(gc, rows=[["x"]]) + loader.load("ID" * 10, tab=" ", range_a1="A:A") + spreadsheet.worksheet.assert_not_called() + + def test_loaded_range_has_no_raw_rows_attr(self): + # raw_rows was dropped — make sure nothing accidentally re-adds it. + loader, gc = self._build_loader() + self._wire_worksheet(gc, rows=[["x"]]) + result = loader.load("ID" * 10, range_a1="A:A") + self.assertFalse(hasattr(result, "raw_rows")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_instagram_adapter.py b/tests/test_instagram_adapter.py new file mode 100644 index 0000000..e9e5ede --- /dev/null +++ b/tests/test_instagram_adapter.py @@ -0,0 +1,110 @@ +"""Tests for content_parser.plugins.instagram.adapter — Apify post → Item.""" +from __future__ import annotations + +import unittest + +from content_parser.plugins.instagram.adapter import post_to_item + + +SAMPLE_REEL = { + "id": "9876", + "shortCode": "C_short", + "url": "https://www.instagram.com/reel/C_short/", + "type": "Video", + "caption": "Check this out\n#fun #fyp", + "ownerUsername": "creator", + "ownerId": 4242, + "timestamp": "2026-04-01T12:00:00.000Z", + "likesCount": 12345, + "commentsCount": 78, + "videoViewCount": 1_500_000, + "videoUrl": "https://cdn.example/v.mp4", + "displayUrl": "https://cdn.example/t.jpg", + "videoDuration": 17.5, + "musicInfo": { + "audio_id": "snd_42", + "song_name": "Song Name", + "artist_name": "Artist Name", + }, + "hashtags": ["fun", "fyp"], + "latestComments": [ + { + "id": "c1", + "ownerUsername": "fan1", + "text": "🔥", + "likesCount": 10, + "timestamp": "2026-04-01T13:00:00Z", + "replies": [ + {"id": "c1r1", "ownerUsername": "creator", "text": "thanks!", "likesCount": 2}, + ], + }, + {"id": "c2", "ownerUsername": "hater", "text": "meh"}, + ], + "locationName": "Somewhere", +} + + +class AdapterTest(unittest.TestCase): + def test_basic_fields(self): + item = post_to_item(SAMPLE_REEL) + self.assertEqual(item.source, "instagram") + self.assertEqual(item.item_id, "C_short") + self.assertEqual(item.url, "https://www.instagram.com/reel/C_short/") + self.assertEqual(item.author, "creator") + self.assertEqual(item.author_id, "4242") + self.assertEqual(item.title, "Check this out") # caption first line, not full caption + self.assertEqual(item.text, "Check this out\n#fun #fyp") # full caption preserved + + def test_metrics(self): + item = post_to_item(SAMPLE_REEL) + self.assertEqual(item.media["like_count"], 12345) + self.assertEqual(item.media["comment_count"], 78) + self.assertEqual(item.media["view_count"], 1_500_000) + self.assertEqual(item.media["video_duration"], 17.5) + self.assertEqual(item.media["audio_id"], "snd_42") + self.assertEqual(item.media["audio_title"], "Song Name") + self.assertEqual(item.media["audio_artist"], "Artist Name") + + def test_extras(self): + item = post_to_item(SAMPLE_REEL) + self.assertEqual(item.extra["hashtags"], ["fun", "fyp"]) + self.assertEqual(item.extra["location_name"], "Somewhere") + + def test_comments_flattened_with_parent_link(self): + item = post_to_item(SAMPLE_REEL) + # 2 top-level + 1 reply = 3 total, in order: top1, reply, top2 + self.assertEqual(len(item.comments), 3) + top1, reply, top2 = item.comments + self.assertIsNone(top1.parent_id) + self.assertEqual(top1.author, "fan1") + self.assertEqual(reply.parent_id, "c1") + self.assertEqual(reply.author, "creator") + self.assertIsNone(top2.parent_id) + self.assertEqual(top2.author, "hater") + + def test_image_post_no_video_metrics(self): + post = { + "id": "1", "shortCode": "IMG", "url": "https://insta/p/IMG/", + "type": "Image", "caption": "static", + "ownerUsername": "x", "timestamp": "2026-04-01T00:00:00Z", + "likesCount": 100, "commentsCount": 0, + } + item = post_to_item(post) + self.assertEqual(item.media["like_count"], 100) + self.assertNotIn("view_count", item.media) + self.assertNotIn("audio_id", item.media) + self.assertEqual(item.comments, []) + + def test_falls_back_to_id_when_no_shortcode(self): + post = {"id": "abc", "ownerUsername": "x"} + item = post_to_item(post) + self.assertEqual(item.item_id, "abc") + + def test_handles_empty_caption(self): + post = {"id": "1", "shortCode": "X"} + item = post_to_item(post) + self.assertIsNone(item.title) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_instagram_graph_adapter.py b/tests/test_instagram_graph_adapter.py new file mode 100644 index 0000000..7ce1125 --- /dev/null +++ b/tests/test_instagram_graph_adapter.py @@ -0,0 +1,158 @@ +"""Tests for content_parser.plugins.instagram_graph.adapter.""" +from __future__ import annotations + +import unittest + +from content_parser.plugins.instagram_graph.adapter import ( + _flatten_insights, + comment_to_core, + flatten_comments, + media_to_item, +) + + +SAMPLE_REEL = { + "id": "17895695668004550", + "caption": "Big launch today\nLink in bio", + "media_type": "REEL", + "media_url": "https://cdn.example/reel.mp4", + "thumbnail_url": "https://cdn.example/reel.jpg", + "permalink": "https://www.instagram.com/reel/Cabc123/", + "timestamp": "2026-04-01T12:00:00+0000", + "is_comment_enabled": True, + "comments_count": 42, + "like_count": 1234, + "owner": {"id": "17841405822304914", "username": "myaccount"}, + "insights": { + "data": [ + {"name": "plays", "values": [{"value": 100000}]}, + {"name": "reach", "values": [{"value": 80000}]}, + {"name": "saved", "values": [{"value": 500}]}, + {"name": "shares", "values": [{"value": 250}]}, + {"name": "total_interactions", "values": [{"value": 2000}]}, + ] + }, +} + + +class FlattenInsightsTest(unittest.TestCase): + def test_dict_envelope(self): + out = _flatten_insights(SAMPLE_REEL["insights"]) + self.assertEqual(out["plays"], 100000) + self.assertEqual(out["reach"], 80000) + + def test_list_form(self): + out = _flatten_insights([ + {"name": "reach", "values": [{"value": 1000}]}, + ]) + self.assertEqual(out["reach"], 1000) + + def test_empty(self): + self.assertEqual(_flatten_insights(None), {}) + self.assertEqual(_flatten_insights({}), {}) + + def test_skips_non_numeric(self): + out = _flatten_insights([{"name": "weird", "values": [{"value": "string"}]}]) + self.assertEqual(out, {}) + + +class MediaToItemTest(unittest.TestCase): + def test_basic_fields(self): + item = media_to_item(SAMPLE_REEL, owner_username="myaccount") + self.assertEqual(item.source, "instagram_graph") + self.assertEqual(item.item_id, "17895695668004550") + self.assertEqual(item.url, "https://www.instagram.com/reel/Cabc123/") + self.assertEqual(item.title, "Big launch today") + self.assertEqual(item.author, "myaccount") + self.assertEqual(item.author_id, "17841405822304914") + self.assertEqual(item.published_at, "2026-04-01T12:00:00+0000") + + def test_media_metrics(self): + item = media_to_item(SAMPLE_REEL) + self.assertEqual(item.media["media_type"], "REEL") + self.assertEqual(item.media["like_count"], 1234) + self.assertEqual(item.media["comments_count"], 42) + self.assertEqual(item.media["plays"], 100000) + self.assertEqual(item.media["reach"], 80000) + self.assertEqual(item.media["saved"], 500) + self.assertEqual(item.media["shares"], 250) + self.assertEqual(item.media["total_interactions"], 2000) + + def test_strips_empty(self): + post = {"id": "1", "media_type": "IMAGE", "permalink": "https://x"} + item = media_to_item(post) + # Without like_count etc., they shouldn't appear + self.assertNotIn("like_count", item.media) + self.assertEqual(item.media["media_type"], "IMAGE") + + def test_raises_on_missing_id(self): + with self.assertRaises(ValueError): + media_to_item({"caption": "no id"}) + + def test_owner_username_overrides(self): + item = media_to_item(SAMPLE_REEL, owner_username="custom") + self.assertEqual(item.author, "custom") + + def test_falls_back_to_owner_when_no_username_passed(self): + # When the caller doesn't pass owner_username, we should still see author + # via the inline owner.username. + item = media_to_item(SAMPLE_REEL) + # adapter doesn't read owner.username for author currently; author is + # only set from owner_username argument. Confirm explicit None when omitted: + self.assertIsNone(item.author) + # But author_id still comes from owner.id + self.assertEqual(item.author_id, "17841405822304914") + + +class CommentToCore(unittest.TestCase): + def test_top_level(self): + c = {"id": "9001", "text": "great", "username": "fan", + "timestamp": "2026-04-01T13:00:00+0000", "like_count": 5, + "user": {"id": "999"}} + out = comment_to_core(c) + self.assertEqual(out.comment_id, "9001") + self.assertIsNone(out.parent_id) + self.assertEqual(out.author, "fan") + self.assertEqual(out.author_id, "999") + self.assertEqual(out.text, "great") + self.assertEqual(out.like_count, 5) + + def test_reply_carries_parent(self): + c = {"id": "9002", "text": "thx"} + out = comment_to_core(c, parent_id="9001") + self.assertEqual(out.parent_id, "9001") + + +class FlattenCommentsTest(unittest.TestCase): + def test_with_inline_replies(self): + comments = [ + {"id": "1", "text": "top", "username": "fan", + "replies": {"data": [ + {"id": "1a", "text": "thanks!", "username": "myaccount"}, + {"id": "1b", "text": "+1", "username": "fan2"}, + ]}}, + {"id": "2", "text": "another top", "username": "fan3"}, + ] + out = flatten_comments(comments) + # 2 top-level + 2 replies = 4 total + self.assertEqual(len(out), 4) + self.assertIsNone(out[0].parent_id) + self.assertEqual(out[1].parent_id, "1") + self.assertEqual(out[2].parent_id, "1") + self.assertIsNone(out[3].parent_id) + + def test_no_replies_field(self): + comments = [{"id": "1", "text": "top", "username": "fan"}] + out = flatten_comments(comments) + self.assertEqual(len(out), 1) + + def test_empty_list(self): + self.assertEqual(flatten_comments([]), []) + + def test_skips_non_dict(self): + out = flatten_comments([{"id": "1", "text": "ok"}, "garbage", None]) + self.assertEqual(len(out), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_instagram_graph_client.py b/tests/test_instagram_graph_client.py new file mode 100644 index 0000000..cc281da --- /dev/null +++ b/tests/test_instagram_graph_client.py @@ -0,0 +1,221 @@ +"""Tests for content_parser.plugins.instagram_graph.client — Graph API HTTP client.""" +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import AuthError, PluginError, RateLimitError +from content_parser.plugins.instagram_graph.client import GraphClient + + +def _resp(payload, *, status=200, ok=True): + m = MagicMock() + m.ok = ok + m.status_code = status + m.json.return_value = payload + m.text = str(payload) + return m + + +class GraphClientTokenTest(unittest.TestCase): + """Token must reach the URL as `access_token=` query param every call.""" + + def _patched_session(self, response): + session = MagicMock() + session.get.return_value = response + return patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), session + + def test_token_attached_to_query(self): + ctx, session = self._patched_session(_resp({"id": "1", "username": "x"})) + with ctx: + GraphClient("MY_TOKEN").get("17841", params={"fields": "username"}) + kwargs = session.get.call_args.kwargs + self.assertEqual(kwargs["params"]["access_token"], "MY_TOKEN") + self.assertEqual(kwargs["params"]["fields"], "username") + + def test_user_supplied_token_overrides_embedded(self): + # If the user passes an access_token in params, OUR client overrides it + # — defense in case a `next` URL embeds a different token. + ctx, session = self._patched_session(_resp({"id": "1"})) + with ctx: + GraphClient("REAL_TOKEN").get("path", params={"access_token": "OTHER", "fields": "id"}) + params = session.get.call_args.kwargs["params"] + self.assertEqual(params["access_token"], "REAL_TOKEN") + + +class GraphClientErrorMappingTest(unittest.TestCase): + def _patch(self, response): + session = MagicMock() + session.get.return_value = response + return patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session) + + def test_401_raises_auth(self): + with self._patch(_resp({"error": {"code": 190, "message": "invalid token"}}, status=401, ok=False)), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(AuthError): + GraphClient("x").get("me") + + def test_permission_code_10_raises_auth(self): + with self._patch(_resp({"error": {"code": 10, "message": "no perm"}}, status=400, ok=False)), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(AuthError) as cm: + GraphClient("x").get("me") + self.assertIn("permissions", str(cm.exception).lower()) + + def test_429_raises_rate_limit(self): + with self._patch(_resp({"error": {"code": 4, "message": "rate"}}, status=429, ok=False)), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(RateLimitError): + GraphClient("x", max_rate_limit_retries=0).get("me") + + def test_5xx_treated_as_retryable(self): + with self._patch(_resp({}, status=503, ok=False)), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(RateLimitError): + GraphClient("x", max_rate_limit_retries=0).get("me") + + def test_other_4xx_raises_plugin_error(self): + with self._patch(_resp({"error": {"code": 100, "message": "bad param"}}, status=400, ok=False)), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(PluginError) as cm: + GraphClient("x").get("me") + self.assertIn("100", str(cm.exception)) + + +class GraphClientRetryTest(unittest.TestCase): + def test_retries_on_429_then_succeeds(self): + sleeps: list[float] = [] + responses = [ + _resp({"error": {"code": 4, "message": "rate"}}, status=429, ok=False), + _resp({"id": "ok"}), + ] + session = MagicMock() + session.get.side_effect = responses + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), \ + patch.object(GraphClient, "_sleep", staticmethod(lambda s: sleeps.append(s))): + result = GraphClient("x").get("me") + self.assertEqual(result, {"id": "ok"}) + self.assertEqual(sleeps, [2.0]) + + def test_retries_on_5xx_then_succeeds(self): + sleeps: list[float] = [] + responses = [ + _resp({}, status=500, ok=False), + _resp({}, status=503, ok=False), + _resp({"id": "ok"}), + ] + session = MagicMock() + session.get.side_effect = responses + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), \ + patch.object(GraphClient, "_sleep", staticmethod(lambda s: sleeps.append(s))): + result = GraphClient("x").get("me") + self.assertEqual(result, {"id": "ok"}) + self.assertEqual(sleeps, [2.0, 4.0]) + + def test_exhausts_retries_then_raises(self): + responses = [_resp({"error": {"code": 4, "message": "rate"}}, status=429, ok=False)] * 5 + session = MagicMock() + session.get.side_effect = responses + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), \ + patch.object(GraphClient, "_sleep"): + with self.assertRaises(RateLimitError): + GraphClient("x", max_rate_limit_retries=2).get("me") + # 1 initial + 2 retries = 3 + self.assertEqual(session.get.call_count, 3) + + +class GraphClientTokenRedactionTest(unittest.TestCase): + """Network errors must not leak the access token in their messages.""" + + def test_request_exception_message_redacted(self): + import requests as rq + # Simulate the kind of message requests sometimes produces, which + # contains the full URL with ?access_token=… in the query string. + leaked_url = "https://graph.facebook.com/v19.0/me?access_token=SECRET&fields=username" + exc = rq.RequestException(f"ConnectionError(...) at {leaked_url}") + + session = MagicMock() + session.get.side_effect = exc + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), \ + patch.object(GraphClient, "_sleep"): + client = GraphClient("SECRET") + with self.assertRaises(PluginError) as cm: + client.get("me") + msg = str(cm.exception) + self.assertNotIn("SECRET", msg) + self.assertIn("[REDACTED]", msg) + + def test_token_not_in_chained_traceback(self): + # Even with `from None`, ensure no chained __cause__ retains the secret. + import requests as rq + exc = rq.RequestException("error: token=MYSECRET") + session = MagicMock() + session.get.side_effect = exc + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session), \ + patch.object(GraphClient, "_sleep"): + try: + GraphClient("MYSECRET").get("me") + except PluginError as raised: + # Cause should be None due to `raise ... from None` + self.assertIsNone(raised.__cause__) + self.assertNotIn("MYSECRET", str(raised)) + + +class GraphClientPaginationTest(unittest.TestCase): + def test_walks_next_url_and_caps_at_max_items(self): + page1 = {"data": [{"id": str(i)} for i in range(10)], + "paging": {"next": "https://graph.facebook.com/v19.0/x/media?fields=id&after=ABC"}} + page2 = {"data": [{"id": str(i)} for i in range(10, 15)]} + session = MagicMock() + session.get.side_effect = [_resp(page1), _resp(page2)] + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session): + client = GraphClient("x") + items = client.get_paginated("17841/media", params={"fields": "id"}) + self.assertEqual(len(items), 15) + + def test_max_items_stops_early(self): + page1 = {"data": [{"id": str(i)} for i in range(10)], + "paging": {"next": "https://graph.facebook.com/v19.0/x/media?after=A"}} + session = MagicMock() + session.get.return_value = _resp(page1) + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session): + items = GraphClient("x").get_paginated("17841/media", max_items=5) + self.assertEqual(len(items), 5) + # Only ONE call — we hit max_items inside the first page + self.assertEqual(session.get.call_count, 1) + + def test_no_next_url_stops(self): + session = MagicMock() + session.get.return_value = _resp({"data": [{"id": "1"}]}) + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session): + items = GraphClient("x").get_paginated("path") + self.assertEqual(items, [{"id": "1"}]) + + def test_strips_embedded_token_in_next_url(self): + page1 = { + "data": [{"id": "1"}], + "paging": {"next": "https://graph.facebook.com/v19.0/x/media?after=A&access_token=EMBEDDED"}, + } + page2 = {"data": [{"id": "2"}]} + session = MagicMock() + session.get.side_effect = [_resp(page1), _resp(page2)] + with patch("content_parser.plugins.instagram_graph.client.requests.Session", + return_value=session): + GraphClient("OUR_TOKEN").get_paginated("x/media") + # Second call should use OUR token, not EMBEDDED + second_call_params = session.get.call_args_list[1].kwargs["params"] + self.assertEqual(second_call_params["access_token"], "OUR_TOKEN") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_instagram_graph_plugin.py b/tests/test_instagram_graph_plugin.py new file mode 100644 index 0000000..f9ed86c --- /dev/null +++ b/tests/test_instagram_graph_plugin.py @@ -0,0 +1,263 @@ +"""Tests for content_parser.plugins.instagram_graph.plugin.""" +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import AuthError, PluginError, RateLimitError +from content_parser.plugins.instagram_graph.plugin import InstagramGraphPlugin + + +VALID_ACCOUNT_ID = "17841405822304914" +VALID_POST_ID = "17895695668004550" + + +class ResolveTest(unittest.TestCase): + def setUp(self): + self.p = InstagramGraphPlugin() + + def test_valid_account_id(self): + specs = self.p.resolve({"account": [VALID_ACCOUNT_ID]}, {}, {}) + self.assertEqual(specs, [f"account:{VALID_ACCOUNT_ID}"]) + + def test_valid_post_id(self): + specs = self.p.resolve({"post_id": [VALID_POST_ID]}, {}, {}) + self.assertEqual(specs, [f"post:{VALID_POST_ID}"]) + + def test_dedupes(self): + specs = self.p.resolve( + {"account": [VALID_ACCOUNT_ID, VALID_ACCOUNT_ID]}, {}, {}, + ) + self.assertEqual(len(specs), 1) + + def test_rejects_non_numeric_account(self): + with self.assertRaises(PluginError): + self.p.resolve({"account": ["not_an_id"]}, {}, {}) + + def test_rejects_too_short_account(self): + with self.assertRaises(PluginError): + self.p.resolve({"account": ["123"]}, {}, {}) + + def test_rejects_url_in_account_field(self): + with self.assertRaises(PluginError): + self.p.resolve( + {"account": ["https://instagram.com/myaccount"]}, {}, {}, + ) + + def test_rejects_invalid_post_id(self): + with self.assertRaises(PluginError): + self.p.resolve({"post_id": ["abc"]}, {}, {}) + + +class FetchAuthTest(unittest.TestCase): + def test_missing_token_raises(self): + p = InstagramGraphPlugin() + with self.assertRaises(AuthError): + list(p.fetch([f"account:{VALID_ACCOUNT_ID}"], {}, {})) + + +class FetchDispatchTest(unittest.TestCase): + def setUp(self): + self.p = InstagramGraphPlugin() + self.token = "TEST_TOKEN" + self.secrets = {"INSTAGRAM_ACCESS_TOKEN": self.token} + + def _patched_client(self, calls: dict): + """calls = {(method_name, path): return_value}.""" + client = MagicMock() + client.get.side_effect = lambda path, params=None: calls[("get", path)] + + def paginated(path, params=None, max_items=None): + return calls[("paginated", path)] + + client.get_paginated.side_effect = paginated + return patch("content_parser.plugins.instagram_graph.plugin.GraphClient", return_value=client), client + + def test_account_path_calls_user_then_media(self): + media1 = { + "id": "1001", "media_type": "IMAGE", + "caption": "post one", "permalink": "https://insta/p/1", + "like_count": 100, "comments_count": 5, + "timestamp": "2026-04-01T00:00:00+0000", + "owner": {"id": VALID_ACCOUNT_ID, "username": "myaccount"}, + } + media2 = { + "id": "1002", "media_type": "REEL", + "caption": "reel two", "permalink": "https://insta/r/2", + "owner": {"id": VALID_ACCOUNT_ID, "username": "myaccount"}, + } + calls = { + ("get", VALID_ACCOUNT_ID): {"username": "myaccount"}, + ("paginated", f"{VALID_ACCOUNT_ID}/media"): [media1, media2], + ("paginated", "1001/comments"): [], + ("paginated", "1002/comments"): [], + ("get", "1001/insights"): {"data": [{"name": "reach", "values": [{"value": 500}]}]}, + ("get", "1002/insights"): {"data": [{"name": "plays", "values": [{"value": 9000}]}]}, + } + ctx, client = self._patched_client(calls) + with ctx: + items = list(self.p.fetch( + [f"account:{VALID_ACCOUNT_ID}"], + {"max_posts_per_account": 5, "fetch_insights": True}, + self.secrets, + )) + self.assertEqual(len(items), 2) + self.assertEqual(items[0].author, "myaccount") + self.assertEqual(items[0].media["reach"], 500) + self.assertEqual(items[1].media["plays"], 9000) + + def test_post_path_calls_media_directly(self): + media = { + "id": VALID_POST_ID, "media_type": "VIDEO", + "caption": "single", "permalink": "https://insta/p/x", + "owner": {"id": VALID_ACCOUNT_ID, "username": "myaccount"}, + } + calls = { + ("get", VALID_POST_ID): media, + ("paginated", f"{VALID_POST_ID}/comments"): [], + ("get", f"{VALID_POST_ID}/insights"): {"data": []}, + } + ctx, client = self._patched_client(calls) + with ctx: + items = list(self.p.fetch( + [f"post:{VALID_POST_ID}"], + {"fetch_insights": True}, + self.secrets, + )) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].item_id, VALID_POST_ID) + + def test_dedupes_same_media_from_account_and_post(self): + media = { + "id": "1001", "media_type": "IMAGE", "caption": "x", + "permalink": "https://insta/p/1", + "owner": {"id": VALID_ACCOUNT_ID, "username": "myaccount"}, + } + calls = { + ("get", VALID_ACCOUNT_ID): {"username": "myaccount"}, + ("paginated", f"{VALID_ACCOUNT_ID}/media"): [media], + ("get", "1001"): media, + ("paginated", "1001/comments"): [], + } + ctx, _ = self._patched_client(calls) + with ctx: + items = list(self.p.fetch( + [f"account:{VALID_ACCOUNT_ID}", "post:1001"], + {"max_posts_per_account": 5, "fetch_insights": False}, + self.secrets, + )) + self.assertEqual(len(items), 1) + + def test_fetch_insights_false_skips_insight_calls(self): + media = { + "id": VALID_POST_ID, "media_type": "IMAGE", "caption": "x", + "permalink": "https://insta/p/x", + } + calls = { + ("get", VALID_POST_ID): media, + ("paginated", f"{VALID_POST_ID}/comments"): [], + } + ctx, client = self._patched_client(calls) + with ctx: + list(self.p.fetch( + [f"post:{VALID_POST_ID}"], + {"fetch_insights": False}, + self.secrets, + )) + # Verify NO call was made to /insights + for call in client.get.call_args_list: + args = call.args + self.assertNotIn("insights", args[0]) + + def test_insights_metric_selection_for_reel_vs_post(self): + """Reels use plays/total_interactions; feed posts use impressions.""" + client = MagicMock() + captured_metrics = [] + + def get_side(path, params=None): + if path.endswith("/insights"): + captured_metrics.append(params.get("metric")) + return {"data": []} + if path == VALID_POST_ID: + # Two consecutive calls for two different scenarios — the test + # asks for posts in two separate fetch calls. + return reel_media if "reel_call" in params else feed_media + return {} + + # Use simple branches: dispatch based on which call we're making. + reel_media = { + "id": VALID_POST_ID, "media_type": "REEL", "caption": "r", + "permalink": "https://insta/r", + } + feed_media = { + "id": VALID_POST_ID, "media_type": "IMAGE", "caption": "f", + "permalink": "https://insta/p", + } + + # Reel scenario + client.get.side_effect = lambda path, params=None: ( + reel_media if path == VALID_POST_ID else ( + ({"data": []}, captured_metrics.append(params.get("metric")))[0] + if path.endswith("/insights") else {} + ) + ) + client.get_paginated.return_value = [] + with patch("content_parser.plugins.instagram_graph.plugin.GraphClient", return_value=client): + list(self.p.fetch( + [f"post:{VALID_POST_ID}"], + {"fetch_insights": True, "fetch_comments": False}, + self.secrets, + )) + self.assertEqual(len(captured_metrics), 1) + # Reel metric set + self.assertIn("plays", captured_metrics[0]) + self.assertIn("total_interactions", captured_metrics[0]) + + # Feed-post scenario + captured_metrics.clear() + client.get.side_effect = lambda path, params=None: ( + feed_media if path == VALID_POST_ID else ( + ({"data": []}, captured_metrics.append(params.get("metric")))[0] + if path.endswith("/insights") else {} + ) + ) + with patch("content_parser.plugins.instagram_graph.plugin.GraphClient", return_value=client): + list(self.p.fetch( + [f"post:{VALID_POST_ID}"], + {"fetch_insights": True, "fetch_comments": False}, + self.secrets, + )) + self.assertEqual(len(captured_metrics), 1) + # Feed-post metric set — has impressions, no plays + self.assertIn("impressions", captured_metrics[0]) + self.assertNotIn("plays", captured_metrics[0]) + + def test_insights_failure_does_not_break_run(self): + media = { + "id": VALID_POST_ID, "media_type": "REEL", "caption": "x", + "permalink": "https://insta/r/x", + } + client = MagicMock() + + def get_side(path, params=None): + if path == VALID_POST_ID: + return media + if path.endswith("/insights"): + raise PluginError("permissions") + return {} + client.get.side_effect = get_side + client.get_paginated.return_value = [] + + with patch("content_parser.plugins.instagram_graph.plugin.GraphClient", return_value=client): + items = list(self.p.fetch( + [f"post:{VALID_POST_ID}"], + {"fetch_insights": True}, + self.secrets, + )) + self.assertEqual(len(items), 1) + # plays/reach should be absent because insights failed + self.assertNotIn("plays", items[0].media) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_instagram_plugin.py b/tests/test_instagram_plugin.py new file mode 100644 index 0000000..a0e9cd2 --- /dev/null +++ b/tests/test_instagram_plugin.py @@ -0,0 +1,141 @@ +"""Tests for content_parser.plugins.instagram.plugin — input validation + dispatch.""" +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import PluginError +from content_parser.plugins.instagram.plugin import InstagramPlugin + + +class NormalizeAccountTest(unittest.TestCase): + def setUp(self): + self.p = InstagramPlugin() + + def test_plain_username(self): + self.assertEqual(self.p._normalize_account("nasa"), "nasa") + + def test_at_handle(self): + self.assertEqual(self.p._normalize_account("@durov"), "durov") + + def test_url_with_query(self): + self.assertEqual( + self.p._normalize_account("https://instagram.com/zuck/?hl=en"), + "zuck", + ) + + def test_dotted_underscored(self): + self.assertEqual(self.p._normalize_account("user.name_42"), "user.name_42") + + def test_rejects_post_url(self): + with self.assertRaises(PluginError): + self.p._normalize_account("https://www.instagram.com/p/abc/") + + def test_rejects_reel_url(self): + with self.assertRaises(PluginError): + self.p._normalize_account("https://www.instagram.com/reel/abc/") + + def test_rejects_invalid_chars(self): + with self.assertRaises(PluginError): + self.p._normalize_account("not a valid name!") + + +class ResolveTest(unittest.TestCase): + def setUp(self): + self.p = InstagramPlugin() + + def test_specs_carry_kind_prefix(self): + specs = self.p.resolve( + { + "hashtag": ["smm", "#marketing"], + "account": ["nasa"], + "post_url": ["https://www.instagram.com/reel/xyz/"], + }, + {}, + {"APIFY_API_TOKEN": "x"}, + ) + kinds = [s.split(":", 1)[0] for s in specs] + self.assertEqual(kinds.count("hashtag"), 2) + self.assertEqual(kinds.count("account"), 1) + self.assertEqual(kinds.count("post"), 1) + + def test_dedupes(self): + specs = self.p.resolve( + {"hashtag": ["smm", "#smm", "smm"]}, {}, {"APIFY_API_TOKEN": "x"} + ) + self.assertEqual(len(specs), 1) + + def test_rejects_account_url_in_post_url_field(self): + with self.assertRaises(PluginError): + self.p.resolve( + {"post_url": ["https://www.instagram.com/nasa/"]}, + {}, + {"APIFY_API_TOKEN": "x"}, + ) + + +class FetchDispatchTest(unittest.TestCase): + """Verifies fetch() splits by input kind into separate Apify calls.""" + + def setUp(self): + self.p = InstagramPlugin() + self.fake_post = {"id": "1", "shortCode": "AAA", "url": "https://insta/p/AAA"} + + def _run(self, specs, settings=None): + settings = settings or {"max_posts_per_input": 5, "results_type": "posts"} + with patch("content_parser.plugins.instagram.plugin.ApifyClient") as MC: + MC.return_value.run_actor.return_value = [self.fake_post] + list(self.p.fetch(specs, settings, {"APIFY_API_TOKEN": "x"})) + return MC.return_value.run_actor.call_args_list + + def test_listings_only_makes_one_call(self): + calls = self._run(["account:https://insta/nasa/", "hashtag:https://insta/explore/tags/x/"]) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0][0][1]["resultsType"], "posts") + + def test_post_urls_only_makes_one_details_call(self): + calls = self._run(["post:https://insta/p/AAA/"]) + self.assertEqual(len(calls), 1) + self.assertEqual(calls[0][0][1]["resultsType"], "details") + + def test_mixed_makes_two_calls(self): + calls = self._run([ + "account:https://insta/nasa/", + "post:https://insta/p/AAA/", + ]) + self.assertEqual(len(calls), 2) + types = sorted(c[0][1]["resultsType"] for c in calls) + self.assertEqual(types, ["details", "posts"]) + + def test_results_type_setting_applied_to_listings(self): + calls = self._run( + ["account:https://insta/nasa/"], + settings={"max_posts_per_input": 5, "results_type": "comments"}, + ) + self.assertEqual(calls[0][0][1]["resultsType"], "comments") + + +class ApifyClientAuthTest(unittest.TestCase): + """ApifyClient sends the token in Authorization header, not query string.""" + + def test_uses_bearer_header(self): + from content_parser.clients.apify import ApifyClient + + with patch( + "content_parser.clients.apify.requests.post" + ) as rp: + resp = MagicMock() + resp.ok = True + resp.status_code = 200 + resp.json.return_value = [] + rp.return_value = resp + + ApifyClient("MY_TOKEN").run_actor("apify/x", {}) + + kwargs = rp.call_args.kwargs + self.assertEqual(kwargs["headers"], {"Authorization": "Bearer MY_TOKEN"}) + self.assertNotIn("token", kwargs.get("params", {})) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_jobs_cron.py b/tests/test_jobs_cron.py new file mode 100644 index 0000000..a453e08 --- /dev/null +++ b/tests/test_jobs_cron.py @@ -0,0 +1,305 @@ +"""Tests for content_parser.jobs.cron — managed crontab block I/O.""" +from __future__ import annotations + +import shlex +import unittest +from pathlib import Path +from unittest.mock import patch + +from content_parser.jobs import cron as cron_module +from content_parser.jobs.cron import ( + BEGIN_MARKER, + END_MARKER, + CronError, + _build_block, + _strip_block, + build_command_for_job, +) +from content_parser.jobs.schema import Job + + +class StripBlockTest(unittest.TestCase): + def test_no_block_returns_unchanged(self): + text = "0 0 * * * date\n5 * * * * uptime\n" + self.assertEqual(_strip_block(text).rstrip(), text.rstrip()) + + def test_strips_block_only(self): + text = ( + "0 0 * * * date\n" + f"{BEGIN_MARKER}\n" + "0 6 * * MON cd /x && python -m content_parser.cli jobs run foo\n" + f"{END_MARKER}\n" + "5 * * * * uptime\n" + ) + result = _strip_block(text) + self.assertNotIn(BEGIN_MARKER, result) + self.assertNotIn("python -m content_parser", result) + self.assertIn("0 0 * * * date", result) + self.assertIn("5 * * * * uptime", result) + + def test_handles_block_at_start(self): + text = ( + f"{BEGIN_MARKER}\n" + "5 * * * * managed\n" + f"{END_MARKER}\n" + "5 * * * * outside\n" + ) + result = _strip_block(text) + self.assertNotIn("managed", result) + self.assertIn("outside", result) + + +class BuildBlockTest(unittest.TestCase): + def test_empty_jobs_no_block(self): + self.assertEqual(_build_block([]), []) + + def test_block_has_markers(self): + from content_parser.jobs.cron import CronEntry + lines = _build_block([ + CronEntry(schedule="0 6 * * MON", job_name="foo", command="cd / && true"), + ]) + self.assertEqual(lines[0], BEGIN_MARKER) + self.assertEqual(lines[-1], END_MARKER) + self.assertIn("0 6 * * MON", lines[1]) + self.assertIn("# job:foo", lines[1]) + + +class BuildCommandTest(unittest.TestCase): + def test_command_quotes_paths(self): + job = Job(name="my-job", source="vk", inputs={"community": ["x"]}) + cmd = build_command_for_job( + job, + project_root=Path("/path with spaces/repo"), + python_executable="/usr/bin/python3", + log_path=Path("/var/log with spaces/cron.log"), + ) + # Paths with spaces must be shell-quoted; safe paths can stay as-is. + self.assertIn(shlex.quote("/path with spaces/repo"), cmd) + self.assertIn(shlex.quote("/var/log with spaces/cron.log"), cmd) + self.assertIn("jobs run my-job", cmd) + self.assertIn("2>&1", cmd) + self.assertTrue(cmd.startswith("cd ")) + + def test_default_log_path_is_under_project_root(self): + job = Job(name="x", source="vk", inputs={"community": ["a"]}) + cmd = build_command_for_job( + job, project_root=Path("/repo"), python_executable="/usr/bin/python3" + ) + self.assertIn("/repo/output/scheduled/.cron.log", cmd) + + def test_newline_in_project_root_rejected(self): + # crontab is line-based; a literal newline inside any path would split + # the entry across lines and corrupt the file. shlex.quote does NOT + # protect against this — it just wraps the bytes in single quotes. + job = Job(name="x", source="vk", inputs={"community": ["a"]}) + with self.assertRaises(CronError) as cm: + build_command_for_job( + job, + project_root=Path("/path\nnewline/repo"), + python_executable="/usr/bin/python3", + ) + self.assertIn("newline", str(cm.exception).lower()) + + def test_newline_in_python_executable_rejected(self): + job = Job(name="x", source="vk", inputs={"community": ["a"]}) + with self.assertRaises(CronError): + build_command_for_job( + job, + project_root=Path("/repo"), + python_executable="/usr/bin/py\nthon", + ) + + def test_newline_in_log_path_rejected(self): + job = Job(name="x", source="vk", inputs={"community": ["a"]}) + with self.assertRaises(CronError): + build_command_for_job( + job, + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + log_path=Path("/var/log\nbreak/cron.log"), + ) + + def test_carriage_return_also_rejected(self): + job = Job(name="x", source="vk", inputs={"community": ["a"]}) + with self.assertRaises(CronError): + build_command_for_job( + job, + project_root=Path("/repo\r/here"), + python_executable="/usr/bin/python3", + ) + + +class InstallCronTest(unittest.TestCase): + """install_cron orchestrates crontab -l → strip → write.""" + + def _patches(self, current_crontab: str = ""): + return patch.multiple( + "content_parser.jobs.cron", + _existing_crontab=lambda: current_crontab, + _write_crontab=patch.DEFAULT, + ) + + def _job(self, name="weekly", schedule="0 6 * * MON"): + return Job( + name=name, source="vk", + inputs={"community": ["a"]}, + schedule=schedule, + ) + + def test_first_install_appends_block(self): + written: dict = {} + with patch("content_parser.jobs.cron._existing_crontab", return_value="0 0 * * * date\n"), \ + patch("content_parser.jobs.cron._write_crontab", side_effect=lambda t: written.setdefault("text", t)): + entries = cron_module.install_cron( + jobs=[self._job()], + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + ) + self.assertEqual(len(entries), 1) + text = written["text"] + self.assertIn("0 0 * * * date", text) + self.assertIn(BEGIN_MARKER, text) + self.assertIn(END_MARKER, text) + self.assertIn("# job:weekly", text) + + def test_idempotent_replace(self): + # Run install twice — the second run should produce the same result. + existing_after_first = "" + written: list[str] = [] + + def write(t): + written.append(t) + nonlocal existing_after_first + existing_after_first = t + + def existing(): + return existing_after_first + + with patch("content_parser.jobs.cron._existing_crontab", side_effect=existing), \ + patch("content_parser.jobs.cron._write_crontab", side_effect=write): + cron_module.install_cron( + jobs=[self._job()], + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + ) + cron_module.install_cron( + jobs=[self._job()], + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + ) + + self.assertEqual(written[0], written[1]) + + def test_jobs_without_schedule_are_skipped(self): + manual_only = Job( + name="manual-only", source="vk", + inputs={"community": ["a"]}, + ) # no schedule + + written: dict = {} + with patch("content_parser.jobs.cron._existing_crontab", return_value=""), \ + patch("content_parser.jobs.cron._write_crontab", side_effect=lambda t: written.setdefault("text", t)): + entries = cron_module.install_cron( + jobs=[manual_only], + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + ) + self.assertEqual(entries, []) + self.assertNotIn("# job:manual-only", written.get("text", "")) + + def test_preserves_lines_outside_block(self): + existing = ( + "0 0 * * * /home/me/backup.sh\n" + f"{BEGIN_MARKER}\n" + "old line\n" + f"{END_MARKER}\n" + "5 * * * * /usr/bin/something\n" + ) + written: dict = {} + with patch("content_parser.jobs.cron._existing_crontab", return_value=existing), \ + patch("content_parser.jobs.cron._write_crontab", side_effect=lambda t: written.setdefault("text", t)): + cron_module.install_cron( + jobs=[self._job()], + project_root=Path("/repo"), + python_executable="/usr/bin/python3", + ) + text = written["text"] + self.assertIn("/home/me/backup.sh", text) + self.assertIn("/usr/bin/something", text) + self.assertNotIn("old line", text) + + +class RemoveCronTest(unittest.TestCase): + def test_removes_when_present(self): + existing = ( + f"0 0 * * * /backup\n" + f"{BEGIN_MARKER}\n" + "managed line\n" + f"{END_MARKER}\n" + ) + written: dict = {} + with patch("content_parser.jobs.cron._existing_crontab", return_value=existing), \ + patch("content_parser.jobs.cron._write_crontab", side_effect=lambda t: written.setdefault("text", t)): + self.assertTrue(cron_module.remove_cron()) + text = written["text"] + self.assertNotIn(BEGIN_MARKER, text) + self.assertIn("/backup", text) + + def test_returns_false_when_absent(self): + with patch("content_parser.jobs.cron._existing_crontab", return_value="0 0 * * * /backup\n"), \ + patch("content_parser.jobs.cron._write_crontab") as wr: + self.assertFalse(cron_module.remove_cron()) + wr.assert_not_called() + + +class ReadBlockTest(unittest.TestCase): + def test_parses_managed_entries(self): + existing = ( + "0 0 * * * outside\n" + f"{BEGIN_MARKER}\n" + "0 6 * * MON cd /x && py -m content_parser.cli jobs run weekly # job:weekly\n" + "0 8 * * * cd /x && py -m content_parser.cli jobs run daily # job:daily\n" + f"{END_MARKER}\n" + "5 * * * * also-outside\n" + ) + with patch("content_parser.jobs.cron._existing_crontab", return_value=existing): + entries = cron_module.read_block() + self.assertEqual(len(entries), 2) + self.assertEqual(entries[0].schedule, "0 6 * * MON") + self.assertEqual(entries[0].job_name, "weekly") + self.assertEqual(entries[1].schedule, "0 8 * * *") + self.assertEqual(entries[1].job_name, "daily") + + def test_empty_when_no_block(self): + with patch("content_parser.jobs.cron._existing_crontab", return_value="0 0 * * * x\n"): + self.assertEqual(cron_module.read_block(), []) + + +class ExistingCrontabTest(unittest.TestCase): + def test_no_crontab_returns_empty(self): + proc = type("P", (), {})() + proc.returncode = 1 + proc.stdout = "" + proc.stderr = "no crontab for user\n" + with patch("content_parser.jobs.cron.subprocess.run", return_value=proc): + self.assertEqual(cron_module._existing_crontab(), "") + + def test_real_error_raises(self): + proc = type("P", (), {})() + proc.returncode = 1 + proc.stdout = "" + proc.stderr = "crontab: invalid option -- 'z'" + with patch("content_parser.jobs.cron.subprocess.run", return_value=proc): + with self.assertRaises(CronError): + cron_module._existing_crontab() + + def test_no_crontab_binary_raises(self): + with patch("content_parser.jobs.cron.subprocess.run", + side_effect=FileNotFoundError("crontab not installed")): + with self.assertRaises(CronError) as cm: + cron_module._existing_crontab() + self.assertIn("crontab", str(cm.exception).lower()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_jobs_runner.py b/tests/test_jobs_runner.py new file mode 100644 index 0000000..0ad92b9 --- /dev/null +++ b/tests/test_jobs_runner.py @@ -0,0 +1,251 @@ +"""Tests for content_parser.jobs.runner — input merging + failure handling.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import PluginError +from content_parser.core.runner import RunResult +from content_parser.jobs import runner as runner_module +from content_parser.jobs.schema import Job, SheetInput + + +class ResolveInputsTest(unittest.TestCase): + """`_resolve_inputs` merges inline values with values pulled from Sheets.""" + + def test_inline_only(self): + job = Job( + name="x", source="vk", + inputs={"community": ["a", "b"], "query": ["hello"]}, + ) + out = runner_module._resolve_inputs(job, secrets={}) + self.assertEqual(out, {"community": ["a", "b"], "query": ["hello"]}) + + def test_sheets_only(self): + job = Job( + name="x", source="vk", + sheet_inputs=[ + SheetInput(sheet="ID" * 12, target="community", range_a1="A:A"), + ], + ) + with patch("content_parser.loaders.gsheets.GoogleSheetsLoader") as MockLoader: + mock_loader = MockLoader.from_secrets.return_value + loaded = MagicMock() + loaded.values = ["x", "y", "z"] + mock_loader.load.return_value = loaded + out = runner_module._resolve_inputs( + job, secrets={"GOOGLE_SHEETS_CREDENTIALS": "fake"} + ) + self.assertEqual(out, {"community": ["x", "y", "z"]}) + + def test_inline_plus_sheets_dedupes(self): + job = Job( + name="x", source="vk", + inputs={"community": ["dup", "inline-only"]}, + sheet_inputs=[ + SheetInput(sheet="ID" * 12, target="community"), + SheetInput(sheet="ID" * 12, target="query", range_a1="B:B"), + ], + ) + with patch("content_parser.loaders.gsheets.GoogleSheetsLoader") as MockLoader: + mock_loader = MockLoader.from_secrets.return_value + + def fake_load(sheet, **kwargs): + if kwargs.get("range_a1") == "B:B": + return MagicMock(values=["search-term"]) + return MagicMock(values=["dup", "from-sheet"]) + + mock_loader.load.side_effect = fake_load + out = runner_module._resolve_inputs( + job, secrets={"GOOGLE_SHEETS_CREDENTIALS": "fake"} + ) + self.assertEqual(out["community"], ["dup", "inline-only", "from-sheet"]) + self.assertEqual(out["query"], ["search-term"]) + + def test_drops_empty_kinds(self): + # Inputs key with empty list shouldn't propagate. + job = Job( + name="x", source="vk", + inputs={"community": ["a"], "query": []}, + ) + out = runner_module._resolve_inputs(job, secrets={}) + self.assertEqual(out, {"community": ["a"]}) + + +class CollectSecretsTest(unittest.TestCase): + def test_includes_plugin_keys(self): + with patch("content_parser.jobs.runner.get_secret") as gs: + gs.side_effect = lambda k: {"VK_ACCESS_TOKEN": "tok"}.get(k, "") + secrets = runner_module._collect_secrets(["VK_ACCESS_TOKEN"], need_sheets=False) + self.assertEqual(secrets["VK_ACCESS_TOKEN"], "tok") + self.assertNotIn("GOOGLE_SHEETS_CREDENTIALS", secrets) + + def test_adds_sheets_when_needed(self): + with patch("content_parser.jobs.runner.get_secret") as gs: + gs.side_effect = lambda k: { + "VK_ACCESS_TOKEN": "tok", + "GOOGLE_SHEETS_CREDENTIALS": "creds", + }.get(k, "") + secrets = runner_module._collect_secrets(["VK_ACCESS_TOKEN"], need_sheets=True) + self.assertEqual(secrets["GOOGLE_SHEETS_CREDENTIALS"], "creds") + + def test_picks_up_optional_proxy_secrets(self): + with patch("content_parser.jobs.runner.get_secret") as gs: + gs.side_effect = lambda k: {"WEBSHARE_USERNAME": "u", "WEBSHARE_PASSWORD": "p"}.get(k, "") + secrets = runner_module._collect_secrets([], need_sheets=False) + self.assertEqual(secrets["WEBSHARE_USERNAME"], "u") + self.assertEqual(secrets["WEBSHARE_PASSWORD"], "p") + + +class RunJobTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_runner_")) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def _job(self, **kwargs): + defaults = dict( + name="test-job", + source="vk", + inputs={"community": ["durov_says"]}, + output_dir=str(self.tmp), + ) + defaults.update(kwargs) + return Job(**defaults) + + def test_calls_core_run_with_resolved_inputs(self): + job = self._job() + fake_plugin = MagicMock() + fake_plugin.secret_keys = ["VK_ACCESS_TOKEN"] + + with patch("content_parser.jobs.runner.get_plugin", return_value=fake_plugin), \ + patch("content_parser.jobs.runner.get_secret", return_value="tok"), \ + patch("content_parser.jobs.runner.core_run") as mock_run: + mock_run.return_value = RunResult(out_dir=self.tmp, items=[]) + result = runner_module.run_job_obj(job) + + self.assertIsInstance(result, RunResult) + kwargs = mock_run.call_args.kwargs + self.assertEqual(kwargs["inputs"]["community"], ["durov_says"]) if "inputs" in kwargs else None + # core_run was called as positional + kwargs + args, kwargs = mock_run.call_args + self.assertIs(args[0], fake_plugin) + self.assertEqual(args[1], {"community": ["durov_says"]}) + self.assertEqual(args[2], {}) + + def test_writes_last_run_marker_on_success(self): + job = self._job() + fake_plugin = MagicMock() + fake_plugin.secret_keys = [] + + with patch("content_parser.jobs.runner.get_plugin", return_value=fake_plugin), \ + patch("content_parser.jobs.runner.get_secret", return_value=""), \ + patch("content_parser.jobs.runner.core_run") as mock_run: + mock_run.return_value = RunResult(out_dir=self.tmp / "fake", items=[]) + runner_module.run_job_obj(job) + + # _record_success writes into the runner's own resolved_output_dir, which + # lives somewhere under self.tmp/<timestamp>/. + markers = list(self.tmp.rglob(".last_run.txt")) + self.assertEqual(len(markers), 1) + self.assertIn("test-job", markers[0].read_text()) + + def test_writes_last_error_on_failure(self): + job = self._job() + fake_plugin = MagicMock() + fake_plugin.secret_keys = [] + + with patch("content_parser.jobs.runner.get_plugin", return_value=fake_plugin), \ + patch("content_parser.jobs.runner.get_secret", return_value=""), \ + patch("content_parser.jobs.runner.core_run", side_effect=RuntimeError("boom")): + with self.assertRaises(RuntimeError): + runner_module.run_job_obj(job) + + # The output dir was created as part of resolved_output_dir() resolution + # in _record_failure. Find it under self.tmp. + errors = list(self.tmp.rglob("last_error.txt")) + self.assertEqual(len(errors), 1) + text = errors[0].read_text() + self.assertIn("boom", text) + self.assertIn("RuntimeError", text) + + def test_record_failure_redacts_known_secret_values(self): + # Secrets passed in get scrubbed from traceback / error message. + job = self._job() + out_dir = self.tmp / "x" + try: + raise RuntimeError("call to https://api.example/?token=SECRETXYZ failed") + except RuntimeError as e: + runner_module._record_failure( + job, e, out_dir=out_dir, + secrets={"SOME_TOKEN": "SECRETXYZ", "OTHER": "shortish"}, + ) + text = (out_dir / "last_error.txt").read_text() + self.assertNotIn("SECRETXYZ", text) + self.assertIn("[REDACTED]", text) + + def test_record_failure_writes_status_file(self): + job = self._job() + out_dir = self.tmp / "y" + try: + raise ValueError("boom") + except ValueError as e: + runner_module._record_failure(job, e, out_dir=out_dir) + import json + status = json.loads((out_dir / ".last_status.json").read_text()) + self.assertEqual(status["status"], "failure") + self.assertEqual(status["job"], "test-job") + self.assertIn("ValueError", status["error"]) + + def test_record_success_writes_status_file(self): + from content_parser.core.runner import RunResult + job = self._job() + out_dir = self.tmp / "z" + out_dir.mkdir(parents=True) + runner_module._record_success(job, out_dir, RunResult(out_dir=out_dir, items=[])) + import json + status = json.loads((out_dir / ".last_status.json").read_text()) + self.assertEqual(status["status"], "success") + self.assertEqual(status["items"], 0) + + def test_notify_none_skips_error_marker(self): + job = self._job(notify_on_failure="none") + fake_plugin = MagicMock() + fake_plugin.secret_keys = [] + + with patch("content_parser.jobs.runner.get_plugin", return_value=fake_plugin), \ + patch("content_parser.jobs.runner.get_secret", return_value=""), \ + patch("content_parser.jobs.runner.core_run", side_effect=RuntimeError("boom")): + with self.assertRaises(RuntimeError): + runner_module.run_job_obj(job) + + errors = list(self.tmp.rglob("last_error.txt")) + self.assertEqual(len(errors), 0) + + def test_empty_resolved_inputs_raises(self): + # Both inline and sheets resolve to nothing → run_job refuses to call plugin. + # We can't construct a Job with truly empty inputs (validation refuses), + # but a sheet that returns nothing simulates the scenario. + job = self._job( + inputs={}, + sheet_inputs=[SheetInput(sheet="ID" * 12, target="community")], + ) + fake_plugin = MagicMock() + fake_plugin.secret_keys = [] + + with patch("content_parser.jobs.runner.get_plugin", return_value=fake_plugin), \ + patch("content_parser.jobs.runner.get_secret", return_value="creds"), \ + patch("content_parser.loaders.gsheets.GoogleSheetsLoader") as MockLoader: + MockLoader.from_secrets.return_value.load.return_value = MagicMock(values=[]) + + with self.assertRaises(PluginError) as cm: + runner_module.run_job_obj(job) + self.assertIn("no resolved inputs", str(cm.exception).lower()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_jobs_schema.py b/tests/test_jobs_schema.py new file mode 100644 index 0000000..e781226 --- /dev/null +++ b/tests/test_jobs_schema.py @@ -0,0 +1,280 @@ +"""Tests for content_parser.jobs.schema — Job parsing, validation, YAML I/O.""" +from __future__ import annotations + +import unittest + +from content_parser.core.errors import PluginError +from content_parser.jobs.schema import ( + Job, + SheetInput, + dump_job_yaml, + is_valid_cron, + load_job_yaml, +) + + +VALID_YAML = """\ +name: weekly-vk +source: vk +schedule: "0 6 * * MON" +description: "Weekly VK marketing" +inputs: + community: [durov_says, telegram] +sheet_inputs: + - sheet: "https://docs.google.com/spreadsheets/d/abc/edit" + tab: Communities + range: A2:A + target: community +settings: + max_posts_per_input: 50 +""" + + +class CronValidationTest(unittest.TestCase): + def test_standard_5_token(self): + self.assertTrue(is_valid_cron("0 6 * * MON")) + self.assertTrue(is_valid_cron("*/15 * * * *")) + self.assertTrue(is_valid_cron("0 0,12 * * *")) + self.assertTrue(is_valid_cron("0 9-17 * * 1-5")) + + def test_aliases(self): + for alias in ("@yearly", "@annually", "@monthly", "@weekly", "@daily", "@hourly", "@reboot"): + with self.subTest(alias=alias): + self.assertTrue(is_valid_cron(alias)) + + def test_invalid_token_count(self): + self.assertFalse(is_valid_cron("0 6 * *")) # 4 tokens + self.assertFalse(is_valid_cron("0 6 * * MON extra")) + + def test_invalid_chars(self): + self.assertFalse(is_valid_cron("0 6 * * !")) + self.assertFalse(is_valid_cron("0 6 * * MON; rm -rf /")) + + def test_unknown_alias(self): + self.assertFalse(is_valid_cron("@bogus")) + + def test_non_string(self): + self.assertFalse(is_valid_cron(None)) # type: ignore[arg-type] + self.assertFalse(is_valid_cron(42)) # type: ignore[arg-type] + + +class JobValidationTest(unittest.TestCase): + def _base(self, **overrides): + kwargs = dict( + name="my-job", + source="vk", + inputs={"community": ["durov_says"]}, + ) + kwargs.update(overrides) + return Job(**kwargs) + + def test_minimal_valid(self): + self._base().validate() # no exception + + def test_invalid_name_chars(self): + with self.assertRaises(PluginError): + self._base(name="my job").validate() + with self.assertRaises(PluginError): + self._base(name="../etc").validate() + with self.assertRaises(PluginError): + self._base(name="").validate() + + def test_name_too_long(self): + with self.assertRaises(PluginError): + self._base(name="x" * 65).validate() + + def test_missing_source(self): + with self.assertRaises(PluginError): + self._base(source="").validate() + + def test_no_inputs_and_no_sheets_rejected(self): + with self.assertRaises(PluginError): + self._base(inputs={}).validate() + + def test_only_sheet_inputs_ok(self): + self._base( + inputs={}, + sheet_inputs=[SheetInput(sheet="X" * 25, target="community")], + ).validate() + + def test_invalid_schedule(self): + with self.assertRaises(PluginError): + self._base(schedule="not a cron").validate() + + def test_inputs_not_list_rejected(self): + with self.assertRaises(PluginError): + self._base(inputs={"community": "not a list"}).validate() # type: ignore[arg-type] + + def test_output_dir_with_dotdot_rejected(self): + # Path traversal: 'output_dir: ../../etc' rejected at validation. + with self.assertRaises(PluginError) as cm: + self._base(output_dir="../../etc").validate() + self.assertIn("..", str(cm.exception)) + + def test_output_dir_normal_relative_ok(self): + self._base(output_dir="custom/scheduled").validate() # no exception + + def test_output_dir_absolute_ok(self): + self._base(output_dir="/tmp/my-output").validate() # user explicitly opted in + + def test_output_dir_dotdot_in_middle_rejected(self): + # ../../ at any position is rejected, not just at start. + with self.assertRaises(PluginError): + self._base(output_dir="custom/../escape").validate() + + def test_sheet_input_missing_sheet(self): + with self.assertRaises(PluginError): + self._base( + sheet_inputs=[SheetInput(sheet="", target="community")], + ).validate() + + def test_sheet_input_missing_target(self): + with self.assertRaises(PluginError): + self._base( + sheet_inputs=[SheetInput(sheet="X" * 25, target="")], + ).validate() + + def test_invalid_notify_value(self): + with self.assertRaises(PluginError): + self._base(notify_on_failure="email").validate() + + +class YamlRoundTripTest(unittest.TestCase): + def test_load_full(self): + job = load_job_yaml(VALID_YAML) + self.assertEqual(job.name, "weekly-vk") + self.assertEqual(job.source, "vk") + self.assertEqual(job.schedule, "0 6 * * MON") + self.assertEqual(job.inputs, {"community": ["durov_says", "telegram"]}) + self.assertEqual(len(job.sheet_inputs), 1) + self.assertEqual(job.sheet_inputs[0].target, "community") + self.assertEqual(job.sheet_inputs[0].range_a1, "A2:A") + self.assertEqual(job.settings, {"max_posts_per_input": 50}) + + def test_dump_then_load_idempotent(self): + original = load_job_yaml(VALID_YAML) + dumped = dump_job_yaml(original) + reloaded = load_job_yaml(dumped) + self.assertEqual(reloaded.name, original.name) + self.assertEqual(reloaded.source, original.source) + self.assertEqual(reloaded.schedule, original.schedule) + self.assertEqual(reloaded.inputs, original.inputs) + self.assertEqual(len(reloaded.sheet_inputs), 1) + + def test_uses_safe_load(self): + # YAML with a !!python/object directive must be rejected by safe_load. + evil = """\ +!!python/object/apply:os.system +- 'echo PWNED' +""" + with self.assertRaises(PluginError): + load_job_yaml(evil) + + def test_empty_doc_rejected(self): + with self.assertRaises(PluginError): + load_job_yaml("") + + def test_top_level_list_rejected(self): + with self.assertRaises(PluginError): + load_job_yaml("- a\n- b\n") + + def test_name_hint_used_when_body_missing_name(self): + yaml_no_name = """\ +source: vk +inputs: + community: [durov_says] +""" + job = load_job_yaml(yaml_no_name, name_hint="from-filename") + self.assertEqual(job.name, "from-filename") + + def test_range_alt_key_supported(self): + # Some users will write `range_a1` instead of `range` + yaml_alt = """\ +name: x +source: vk +sheet_inputs: + - sheet: longidlongidlongidlongidlong + target: community + range_a1: B:B +""" + job = load_job_yaml(yaml_alt) + self.assertEqual(job.sheet_inputs[0].range_a1, "B:B") + + def test_string_value_in_inputs_rejected(self): + # Common typo: `community: durov_says` (no brackets) — without the + # type-check the loop would iterate the string character by character + # and produce ['d','u','r','o','v', ...]. Must raise instead. + evil = """\ +name: my-job +source: vk +inputs: + community: durov_says +""" + with self.assertRaises(PluginError) as cm: + load_job_yaml(evil) + self.assertIn("must be a list", str(cm.exception)) + self.assertIn("community", str(cm.exception)) + + def test_int_value_in_inputs_rejected(self): + evil = """\ +name: x +source: vk +inputs: + community: 42 +""" + with self.assertRaises(PluginError): + load_job_yaml(evil) + + def test_dict_value_in_inputs_rejected(self): + evil = """\ +name: x +source: vk +inputs: + community: {nested: dict} +""" + with self.assertRaises(PluginError): + load_job_yaml(evil) + + def test_empty_value_in_inputs_treated_as_empty_list(self): + # `inputs.community:` with no value yields None. Acceptable as empty list. + yaml_empty = """\ +name: x +source: vk +inputs: + community: +sheet_inputs: + - sheet: longidlongidlongidlongidlong + target: community +""" + job = load_job_yaml(yaml_empty) + self.assertEqual(job.inputs["community"], []) + + +class OutputDirTest(unittest.TestCase): + def test_default_output_dir(self): + job = Job(name="my-job", source="vk", inputs={"community": ["x"]}) + path = job.resolved_output_dir(timestamp="20260101_120000") + self.assertEqual(path.parts[-3:], ("scheduled", "my-job", "20260101_120000")) + + def test_relative_output_dir_resolved_against_cwd(self): + job = Job( + name="my-job", source="vk", + inputs={"community": ["x"]}, + output_dir="custom/path", + ) + path = job.resolved_output_dir(timestamp="20260101_120000") + self.assertTrue(path.is_absolute()) + self.assertEqual(path.parts[-3:], ("custom", "path", "20260101_120000")) + + def test_absolute_output_dir_used_as_is(self): + job = Job( + name="my-job", source="vk", + inputs={"community": ["x"]}, + output_dir="/tmp/abs/path", + ) + path = job.resolved_output_dir(timestamp="20260101_120000") + self.assertEqual(str(path), "/tmp/abs/path/20260101_120000") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_jobs_store.py b/tests/test_jobs_store.py new file mode 100644 index 0000000..2febf55 --- /dev/null +++ b/tests/test_jobs_store.py @@ -0,0 +1,84 @@ +"""Tests for content_parser.jobs.store — CRUD with path-traversal guard.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from content_parser.core.errors import PluginError +from content_parser.jobs import store as store_module +from content_parser.jobs.schema import Job, SheetInput + + +class StoreTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_jobs_")) + # Override JOBS_DIR for the duration of each test. + self._orig = store_module.JOBS_DIR + store_module.JOBS_DIR = self.tmp + + def tearDown(self): + store_module.JOBS_DIR = self._orig + shutil.rmtree(self.tmp, ignore_errors=True) + + def _job(self, name="my-job"): + return Job(name=name, source="vk", inputs={"community": ["durov_says"]}) + + def test_save_and_load(self): + path = store_module.save_job(self._job()) + self.assertTrue(path.exists()) + loaded = store_module.load_job("my-job") + self.assertEqual(loaded.name, "my-job") + self.assertEqual(loaded.inputs["community"], ["durov_says"]) + + def test_save_sets_chmod_600(self): + import os + path = store_module.save_job(self._job()) + mode = os.stat(path).st_mode & 0o777 + self.assertEqual(mode, 0o600) + + def test_list_jobs_returns_sorted(self): + store_module.save_job(self._job("zebra")) + store_module.save_job(self._job("alpha")) + names = [j.name for j in store_module.list_jobs()] + self.assertEqual(names, ["alpha", "zebra"]) + + def test_list_jobs_skips_invalid_files(self): + # Valid job + store_module.save_job(self._job("good")) + # Invalid file (won't parse as a Job) + (self.tmp / "broken.yaml").write_text("not: a: valid: job", encoding="utf-8") + names = [j.name for j in store_module.list_jobs()] + self.assertEqual(names, ["good"]) + + def test_list_invalid_returns_pairs(self): + (self.tmp / "broken.yaml").write_text("source: \nimports:\n", encoding="utf-8") + invalid = store_module.list_invalid() + self.assertEqual(len(invalid), 1) + self.assertEqual(invalid[0][0], "broken") + + def test_load_missing_raises(self): + with self.assertRaises(PluginError): + store_module.load_job("nonexistent") + + def test_delete(self): + store_module.save_job(self._job()) + self.assertTrue(store_module.delete_job("my-job")) + self.assertFalse(store_module.delete_job("my-job")) + + def test_invalid_name_rejected(self): + for bad in ("../etc", "name with spaces", "x" * 100, "foo/bar", ""): + with self.subTest(bad=bad): + with self.assertRaises(PluginError): + store_module._job_path(bad) + + def test_job_exists(self): + self.assertFalse(store_module.job_exists("nope")) + store_module.save_job(self._job()) + self.assertTrue(store_module.job_exists("my-job")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_legacy_cli.py b/tests/test_legacy_cli.py new file mode 100644 index 0000000..f532079 --- /dev/null +++ b/tests/test_legacy_cli.py @@ -0,0 +1,52 @@ +"""Tests for youtube_parser.main — legacy CLI translation to new content_parser.cli.""" +from __future__ import annotations + +import unittest + +from youtube_parser.main import _build_legacy_parser, _to_new_argv + + +class LegacyCliTest(unittest.TestCase): + def _translate(self, args_list): + ns = _build_legacy_parser().parse_args(args_list) + return _to_new_argv(ns) + + def test_basic_video(self): + new = self._translate(["--video", "https://youtu.be/x"]) + self.assertIn("run", new) + self.assertIn("--source", new) + self.assertIn("youtube", new) + self.assertIn("--video", new) + self.assertIn("https://youtu.be/x", new) + + def test_output_passed(self): + new = self._translate(["--video", "x", "--output", "/tmp/out"]) + self.assertIn("--output", new) + self.assertIn("/tmp/out", new) + + def test_settings_translated(self): + new = self._translate([ + "--video", "x", + "--max-comments", "50", + "--include-replies", + "--no-transcripts", + "--comment-order", "time", + ]) + self.assertIn("max_comments=50", new) + self.assertIn("include_replies=true", new) + self.assertIn("fetch_transcripts=false", new) + self.assertIn("comment_order=time", new) + + def test_no_comments_flag(self): + new = self._translate(["--video", "x", "--no-comments"]) + self.assertIn("fetch_comments=false", new) + + def test_default_settings_when_unset(self): + new = self._translate(["--video", "x"]) + self.assertIn("fetch_transcripts=true", new) + self.assertIn("fetch_comments=true", new) + self.assertIn("include_replies=false", new) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reddit_adapter.py b/tests/test_reddit_adapter.py new file mode 100644 index 0000000..d9294c3 --- /dev/null +++ b/tests/test_reddit_adapter.py @@ -0,0 +1,158 @@ +"""Tests for content_parser.plugins.reddit.adapter — PRAW objects → Item/Comment.""" +from __future__ import annotations + +import unittest +from types import SimpleNamespace + +from content_parser.plugins.reddit.adapter import ( + _author_str, + _iso, + comment_to_core, + submission_to_item, +) + + +def _redditor(name: str | None) -> SimpleNamespace | None: + if name is None: + return None + return SimpleNamespace(name=name) + + +def _subreddit(name: str) -> SimpleNamespace: + return SimpleNamespace(display_name=name) + + +def _make_submission(**overrides): + base = dict( + id="abc123", + title="Some Post", + author=_redditor("alice"), + author_fullname="t2_111", + created_utc=1_700_000_000.0, + permalink="/r/python/comments/abc123/some_post/", + url="https://reddit.com/r/python/comments/abc123/some_post/", + is_self=True, + selftext="Body text.", + score=1234, + upvote_ratio=0.95, + num_comments=88, + num_crossposts=2, + subreddit=_subreddit("python"), + link_flair_text="Discussion", + is_video=False, + over_18=False, + spoiler=False, + locked=False, + domain="self.python", + ) + base.update(overrides) + return SimpleNamespace(**base) + + +class IsoTest(unittest.TestCase): + def test_iso(self): + self.assertTrue(_iso(1_700_000_000.0).startswith("2023-")) + + def test_iso_none(self): + self.assertIsNone(_iso(None)) + + +class AuthorTest(unittest.TestCase): + def test_redditor_object(self): + self.assertEqual(_author_str(_redditor("bob")), "bob") + + def test_deleted(self): + self.assertEqual(_author_str(None), "[deleted]") + + +class SubmissionToItemTest(unittest.TestCase): + def test_self_post_basic_fields(self): + item = submission_to_item(_make_submission()) + self.assertEqual(item.source, "reddit") + self.assertEqual(item.item_id, "abc123") + self.assertEqual(item.title, "Some Post") + self.assertEqual(item.author, "alice") + self.assertEqual(item.author_id, "t2_111") + self.assertEqual(item.url, "https://reddit.com/r/python/comments/abc123/some_post/") + self.assertEqual(item.text, "Body text.") + self.assertTrue(item.published_at.startswith("2023-")) + + def test_self_post_metrics(self): + item = submission_to_item(_make_submission()) + self.assertEqual(item.media["score"], 1234) + self.assertEqual(item.media["upvote_ratio"], 0.95) + self.assertEqual(item.media["num_comments"], 88) + self.assertEqual(item.media["subreddit"], "python") + self.assertEqual(item.media["flair"], "Discussion") + self.assertTrue(item.media["is_self"]) + + def test_self_post_does_not_set_url_external(self): + item = submission_to_item(_make_submission()) + self.assertNotIn("url_external", item.media) + + def test_link_post_sets_url_external(self): + sub = _make_submission( + is_self=False, + selftext="", + url="https://example.com/article", + domain="example.com", + ) + item = submission_to_item(sub) + self.assertEqual(item.media["url_external"], "https://example.com/article") + self.assertFalse(item.media["is_self"]) + self.assertIsNone(item.text) + + def test_deleted_author(self): + item = submission_to_item(_make_submission(author=None)) + self.assertEqual(item.author, "[deleted]") + + def test_nsfw_and_locked_flags(self): + item = submission_to_item(_make_submission(over_18=True, locked=True, spoiler=True)) + self.assertTrue(item.media["over_18"]) + self.assertTrue(item.media["locked"]) + self.assertTrue(item.media["spoiler"]) + + def test_strips_none_metrics(self): + sub = _make_submission() + sub.upvote_ratio = None # type: ignore[attr-defined] + item = submission_to_item(sub) + self.assertNotIn("upvote_ratio", item.media) + + def test_url_from_permalink_when_external_missing(self): + item = submission_to_item(_make_submission(url=None, is_self=True)) + self.assertEqual(item.url, "https://reddit.com/r/python/comments/abc123/some_post/") + + def test_awards_in_extra(self): + sub = _make_submission(all_awardings=[{"name": "Gold", "count": 2}]) + item = submission_to_item(sub) + self.assertEqual(item.extra["awards"], [{"name": "Gold", "count": 2}]) + + +class CommentToCoreTest(unittest.TestCase): + def test_basic(self): + c = SimpleNamespace( + id="cmt1", + author=_redditor("bob"), + author_fullname="t2_222", + body="great post", + score=42, + created_utc=1_700_000_500.0, + ) + out = comment_to_core(c, parent_id=None) + self.assertEqual(out.comment_id, "cmt1") + self.assertIsNone(out.parent_id) + self.assertEqual(out.author, "bob") + self.assertEqual(out.author_id, "t2_222") + self.assertEqual(out.text, "great post") + self.assertEqual(out.like_count, 42) + self.assertTrue(out.published_at.startswith("2023-")) + + def test_reply_carries_parent_id(self): + c = SimpleNamespace(id="cmt2", author=None, body="ok", score=0, created_utc=0) + out = comment_to_core(c, parent_id="cmt1") + self.assertEqual(out.parent_id, "cmt1") + self.assertEqual(out.author, "[deleted]") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reddit_plugin.py b/tests/test_reddit_plugin.py new file mode 100644 index 0000000..e49ea3f --- /dev/null +++ b/tests/test_reddit_plugin.py @@ -0,0 +1,375 @@ +"""Tests for content_parser.plugins.reddit.plugin — input validation + dispatch.""" +from __future__ import annotations + +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import PluginError +from content_parser.plugins.reddit.plugin import RedditPlugin + + +class NormalizeSubredditTest(unittest.TestCase): + def setUp(self): + self.p = RedditPlugin() + + def test_plain_name(self): + self.assertEqual(self.p._normalize_subreddit("python"), "python") + + def test_strips_r_prefix(self): + self.assertEqual(self.p._normalize_subreddit("r/python"), "python") + self.assertEqual(self.p._normalize_subreddit("R/Python"), "Python") + + def test_url(self): + self.assertEqual( + self.p._normalize_subreddit("https://reddit.com/r/MachineLearning/"), + "MachineLearning", + ) + + def test_rejects_bad_chars(self): + with self.assertRaises(PluginError): + self.p._normalize_subreddit("not a sub!") + + def test_rejects_user_url(self): + with self.assertRaises(PluginError): + self.p._normalize_subreddit("https://reddit.com/u/spez/") + + def test_rejects_non_reddit_host_in_url(self): + # path looks valid but host is wrong → should be rejected + with self.assertRaises(PluginError): + self.p._normalize_subreddit("https://evil.example/r/python/") + + +class NormalizeUserTest(unittest.TestCase): + def setUp(self): + self.p = RedditPlugin() + + def test_plain(self): + self.assertEqual(self.p._normalize_user("spez"), "spez") + + def test_u_prefix(self): + self.assertEqual(self.p._normalize_user("/u/spez"), "spez") + self.assertEqual(self.p._normalize_user("u/spez"), "spez") + self.assertEqual(self.p._normalize_user("/user/spez"), "spez") + + def test_url(self): + self.assertEqual(self.p._normalize_user("https://reddit.com/u/spez/"), "spez") + + def test_at_prefix(self): + self.assertEqual(self.p._normalize_user("@spez"), "spez") + + def test_rejects_too_short(self): + with self.assertRaises(PluginError): + self.p._normalize_user("ab") + + def test_rejects_non_reddit_host_in_url(self): + with self.assertRaises(PluginError): + self.p._normalize_user("https://evil.example/u/spez/") + + +class IsRedditPostUrlTest(unittest.TestCase): + def setUp(self): + self.p = RedditPlugin() + + def test_canonical_post_url(self): + self.assertTrue(self.p._is_reddit_post_url( + "https://www.reddit.com/r/python/comments/abc/title/" + )) + + def test_subreddit_listing_is_not_post(self): + self.assertFalse(self.p._is_reddit_post_url("https://reddit.com/r/python/")) + + def test_other_host_rejected(self): + self.assertFalse(self.p._is_reddit_post_url( + "https://example.com/r/python/comments/abc/title/" + )) + + def test_lookalike_host_rejected(self): + # 'reddit.com' as substring of a different domain must not pass. + self.assertFalse(self.p._is_reddit_post_url( + "https://evilreddit.com/r/python/comments/abc/title/" + )) + self.assertFalse(self.p._is_reddit_post_url( + "https://reddit.com.evil.example/r/python/comments/abc/title/" + )) + + def test_subdomain_accepted(self): + # Real Reddit subdomains like old.reddit.com should pass + self.assertTrue(self.p._is_reddit_post_url( + "https://old.reddit.com/r/python/comments/abc/title/" + )) + + +class ResolveTest(unittest.TestCase): + def setUp(self): + self.p = RedditPlugin() + + def test_specs_carry_kind_prefix(self): + specs = self.p.resolve( + { + "subreddit": ["python", "r/MachineLearning"], + "query": ["claude code"], + "post_url": ["https://www.reddit.com/r/python/comments/abc/title/"], + "user": ["/u/spez"], + }, + {}, + {"REDDIT_CLIENT_ID": "x", "REDDIT_CLIENT_SECRET": "y"}, + ) + kinds = [s.split(":", 1)[0] for s in specs] + self.assertEqual(kinds.count("subreddit"), 2) + self.assertEqual(kinds.count("query"), 1) + self.assertEqual(kinds.count("post_url"), 1) + self.assertEqual(kinds.count("user"), 1) + + def test_dedupes(self): + specs = self.p.resolve( + {"subreddit": ["python", "r/python", "PYTHON"]}, + {}, + {"REDDIT_CLIENT_ID": "x", "REDDIT_CLIENT_SECRET": "y"}, + ) + # python and r/python normalize to same; PYTHON stays as-is (case-sensitive on Reddit). + self.assertEqual(len(specs), 2) + + def test_rejects_listing_url_in_post_field(self): + with self.assertRaises(PluginError): + self.p.resolve( + {"post_url": ["https://reddit.com/r/python/"]}, + {}, + {"REDDIT_CLIENT_ID": "x", "REDDIT_CLIENT_SECRET": "y"}, + ) + + +class CommentCollectionTest(unittest.TestCase): + """Verify _collect_comments respects depth + max_comments + expand_more.""" + + def setUp(self): + self.p = RedditPlugin() + + def _comment(self, id, score=10, replies=None): + return SimpleNamespace( + id=id, + author=SimpleNamespace(name=f"u_{id}"), + author_fullname=f"t2_{id}", + body=f"text {id}", + score=score, + created_utc=1_700_000_000.0, + replies=replies or [], + ) + + def _submission_with_comments(self, comments, expand_called): + comments_obj = MagicMock() + comments_obj.__iter__ = lambda self_: iter(comments) + def replace_more(limit): + expand_called.append(limit) + comments_obj.replace_more = replace_more + return SimpleNamespace(comments=comments_obj) + + def test_top_level_only(self): + replies = [self._comment("r1"), self._comment("r2")] + top = [self._comment("c1", replies=replies), self._comment("c2")] + called = [] + sub = self._submission_with_comments(top, called) + + out = self.p._collect_comments(sub, max_comments=100, depth="top_level", expand_more=False) + ids = [c.comment_id for c in out] + self.assertEqual(ids, ["c1", "c2"]) + self.assertEqual(called, [0]) # replace_more(limit=0) when expand_more=False + + def test_all_depth_walks_replies(self): + replies = [self._comment("r1"), self._comment("r2")] + top = [self._comment("c1", replies=replies), self._comment("c2")] + called = [] + sub = self._submission_with_comments(top, called) + + out = self.p._collect_comments(sub, max_comments=100, depth="all", expand_more=False) + ids = [c.comment_id for c in out] + self.assertEqual(ids, ["c1", "r1", "r2", "c2"]) + # parent_id linkage + self.assertIsNone(out[0].parent_id) # c1 + self.assertEqual(out[1].parent_id, "c1") # r1 under c1 + self.assertEqual(out[2].parent_id, "c1") # r2 under c1 + self.assertIsNone(out[3].parent_id) # c2 + + def test_max_comments_caps(self): + top = [self._comment(f"c{i}") for i in range(10)] + called = [] + sub = self._submission_with_comments(top, called) + out = self.p._collect_comments(sub, max_comments=3, depth="top_level", expand_more=False) + self.assertEqual(len(out), 3) + + def test_expand_more_true_passes_capped_limit(self): + from content_parser.plugins.reddit.plugin import _MAX_REPLACE_MORE + + called = [] + sub = self._submission_with_comments([], called) + self.p._collect_comments(sub, max_comments=100, depth="top_level", expand_more=True) + # Should pass the hard cap, not None — unbounded expansion is unsafe. + self.assertEqual(called, [_MAX_REPLACE_MORE]) + self.assertNotIn(None, called) + + +class RedactSpecTest(unittest.TestCase): + """_redact_spec strips query/fragment and caps length so logs stay safe.""" + + def test_strips_query_string(self): + from content_parser.core.redact import redact_spec as _redact_spec + out = _redact_spec("post_url:https://reddit.com/r/x/?token=secret&foo=1") + self.assertNotIn("token", out) + self.assertNotIn("secret", out) + self.assertIn("?…", out) + + def test_strips_fragment(self): + from content_parser.core.redact import redact_spec as _redact_spec + out = _redact_spec("post_url:https://reddit.com/x#access_token=xxx") + self.assertNotIn("access_token", out) + self.assertNotIn("xxx", out) + self.assertIn("#…", out) + + def test_truncates_long(self): + from content_parser.core.redact import redact_spec as _redact_spec + spec = "subreddit:" + "a" * 200 + out = _redact_spec(spec) + self.assertLessEqual(len(out), 80) + self.assertTrue(out.endswith("…")) + + def test_short_unchanged(self): + from content_parser.core.redact import redact_spec as _redact_spec + self.assertEqual(_redact_spec("subreddit:python"), "subreddit:python") + + +class UserAgentWarningTest(unittest.TestCase): + """build_reddit warns when REDDIT_USER_AGENT is missing.""" + + def _fake_praw_module(self): + fake = MagicMock() + fake.Reddit.return_value = MagicMock() + return fake + + def test_warns_when_user_agent_missing(self): + from unittest.mock import patch + + fake = self._fake_praw_module() + with patch.dict("sys.modules", {"praw": fake}): + from content_parser.plugins.reddit.client import build_reddit, DEFAULT_USER_AGENT + + with self.assertLogs("content_parser.plugins.reddit.client", level="WARNING") as cm: + build_reddit({"REDDIT_CLIENT_ID": "x", "REDDIT_CLIENT_SECRET": "y"}) + + self.assertTrue(any("REDDIT_USER_AGENT" in line for line in cm.output)) + kwargs = fake.Reddit.call_args.kwargs + self.assertEqual(kwargs["user_agent"], DEFAULT_USER_AGENT) + + def test_no_warning_when_set(self): + from unittest.mock import patch + + fake = self._fake_praw_module() + with patch.dict("sys.modules", {"praw": fake}): + from content_parser.plugins.reddit.client import build_reddit + + with self.assertNoLogs("content_parser.plugins.reddit.client", level="WARNING"): + build_reddit({ + "REDDIT_CLIENT_ID": "x", + "REDDIT_CLIENT_SECRET": "y", + "REDDIT_USER_AGENT": "myapp:v1 by /u/me", + }) + + kwargs = fake.Reddit.call_args.kwargs + self.assertEqual(kwargs["user_agent"], "myapp:v1 by /u/me") + + def test_whitespace_user_agent_treated_as_empty(self): + from unittest.mock import patch + + fake = self._fake_praw_module() + with patch.dict("sys.modules", {"praw": fake}): + from content_parser.plugins.reddit.client import build_reddit, DEFAULT_USER_AGENT + + with self.assertLogs("content_parser.plugins.reddit.client", level="WARNING"): + build_reddit({ + "REDDIT_CLIENT_ID": "x", + "REDDIT_CLIENT_SECRET": "y", + "REDDIT_USER_AGENT": " ", + }) + + kwargs = fake.Reddit.call_args.kwargs + self.assertEqual(kwargs["user_agent"], DEFAULT_USER_AGENT) + + +class FetchAuthGuardTest(unittest.TestCase): + """fetch() raises AuthError early without secrets.""" + + def test_missing_client_id(self): + from content_parser.core.errors import AuthError + p = RedditPlugin() + with self.assertRaises(AuthError): + list(p.fetch(["subreddit:python"], {}, {"REDDIT_CLIENT_SECRET": "y"})) + + def test_missing_client_secret(self): + from content_parser.core.errors import AuthError + p = RedditPlugin() + with self.assertRaises(AuthError): + list(p.fetch(["subreddit:python"], {}, {"REDDIT_CLIENT_ID": "x"})) + + +class ListingDispatchTest(unittest.TestCase): + """_collect_submissions chooses the right PRAW listing method.""" + + def setUp(self): + self.p = RedditPlugin() + + def _reddit(self): + sub = MagicMock() + sub.top.return_value = [SimpleNamespace(id="t1")] + sub.hot.return_value = [SimpleNamespace(id="h1")] + sub.new.return_value = [SimpleNamespace(id="n1")] + sub.rising.return_value = [SimpleNamespace(id="ri1")] + sub.controversial.return_value = [SimpleNamespace(id="cn1")] + sub.search.return_value = [SimpleNamespace(id="s1")] + reddit = MagicMock() + reddit.subreddit.return_value = sub + return reddit, sub + + def test_subreddit_top(self): + reddit, sub = self._reddit() + out = list(self.p._collect_submissions(reddit, "subreddit", "python", "top", "month", 5)) + sub.top.assert_called_once_with(time_filter="month", limit=5) + self.assertEqual(out[0].id, "t1") + + def test_subreddit_new_does_not_pass_time_filter(self): + reddit, sub = self._reddit() + list(self.p._collect_submissions(reddit, "subreddit", "python", "new", "month", 5)) + sub.new.assert_called_once_with(limit=5) + + def test_query_uses_search_on_all(self): + reddit, sub = self._reddit() + list(self.p._collect_submissions(reddit, "query", "claude code", "top", "week", 5)) + reddit.subreddit.assert_called_with("all") + sub.search.assert_called_once() + + def test_post_url_uses_submission(self): + reddit, _ = self._reddit() + reddit.submission.return_value = SimpleNamespace(id="p1") + out = list(self.p._collect_submissions( + reddit, "post_url", "https://reddit.com/r/x/comments/abc/", "top", "month", 5 + )) + reddit.submission.assert_called_once_with(url="https://reddit.com/r/x/comments/abc/") + self.assertEqual(out[0].id, "p1") + + def test_user_rising_falls_back_to_new_with_log(self): + # User submissions have no .rising(), so 'rising' should silently use .new() + # but emit an INFO log so the user knows. + reddit_mock = MagicMock() + user_subs = MagicMock() + user_subs.new.return_value = [SimpleNamespace(id="un1")] + reddit_mock.redditor.return_value = SimpleNamespace(submissions=user_subs) + + with self.assertLogs("content_parser.plugins.reddit.plugin", level="INFO") as cm: + out = list(self.p._collect_submissions( + reddit_mock, "user", "spez", "rising", "month", 5 + )) + user_subs.new.assert_called_once_with(limit=5) + self.assertEqual(out[0].id, "un1") + self.assertTrue(any("rising" in m and "new" in m for m in cm.output)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 0000000..301c3e2 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,54 @@ +"""Tests for content_parser.core.runner — partial-run safety.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path +from typing import Iterator + +from content_parser.core.plugin import InputSpec, SourcePlugin +from content_parser.core.runner import run +from content_parser.core.schema import Item + + +class _CrashAfterTwo(SourcePlugin): + name = "crash" + label = "Crash" + secret_keys = [] + + def input_specs(self): return [InputSpec(kind="x", label="X")] + def settings_specs(self): return [] + def resolve(self, *a, **k): return ["a", "b", "c"] + + def fetch(self, ids, *a, **k) -> Iterator[Item]: + yield Item(source="crash", item_id="a", url="u/a", title="A") + yield Item(source="crash", item_id="b", url="u/b", title="B") + raise RuntimeError("boom on third") + + +class RunnerTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_run_")) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_partial_run_writes_summary_and_index_then_reraises(self): + plugin = _CrashAfterTwo() + with self.assertRaises(RuntimeError) as cm: + run(plugin, {"x": ["1"]}, {}, {}, output_dir=self.tmp, log=lambda _: None) + self.assertIn("boom", str(cm.exception)) + + files = {p.name for p in self.tmp.iterdir()} + self.assertIn("summary.csv", files) + self.assertIn("index.md", files) + + summary = (self.tmp / "summary.csv").read_text(encoding="utf-8") + # Both successful items should be in summary + self.assertIn(",a,A,", summary) + self.assertIn(",b,B,", summary) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_safe_filename.py b/tests/test_safe_filename.py new file mode 100644 index 0000000..c1a4aab --- /dev/null +++ b/tests/test_safe_filename.py @@ -0,0 +1,186 @@ +"""Tests for output._safe_filename and _file_stem path-traversal safety.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path + +from content_parser.core.output import ( + _csv_safe, + _file_stem, + _safe_filename, + write_item_json, + write_item_markdown, + write_summary_csv, +) +from content_parser.core.schema import Item + + +class CsvInjectionTest(unittest.TestCase): + """Excel/Sheets execute cells starting with =/+/-/@ as formulas.""" + + def test_equals_prefix_neutralized(self): + self.assertEqual(_csv_safe("=cmd|'/c calc'!A1"), "'=cmd|'/c calc'!A1") + + def test_plus_prefix_neutralized(self): + self.assertEqual(_csv_safe("+1+1"), "'+1+1") + + def test_minus_prefix_neutralized(self): + self.assertEqual(_csv_safe("-2+3"), "'-2+3") + + def test_at_prefix_neutralized(self): + self.assertEqual(_csv_safe("@SUM(A1:A10)"), "'@SUM(A1:A10)") + + def test_tab_and_cr_prefixes_neutralized(self): + self.assertEqual(_csv_safe("\t=evil"), "'\t=evil") + self.assertEqual(_csv_safe("\r=evil"), "'\r=evil") + + def test_safe_string_unchanged(self): + self.assertEqual(_csv_safe("normal title"), "normal title") + self.assertEqual(_csv_safe("123 abc"), "123 abc") + + def test_none_passthrough(self): + self.assertIsNone(_csv_safe(None)) + + def test_non_string_passthrough(self): + self.assertEqual(_csv_safe(42), 42) + self.assertEqual(_csv_safe(True), True) + + def test_empty_string_unchanged(self): + self.assertEqual(_csv_safe(""), "") + + def test_summary_csv_escapes_malicious_title(self): + import csv as _csv + import shutil + import tempfile + tmp = Path(tempfile.mkdtemp(prefix="cp_csv_")) + try: + item = Item( + source="reddit", item_id="abc", url="https://x", + title="=cmd|'/c calc'!A1", + author="@evil", + ) + path = write_summary_csv([item], tmp) + with path.open(encoding="utf-8") as f: + reader = _csv.DictReader(f) + row = next(reader) + self.assertTrue(row["title"].startswith("'=")) + self.assertTrue(row["author"].startswith("'@")) + finally: + shutil.rmtree(tmp, ignore_errors=True) + + +class SafeFilenameTest(unittest.TestCase): + def test_plain_text(self): + self.assertEqual(_safe_filename("Some Title"), "Some_Title") + + def test_strips_path_separators(self): + # forward slash is NOT \w, \s, or hyphen → removed + self.assertEqual(_safe_filename("a/b/c"), "abc") + + def test_strips_dots_so_dotdot_cannot_escape(self): + # The whole point: '..' resolves to nothing, no traversal possible + self.assertEqual(_safe_filename("../../etc/passwd"), "etcpasswd") + + def test_strips_backslash(self): + self.assertEqual(_safe_filename("a\\b\\c"), "abc") + + def test_strips_null_byte(self): + self.assertEqual(_safe_filename("name\x00.txt"), "nametxt") + + def test_keeps_unicode_word_chars(self): + # Russian text falls under \w with re.UNICODE + self.assertIn("Привет", _safe_filename("Привет мир")) + + def test_truncates_to_max_length(self): + out = _safe_filename("x" * 200, max_length=10) + self.assertEqual(len(out), 10) + + def test_empty_falls_back(self): + self.assertEqual(_safe_filename(""), "item") + self.assertEqual(_safe_filename("///"), "item") + + +class FileStemCollisionTest(unittest.TestCase): + """When item_id sanitizes to the fallback, a hash disambiguates.""" + + def test_collision_when_ids_sanitize_to_same_fallback(self): + a = Item(source="reddit", item_id="../../a", url="u", title="") + b = Item(source="reddit", item_id="../../b", url="u", title="") + self.assertNotEqual(_file_stem(a), _file_stem(b)) + + def test_normal_id_unchanged_no_hash(self): + item = Item(source="youtube", item_id="dQw4w9WgXcQ", url="u", title="x") + stem = _file_stem(item) + # stem should not contain the 'item-' fallback hash prefix + self.assertNotIn("item-", stem) + self.assertIn("dQw4w9WgXcQ", stem) + + def test_fallback_id_gets_stable_hash(self): + item = Item(source="reddit", item_id="..", url="u", title="t") + stem1 = _file_stem(item) + stem2 = _file_stem(item) + self.assertEqual(stem1, stem2) + self.assertIn("item-", stem1) + + +class FileStemPathTraversalTest(unittest.TestCase): + """Even if an upstream API returns malicious source/item_id, files stay in out_dir.""" + + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_traverse_")) + + def tearDown(self): + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_stem_strips_traversal_in_item_id(self): + item = Item( + source="reddit", + item_id="../../etc/passwd", + url="https://example", + title="bad", + ) + stem = _file_stem(item) + self.assertNotIn("..", stem) + self.assertNotIn("/", stem) + self.assertNotIn("\\", stem) + + def test_stem_strips_traversal_in_source(self): + item = Item( + source="../../malicious", + item_id="abc", + url="https://example", + title="ok", + ) + stem = _file_stem(item) + self.assertNotIn("..", stem) + self.assertNotIn("/", stem) + + def test_write_item_json_stays_inside_out_dir(self): + item = Item( + source="reddit", + item_id="../../escape", + url="https://example", + title="x", + ) + path = write_item_json(item, self.tmp) + # Resolved path must still be a child of out_dir + self.assertTrue( + path.resolve().is_relative_to(self.tmp.resolve()), + f"Wrote to {path.resolve()} which escapes {self.tmp.resolve()}", + ) + + def test_write_item_markdown_stays_inside_out_dir(self): + item = Item( + source="../traverse", + item_id="../escape", + url="https://example", + title="..", + ) + path = write_item_markdown(item, self.tmp) + self.assertTrue(path.resolve().is_relative_to(self.tmp.resolve())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_secrets.py b/tests/test_secrets.py new file mode 100644 index 0000000..f31bfaf --- /dev/null +++ b/tests/test_secrets.py @@ -0,0 +1,85 @@ +"""Tests for content_parser.core.secrets — TOML upsert/escape/remove.""" +from __future__ import annotations + +import shutil +import tempfile +import tomllib +import unittest +from pathlib import Path + +from content_parser.core import secrets as s + + +class TomlUpsertTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_sec_")) + self._orig_secrets = s.SECRETS_PATH + s.SECRETS_PATH = self.tmp / ".streamlit" / "secrets.toml" + + def tearDown(self): + s.SECRETS_PATH = self._orig_secrets + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_escape_quote_and_backslash(self): + self.assertEqual(s._toml_escape('plain'), 'plain') + self.assertEqual(s._toml_escape('with"quote'), 'with\\"quote') + self.assertEqual(s._toml_escape('with\\backslash'), 'with\\\\backslash') + + def test_fresh_write_is_valid_toml(self): + s._upsert_secrets_toml("KEY", "value") + parsed = tomllib.loads(s.SECRETS_PATH.read_text(encoding="utf-8")) + self.assertEqual(parsed["KEY"], "value") + + def test_value_with_quotes_round_trips(self): + tricky = 'a"b\\c' + s._upsert_secrets_toml("TRICKY", tricky) + parsed = tomllib.loads(s.SECRETS_PATH.read_text(encoding="utf-8")) + self.assertEqual(parsed["TRICKY"], tricky) + + def test_replace_preserves_other_keys(self): + s.SECRETS_PATH.parent.mkdir(parents=True, exist_ok=True) + s.SECRETS_PATH.write_text( + 'OTHER = "keep"\nKEY = "old"\nMORE = "also"\n', encoding="utf-8" + ) + s._upsert_secrets_toml("KEY", "new") + text = s.SECRETS_PATH.read_text(encoding="utf-8") + self.assertIn("OTHER", text) + self.assertIn("MORE", text) + self.assertIn('KEY = "new"', text) + self.assertNotIn('"old"', text) + + def test_remove_preserves_others(self): + s.SECRETS_PATH.parent.mkdir(parents=True, exist_ok=True) + s.SECRETS_PATH.write_text( + 'OTHER = "keep"\nKEY = "x"\n', encoding="utf-8" + ) + s._remove_from_secrets_toml("KEY") + text = s.SECRETS_PATH.read_text(encoding="utf-8") + self.assertNotIn("KEY", text) + self.assertIn("OTHER", text) + + def test_remove_sole_key_deletes_file(self): + s.SECRETS_PATH.parent.mkdir(parents=True, exist_ok=True) + s.SECRETS_PATH.write_text('KEY = "x"\n', encoding="utf-8") + s._remove_from_secrets_toml("KEY") + self.assertFalse(s.SECRETS_PATH.exists()) + + def test_streamlit_cloud_skip_writes(self): + from unittest.mock import patch + # On Streamlit Cloud, secrets.toml is read-only; we must NOT touch it. + with patch.dict("os.environ", {"STREAMLIT_RUNTIME": "cloud"}): + self.assertTrue(s._is_streamlit_cloud()) + s._upsert_secrets_toml("KEY", "value") + # File should not have been created + self.assertFalse(s.SECRETS_PATH.exists()) + + def test_local_environment_writes_normally(self): + from unittest.mock import patch + with patch.dict("os.environ", {}, clear=True): + self.assertFalse(s._is_streamlit_cloud()) + s._upsert_secrets_toml("KEY", "value") + self.assertTrue(s.SECRETS_PATH.exists()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_telegram_adapter.py b/tests/test_telegram_adapter.py new file mode 100644 index 0000000..b7f5ce3 --- /dev/null +++ b/tests/test_telegram_adapter.py @@ -0,0 +1,256 @@ +"""Tests for content_parser.plugins.telegram.adapter — defensive field mapping.""" +from __future__ import annotations + +import unittest + +from content_parser.plugins.telegram.adapter import ( + _pick, + _reactions_total, + _replies_count, + _to_int, + message_to_item, +) + + +class PickTest(unittest.TestCase): + def test_first_present_key(self): + self.assertEqual(_pick({"a": 1, "b": 2}, "a", "b"), 1) + self.assertEqual(_pick({"b": 2}, "a", "b"), 2) + + def test_skips_none_and_empty(self): + self.assertEqual(_pick({"a": None, "b": "", "c": 5}, "a", "b", "c"), 5) + + def test_default(self): + self.assertEqual(_pick({}, "a", default=42), 42) + + +class ToIntTest(unittest.TestCase): + def test_int_passthrough(self): + self.assertEqual(_to_int(42), 42) + + def test_float_truncates(self): + self.assertEqual(_to_int(3.7), 3) + + def test_string_int(self): + self.assertEqual(_to_int("100"), 100) + + def test_none(self): + self.assertIsNone(_to_int(None)) + + def test_dict_returns_none(self): + # Critical: a stray dict like {"count": 100} from a different actor schema + # must not leak into media as a number. + self.assertIsNone(_to_int({"count": 100})) + + def test_bool_returns_none(self): + # Avoid True being silently treated as 1 + self.assertIsNone(_to_int(True)) + + def test_garbage_string(self): + self.assertIsNone(_to_int("not a number")) + + +class RepliesCountTest(unittest.TestCase): + def test_int_replies(self): + # Critical: many actors return 'replies: 42' (just a number) + self.assertEqual(_replies_count({"replies": 42}), 42) + + def test_list_replies_uses_length(self): + self.assertEqual(_replies_count({"replies": [{"id": 1}, {"id": 2}]}), 2) + + def test_alt_keys(self): + self.assertEqual(_replies_count({"repliesCount": 7}), 7) + self.assertEqual(_replies_count({"replies_count": 8}), 8) + self.assertEqual(_replies_count({"commentsCount": 9}), 9) + + def test_missing(self): + self.assertIsNone(_replies_count({})) + + +class ReactionsTotalTest(unittest.TestCase): + def test_list_of_dicts(self): + self.assertEqual( + _reactions_total([{"emoji": "👍", "count": 5}, {"emoji": "❤️", "count": 3}]), + 8, + ) + + def test_dict_emoji_to_count(self): + self.assertEqual(_reactions_total({"👍": 5, "❤️": 3}), 8) + + def test_int_passthrough(self): + self.assertEqual(_reactions_total(42), 42) + + def test_none(self): + self.assertIsNone(_reactions_total(None)) + + def test_empty_list(self): + self.assertIsNone(_reactions_total([])) + + +SAMPLE_MESSAGE = { + "id": 123, + "channelUsername": "durov", + "channelTitle": "Pavel Durov", + "channelId": "-1001", + "date": "2024-01-15T10:00:00.000Z", + "text": "Big news today\nDetails inside", + "url": "https://t.me/durov/123", + "views": 500_000, + "forwards": 1234, + "replies": 89, + "reactions": [ + {"emoji": "👍", "count": 1000}, + {"emoji": "❤️", "count": 500}, + ], + "media": {"type": "photo", "url": "https://cdn.example/p.jpg"}, + "isPinned": True, +} + + +class MessageToItemTest(unittest.TestCase): + def test_basic_fields(self): + item = message_to_item(SAMPLE_MESSAGE) + self.assertEqual(item.source, "telegram") + self.assertEqual(item.item_id, "durov_123") + self.assertEqual(item.url, "https://t.me/durov/123") + self.assertEqual(item.title, "Big news today") + self.assertEqual(item.author, "Pavel Durov") + self.assertEqual(item.author_id, "durov") + self.assertIn("Big news today", item.text) + + def test_metrics(self): + item = message_to_item(SAMPLE_MESSAGE) + self.assertEqual(item.media["views_count"], 500_000) + self.assertEqual(item.media["forwards_count"], 1234) + self.assertEqual(item.media["reactions_count"], 1500) + self.assertEqual(item.media["media_type"], "photo") + self.assertTrue(item.media["is_pinned"]) + + def test_zero_views_kept(self): + # 0 is a meaningful signal (just-published or banned), not a missing value + msg = dict(SAMPLE_MESSAGE) + msg["views"] = 0 + item = message_to_item(msg) + self.assertEqual(item.media["views_count"], 0) + + def test_falls_back_to_alternative_field_names(self): + msg = { + "messageId": 99, + "chatUsername": "somechan", + "chatTitle": "Some Channel", + "timestamp": "2024-02-01T00:00:00Z", + "message": "alt format text", + "view_count": 100, + "forward_count": 5, + } + item = message_to_item(msg) + self.assertEqual(item.item_id, "somechan_99") + self.assertEqual(item.author, "Some Channel") + self.assertEqual(item.author_id, "somechan") + self.assertEqual(item.text, "alt format text") + self.assertEqual(item.media["views_count"], 100) + self.assertEqual(item.media["forwards_count"], 5) + + def test_url_constructed_when_missing(self): + msg = dict(SAMPLE_MESSAGE) + del msg["url"] + item = message_to_item(msg) + self.assertEqual(item.url, "https://t.me/durov/123") + + def test_raises_on_missing_id(self): + with self.assertRaises(ValueError): + message_to_item({"channelUsername": "durov", "text": "x"}) + + def test_no_username_falls_back_to_id_only(self): + msg = {"id": 123, "text": "x"} + item = message_to_item(msg) + self.assertEqual(item.item_id, "123") + + def test_extracts_inline_comments(self): + msg = dict(SAMPLE_MESSAGE) + msg["replies_data"] = [ + { + "id": 1, + "from": {"username": "fan1", "first_name": "Fan", "last_name": "One"}, + "text": "great post", + "date": "2024-01-15T11:00:00Z", + "reactions": [{"emoji": "👍", "count": 5}], + }, + { + "id": 2, + "from": {"username": "fan2"}, + "text": "agreed", + "date": "2024-01-15T11:30:00Z", + }, + ] + item = message_to_item(msg) + self.assertEqual(len(item.comments), 2) + c1, c2 = item.comments + self.assertEqual(c1.text, "great post") + self.assertEqual(c1.author, "Fan One") + self.assertEqual(c1.author_id, "fan1") + self.assertEqual(c1.like_count, 5) + self.assertEqual(c2.author, "fan2") # name fallback to username + + def test_no_comments_when_replies_is_int(self): + # Some actors return 'replies' as an int count, not a list. + msg = dict(SAMPLE_MESSAGE) + msg["replies_data"] = 42 + item = message_to_item(msg) + self.assertEqual(item.comments, []) + + def test_replies_int_populates_comments_count(self): + # 'replies: 42' (no comment list) → comments_count=42, comments=[] + msg = dict(SAMPLE_MESSAGE) + msg["replies"] = 42 + # remove the explicit replies_data so we hit the int path + msg.pop("replies_data", None) + item = message_to_item(msg) + self.assertEqual(item.media["comments_count"], 42) + self.assertEqual(item.comments, []) + + def test_replies_as_list_populates_both(self): + msg = dict(SAMPLE_MESSAGE) + msg["replies"] = [ + {"id": 1, "from": {"username": "u1"}, "text": "x"}, + {"id": 2, "from": {"username": "u2"}, "text": "y"}, + ] + msg.pop("replies_data", None) + item = message_to_item(msg) + self.assertEqual(item.media["comments_count"], 2) + self.assertEqual(len(item.comments), 2) + + def test_dict_views_does_not_leak(self): + # Defensive: if an actor returns views as a dict (unusual), do NOT + # silently store it in media — coerce to None. + msg = dict(SAMPLE_MESSAGE) + msg["views"] = {"count": 999} + item = message_to_item(msg) + self.assertNotIn("views_count", item.media) + + def test_reply_tree_populates_parent_id(self): + msg = dict(SAMPLE_MESSAGE) + msg["replies_data"] = [ + {"id": 1, "from": {"username": "u1"}, "text": "top"}, + {"id": 2, "from": {"username": "u2"}, "text": "reply", "reply_to_message_id": 1}, + {"id": 3, "from": {"username": "u3"}, "text": "reply2", "replyToMessageId": 1}, + ] + item = message_to_item(msg) + ids_to_parent = {c.comment_id: c.parent_id for c in item.comments} + self.assertIsNone(ids_to_parent["1"]) + self.assertEqual(ids_to_parent["2"], "1") + self.assertEqual(ids_to_parent["3"], "1") + + def test_reply_to_outside_batch_stays_top_level(self): + # If reply_to_message_id points to a message NOT in this batch, + # treat as top-level (we have no context to chain it to). + msg = dict(SAMPLE_MESSAGE) + msg["replies_data"] = [ + {"id": 5, "from": {"username": "u"}, "text": "x", "reply_to_message_id": 99999}, + ] + item = message_to_item(msg) + self.assertIsNone(item.comments[0].parent_id) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_telegram_plugin.py b/tests/test_telegram_plugin.py new file mode 100644 index 0000000..690a49f --- /dev/null +++ b/tests/test_telegram_plugin.py @@ -0,0 +1,369 @@ +"""Tests for content_parser.plugins.telegram.plugin — input validation + dispatch.""" +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import AuthError, PluginError +from content_parser.core.redact import redact_spec as _redact_spec +from content_parser.plugins.telegram.plugin import ( + TelegramPlugin, + _is_tg_host, +) + + +class IsTgHostTest(unittest.TestCase): + def test_main_domains(self): + self.assertTrue(_is_tg_host("t.me")) + self.assertTrue(_is_tg_host("telegram.me")) + self.assertTrue(_is_tg_host("T.ME")) + + def test_lookalike_rejected(self): + self.assertFalse(_is_tg_host("evilt.me")) + self.assertFalse(_is_tg_host("t.me.evil.example")) + + def test_empty(self): + self.assertFalse(_is_tg_host("")) + + +class NormalizeChannelTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + + def test_plain_username(self): + self.assertEqual(self.p._normalize_channel("durov"), "durov") + + def test_at_handle(self): + self.assertEqual(self.p._normalize_channel("@telegram"), "telegram") + + def test_url(self): + self.assertEqual( + self.p._normalize_channel("https://t.me/durov"), + "durov", + ) + + def test_telegram_me_alias(self): + self.assertEqual( + self.p._normalize_channel("https://telegram.me/durov"), + "durov", + ) + + def test_rejects_non_tg_host(self): + with self.assertRaises(PluginError): + self.p._normalize_channel("https://evil.example/durov") + + def test_rejects_reserved_path(self): + with self.assertRaises(PluginError): + self.p._normalize_channel("https://t.me/joinchat") + with self.assertRaises(PluginError): + self.p._normalize_channel("https://t.me/proxy") + + def test_rejects_too_short(self): + with self.assertRaises(PluginError): + self.p._normalize_channel("dur") # Telegram usernames are 5+ chars + + def test_rejects_starting_with_digit(self): + with self.assertRaises(PluginError): + self.p._normalize_channel("123channel") + + def test_rejects_invalid_chars(self): + with self.assertRaises(PluginError): + self.p._normalize_channel("name with spaces") + + def test_rejects_empty(self): + with self.assertRaises(PluginError): + self.p._normalize_channel(" ") + + +class ExtractPostUrlTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + + def test_canonical_post_url(self): + self.assertEqual( + self.p._extract_post_url("https://t.me/durov/123"), + "https://t.me/durov/123", + ) + + def test_telegram_me_normalized(self): + self.assertEqual( + self.p._extract_post_url("https://telegram.me/durov/123"), + "https://t.me/durov/123", + ) + + def test_with_query_string(self): + # The 's' subdomain or query params should be tolerated; + # we extract just channel + msg id. + self.assertEqual( + self.p._extract_post_url("https://t.me/durov/123?single"), + "https://t.me/durov/123", + ) + + def test_non_tg_host_rejected(self): + self.assertIsNone( + self.p._extract_post_url("https://evil.example/durov/123") + ) + + def test_channel_listing_rejected(self): + # No message id → not a post + self.assertIsNone(self.p._extract_post_url("https://t.me/durov")) + + def test_private_c_path_rejected(self): + # /c/<chat_id>/<msg_id> is private channel — Apify scrapers usually can't read it + self.assertIsNone(self.p._extract_post_url("https://t.me/c/123/456")) + + def test_reserved_path_rejected(self): + self.assertIsNone(self.p._extract_post_url("https://t.me/joinchat/abc")) + + +class RedactSpecTest(unittest.TestCase): + def test_strips_query(self): + out = _redact_spec("post:https://t.me/durov/123?token=secret") + self.assertNotIn("secret", out) + + def test_strips_fragment(self): + out = _redact_spec("post:https://t.me/durov/123#access_token=xxx") + self.assertNotIn("access_token", out) + + def test_truncates(self): + self.assertLessEqual(len(_redact_spec("channel:" + "x" * 200)), 80) + + +class ResolveTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + self.secrets = {"APIFY_API_TOKEN": "x"} + + def test_specs_carry_kind_prefix(self): + specs = self.p.resolve( + { + "channel": ["durov", "@telegram"], + "post_url": ["https://t.me/durov/123"], + }, + {}, + self.secrets, + ) + kinds = [s.split(":", 1)[0] for s in specs] + self.assertEqual(kinds.count("channel"), 2) + self.assertEqual(kinds.count("post"), 1) + + def test_dedupe(self): + specs = self.p.resolve( + {"channel": ["durov", "@durov", "https://t.me/durov"]}, + {}, + self.secrets, + ) + self.assertEqual(len(specs), 1) + + def test_rejects_listing_url_in_post_field(self): + with self.assertRaises(PluginError): + self.p.resolve( + {"post_url": ["https://t.me/durov"]}, + {}, + self.secrets, + ) + + +class FetchAuthGuardTest(unittest.TestCase): + def test_missing_token_raises(self): + p = TelegramPlugin() + with self.assertRaises(AuthError): + list(p.fetch(["channel:durov"], {}, {})) + + +class ActorIdValidationTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + + def _patch_client(self): + return patch("content_parser.plugins.telegram.plugin.ApifyClient") + + def test_default_used_when_empty(self): + with self._patch_client() as MC: + MC.return_value.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov"], {"actor_id": ""}, + {"APIFY_API_TOKEN": "x"}, + )) + actor_id, _ = MC.return_value.run_actor.call_args[0] + self.assertEqual(actor_id, "apify/telegram-channel-scraper") + + def test_default_used_when_whitespace(self): + with self._patch_client() as MC: + MC.return_value.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov"], {"actor_id": " "}, + {"APIFY_API_TOKEN": "x"}, + )) + actor_id, _ = MC.return_value.run_actor.call_args[0] + self.assertEqual(actor_id, "apify/telegram-channel-scraper") + + def test_valid_username_actor_form(self): + with self._patch_client() as MC: + MC.return_value.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov"], {"actor_id": "73code/telegram-scraper"}, + {"APIFY_API_TOKEN": "x"}, + )) + actor_id, _ = MC.return_value.run_actor.call_args[0] + self.assertEqual(actor_id, "73code/telegram-scraper") + + def test_valid_tilde_form(self): + with self._patch_client() as MC: + MC.return_value.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov"], {"actor_id": "user~actor"}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertEqual(MC.return_value.run_actor.call_args[0][0], "user~actor") + + def test_garbage_actor_id_raises(self): + with self._patch_client() as MC: + MC.return_value.run_actor.return_value = [] + for bad in ("noslash", "/missing", "missing/", "has spaces/x", "../../etc"): + with self.subTest(bad=bad): + with self.assertRaises(PluginError): + list(self.p.fetch( + ["channel:durov"], {"actor_id": bad}, + {"APIFY_API_TOKEN": "x"}, + )) + + +class ApifyErrorMappingTest(unittest.TestCase): + """ApifyError from the underlying client is wrapped in PluginError.""" + + def setUp(self): + self.p = TelegramPlugin() + + def test_channels_call_failure(self): + from content_parser.clients.apify import ApifyError + with patch("content_parser.plugins.telegram.plugin.ApifyClient") as MC: + MC.return_value.run_actor.side_effect = ApifyError("simulated failure") + with self.assertRaises(PluginError) as cm: + list(self.p.fetch( + ["channel:durov"], {}, {"APIFY_API_TOKEN": "x"}, + )) + self.assertIn("channels", str(cm.exception)) + self.assertIn("simulated failure", str(cm.exception)) + + def test_posts_call_failure(self): + from content_parser.clients.apify import ApifyError + with patch("content_parser.plugins.telegram.plugin.ApifyClient") as MC: + MC.return_value.run_actor.side_effect = ApifyError("posts went bad") + with self.assertRaises(PluginError) as cm: + list(self.p.fetch( + ["post:https://t.me/durov/1"], {}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertIn("posts", str(cm.exception)) + + +class PrivateChannelRejectTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + + def test_resolve_explicit_error_for_private_url(self): + with self.assertRaises(PluginError) as cm: + self.p.resolve( + {"post_url": ["https://t.me/c/123/456"]}, + {}, + {"APIFY_API_TOKEN": "x"}, + ) + msg = str(cm.exception) + self.assertIn("private", msg.lower()) + # Must mention /c/ pattern so user understands what's wrong + self.assertTrue("/c/" in msg or "private-channel" in msg.lower()) + + def test_is_private_channel_url_helper(self): + self.assertTrue(self.p._is_private_channel_url("https://t.me/c/123/456")) + self.assertFalse(self.p._is_private_channel_url("https://t.me/durov/123")) + self.assertFalse(self.p._is_private_channel_url("https://evil.example/c/1/2")) + + +class FetchDispatchTest(unittest.TestCase): + def setUp(self): + self.p = TelegramPlugin() + + def _patch_client(self): + return patch("content_parser.plugins.telegram.plugin.ApifyClient") + + def test_channels_only_makes_one_actor_call(self): + with self._patch_client() as MC: + inst = MC.return_value + inst.run_actor.return_value = [ + {"id": 1, "channelUsername": "durov", "text": "x", "url": "https://t.me/durov/1"}, + ] + list(self.p.fetch( + ["channel:durov", "channel:telegram"], + {"max_messages_per_channel": 5}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertEqual(inst.run_actor.call_count, 1) + actor_id, actor_input = inst.run_actor.call_args[0] + self.assertIn("https://t.me/durov", actor_input["urls"]) + self.assertIn("https://t.me/telegram", actor_input["urls"]) + + def test_post_urls_make_separate_actor_call(self): + with self._patch_client() as MC: + inst = MC.return_value + inst.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov", "post:https://t.me/durov/123"], + {"max_messages_per_channel": 5}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertEqual(inst.run_actor.call_count, 2) + + def test_actor_id_overridable(self): + with self._patch_client() as MC: + inst = MC.return_value + inst.run_actor.return_value = [] + list(self.p.fetch( + ["channel:durov"], + {"actor_id": "73code/telegram-scraper"}, + {"APIFY_API_TOKEN": "x"}, + )) + actor_id, _ = inst.run_actor.call_args[0] + self.assertEqual(actor_id, "73code/telegram-scraper") + + def test_dedupes_messages_by_item_id(self): + same_msg = { + "id": 1, "channelUsername": "durov", "text": "x", + "url": "https://t.me/durov/1", + } + with self._patch_client() as MC: + inst = MC.return_value + # Same message returned in both channel and post calls + inst.run_actor.side_effect = [[same_msg], [same_msg]] + items = list(self.p.fetch( + ["channel:durov", "post:https://t.me/durov/1"], + {"max_messages_per_channel": 5}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertEqual(len(items), 1) + + def test_caps_comment_count(self): + msg = { + "id": 1, + "channelUsername": "durov", + "text": "x", + "url": "https://t.me/durov/1", + "replies_data": [ + {"id": i, "from": {"username": f"u{i}"}, "text": f"c{i}"} + for i in range(50) + ], + } + with self._patch_client() as MC: + inst = MC.return_value + inst.run_actor.return_value = [msg] + items = list(self.p.fetch( + ["channel:durov"], + {"max_messages_per_channel": 5, "max_comments_per_post": 10}, + {"APIFY_API_TOKEN": "x"}, + )) + self.assertEqual(len(items), 1) + self.assertEqual(len(items[0].comments), 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transcription_cache.py b/tests/test_transcription_cache.py new file mode 100644 index 0000000..d14baea --- /dev/null +++ b/tests/test_transcription_cache.py @@ -0,0 +1,64 @@ +"""Tests for content_parser.transcription.cache — disk cache CRUD.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path + +from content_parser.transcription import cache as cache_mod + + +class CacheTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_cache_")) + self._orig = cache_mod.CACHE_DIR + cache_mod.CACHE_DIR = self.tmp + + def tearDown(self): + cache_mod.CACHE_DIR = self._orig + shutil.rmtree(self.tmp, ignore_errors=True) + + def test_get_missing_returns_none(self): + self.assertIsNone(cache_mod.get("instagram", "AAA")) + + def test_put_then_get_round_trips(self): + data = { + "language": "ru", "is_generated": True, + "segments": [{"start": 0.0, "duration": 2.0, "text": "hi"}], + "text": "hi", "error": None, + } + cache_mod.put("instagram", "AAA", data) + loaded = cache_mod.get("instagram", "AAA") + self.assertEqual(loaded, data) + + def test_safe_filename_for_unsafe_id(self): + data = {"language": "ru", "is_generated": True, "segments": [], "text": "x", "error": None} + cache_mod.put("instagram", "../../etc/passwd", data) + files = list(self.tmp.glob("*.json")) + self.assertEqual(len(files), 1) + self.assertNotIn("..", files[0].name) + self.assertNotIn("/", files[0].name) + + def test_clear_removes_all(self): + cache_mod.put("instagram", "a", {"text": "1"}) + cache_mod.put("vk", "b", {"text": "2"}) + n = cache_mod.clear() + self.assertEqual(n, 2) + self.assertIsNone(cache_mod.get("instagram", "a")) + + def test_atomic_write_no_tmp_left_after_success(self): + cache_mod.put("instagram", "x", {"text": "ok"}) + # No .tmp sibling should remain + tmps = list(self.tmp.glob("*.tmp")) + self.assertEqual(tmps, []) + + def test_existing_value_replaced_atomically(self): + cache_mod.put("instagram", "x", {"text": "first"}) + cache_mod.put("instagram", "x", {"text": "second"}) + loaded = cache_mod.get("instagram", "x") + self.assertEqual(loaded["text"], "second") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transcription_runner.py b/tests/test_transcription_runner.py new file mode 100644 index 0000000..80fb1a1 --- /dev/null +++ b/tests/test_transcription_runner.py @@ -0,0 +1,303 @@ +"""Tests for content_parser.transcription.runner — the maybe_transcribe orchestrator.""" +from __future__ import annotations + +import shutil +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch + +from content_parser.core.schema import Item, Transcript +from content_parser.transcription import cache as cache_mod +from content_parser.transcription import runner as runner_mod +from content_parser.transcription.runner import _is_public_url + + +class IsPublicUrlTest(unittest.TestCase): + def test_normal_https_url(self): + self.assertTrue(_is_public_url("https://www.instagram.com/p/AAA/")) + self.assertTrue(_is_public_url("https://t.me/durov/123")) + self.assertTrue(_is_public_url("https://cdn.example.com/v.mp4")) + + def test_http_also_ok(self): + self.assertTrue(_is_public_url("http://example.com/v.mp4")) + + def test_other_schemes_rejected(self): + self.assertFalse(_is_public_url("file:///etc/passwd")) + self.assertFalse(_is_public_url("ftp://server/x")) + self.assertFalse(_is_public_url("data:text/plain,hello")) + + def test_localhost_rejected(self): + self.assertFalse(_is_public_url("http://localhost/x")) + self.assertFalse(_is_public_url("http://LOCALHOST/x")) + self.assertFalse(_is_public_url("http://127.0.0.1/x")) + self.assertFalse(_is_public_url("http://0.0.0.0/x")) + + def test_private_rfc1918_rejected(self): + self.assertFalse(_is_public_url("http://10.0.0.1/x")) + self.assertFalse(_is_public_url("http://10.255.255.255/x")) + self.assertFalse(_is_public_url("http://172.16.0.1/x")) + self.assertFalse(_is_public_url("http://192.168.1.1/x")) + + def test_link_local_rejected(self): + # AWS metadata endpoint + self.assertFalse(_is_public_url("http://169.254.169.254/latest/meta-data/")) + + def test_ipv6_loopback_rejected(self): + self.assertFalse(_is_public_url("http://[::1]/x")) + + def test_ipv6_private_rejected(self): + # fc00::/7 is unique-local addresses + self.assertFalse(_is_public_url("http://[fc00::1]/x")) + + def test_empty_or_invalid(self): + self.assertFalse(_is_public_url("")) + self.assertFalse(_is_public_url("not a url")) + self.assertFalse(_is_public_url(None)) # type: ignore[arg-type] + + def test_dns_name_passes(self): + # We can't catch DNS-rebinding without resolution; that's by design. + # Any non-IP-literal hostname is accepted at this layer. + self.assertTrue(_is_public_url("http://attacker.example/x")) + + +WHISPER_RESPONSE = { + "text": "Hello world", + "language": "en", + "segments": [ + {"id": 0, "start": 0.0, "end": 2.5, "text": "Hello world"}, + {"id": 1, "start": 2.5, "end": 5.0, "text": "more text"}, + ], +} + + +class MaybeTranscribeTest(unittest.TestCase): + def setUp(self): + self.tmp = Path(tempfile.mkdtemp(prefix="cp_run_")) + self._orig = cache_mod.CACHE_DIR + cache_mod.CACHE_DIR = self.tmp / "cache" + + def tearDown(self): + cache_mod.CACHE_DIR = self._orig + shutil.rmtree(self.tmp, ignore_errors=True) + + def _item(self, **kw): + defaults = dict( + source="instagram", item_id="AAA", + url="https://www.instagram.com/p/AAA/", + media={"video_url": "https://cdn.example/v.mp4"}, + ) + defaults.update(kw) + return Item(**defaults) + + def test_disabled_setting_does_nothing(self): + item = self._item() + runner_mod.maybe_transcribe(item, settings={"transcribe_videos": False}, secrets={}) + self.assertIsNone(item.transcript) + + def test_no_api_key_sets_error_transcript(self): + item = self._item() + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": ""}, + ) + self.assertIsNotNone(item.transcript) + self.assertIn("OPENAI_API_KEY", item.transcript.error) + self.assertEqual(item.transcript.segments, []) + + def test_uses_cache_skipping_download_and_api(self): + cache_mod.put("instagram", "AAA", { + "language": "ru", "is_generated": True, + "segments": [{"start": 0.0, "duration": 1.0, "text": "cached"}], + "text": "cached", "error": None, + }) + item = self._item() + with patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio") as ta: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + dl.assert_not_called() + ta.assert_not_called() + self.assertEqual(item.transcript.text, "cached") + + def test_full_pipeline_downloads_transcribes_caches(self): + item = self._item() + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=30.0), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio", return_value=WHISPER_RESPONSE): + audio_path = self.tmp / "fake.mp3" + audio_path.write_bytes(b"\x00") + dl.return_value = audio_path + + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True, "max_audio_seconds_per_video": 600}, + secrets={"OPENAI_API_KEY": "k"}, + ) + + self.assertIsNotNone(item.transcript) + self.assertEqual(item.transcript.text, "Hello world") + self.assertEqual(item.transcript.language, "en") + self.assertEqual(len(item.transcript.segments), 2) + self.assertEqual(item.transcript.segments[0]["start"], 0.0) + self.assertEqual(item.transcript.segments[0]["duration"], 2.5) + # Cache populated + cached = cache_mod.get("instagram", "AAA") + self.assertIsNotNone(cached) + self.assertEqual(cached["text"], "Hello world") + + def test_duration_cap_skips_download(self): + item = self._item() + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=1200.0), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio") as ta: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True, "max_audio_seconds_per_video": 600}, + secrets={"OPENAI_API_KEY": "k"}, + ) + dl.assert_not_called() + ta.assert_not_called() + self.assertIn("too long", item.transcript.error.lower()) + + def test_unknown_duration_skips_download(self): + # Critical: when duration cannot be probed, we MUST refuse to download + # because we can't bound the Whisper bill. Earlier this case fell + # through and incurred unbudgeted cost. + item = self._item() + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=None), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio") as ta: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True, "max_audio_seconds_per_video": 600}, + secrets={"OPENAI_API_KEY": "k"}, + ) + dl.assert_not_called() + ta.assert_not_called() + self.assertIn("duration unknown", item.transcript.error.lower()) + + def test_download_failure_recorded_in_error(self): + from content_parser.transcription.downloader import DownloadError + item = self._item() + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=30.0), \ + patch("content_parser.transcription.runner.download_audio", side_effect=DownloadError("nope")), \ + patch("content_parser.transcription.runner.transcribe_audio") as ta: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + ta.assert_not_called() + self.assertIn("download", item.transcript.error.lower()) + + def test_whisper_failure_recorded_in_error(self): + from content_parser.transcription.whisper_api import WhisperError + item = self._item() + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=30.0), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio", side_effect=WhisperError("api died")): + audio_path = self.tmp / "fake.mp3" + audio_path.write_bytes(b"\x00") + dl.return_value = audio_path + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + self.assertIn("whisper", item.transcript.error.lower()) + + def test_only_if_missing_skips_when_transcript_present(self): + item = self._item() + item.transcript = Transcript( + language="ru", is_generated=False, + segments=[{"start": 0.0, "duration": 1.0, "text": "existing"}], + text="existing", error=None, + ) + with patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio") as ta: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + only_if_missing=True, + ) + dl.assert_not_called() + ta.assert_not_called() + self.assertEqual(item.transcript.text, "existing") # untouched + + def test_only_if_missing_runs_when_segments_empty(self): + item = self._item() + item.transcript = Transcript( + language=None, is_generated=None, segments=[], text="", + error="blocked", # earlier youtube-transcript-api fail + ) + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=30.0), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio", return_value=WHISPER_RESPONSE): + audio_path = self.tmp / "fake.mp3" + audio_path.write_bytes(b"\x00") + dl.return_value = audio_path + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + only_if_missing=True, + ) + # Was filled by Whisper now + self.assertEqual(item.transcript.text, "Hello world") + self.assertEqual(len(item.transcript.segments), 2) + + def test_private_url_rejected_before_download(self): + # SSRF guard: even if duration probe would succeed, refuse private IPs. + item = self._item(media={"video_url": "http://169.254.169.254/latest/meta-data/"}) + item.url = "" # force fallback to media.video_url + with patch("content_parser.transcription.runner.get_duration_seconds") as gd, \ + patch("content_parser.transcription.runner.download_audio") as dl: + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + gd.assert_not_called() + dl.assert_not_called() + self.assertIn("non-public", item.transcript.error.lower()) + + def test_no_video_url_silent_skip(self): + item = self._item(media={}, url="") + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + self.assertIsNone(item.transcript) + + def test_prefers_canonical_url_over_cdn(self): + # Instagram CDN URL in media.video_url, but post URL in item.url. + # We should prefer the post URL — yt-dlp can re-fetch fresh CDN links. + item = self._item( + url="https://www.instagram.com/reel/AAA/", + media={"video_url": "https://expired-cdn.example/x.mp4?token=old"}, + ) + with patch("content_parser.transcription.runner.get_duration_seconds", return_value=30.0), \ + patch("content_parser.transcription.runner.download_audio") as dl, \ + patch("content_parser.transcription.runner.transcribe_audio", return_value=WHISPER_RESPONSE): + audio_path = self.tmp / "fake.mp3" + audio_path.write_bytes(b"\x00") + dl.return_value = audio_path + runner_mod.maybe_transcribe( + item, + settings={"transcribe_videos": True}, + secrets={"OPENAI_API_KEY": "k"}, + ) + url_passed = dl.call_args.args[0] + self.assertIn("instagram.com", url_passed) + self.assertNotIn("expired-cdn", url_passed) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_transcription_whisper.py b/tests/test_transcription_whisper.py new file mode 100644 index 0000000..f98835b --- /dev/null +++ b/tests/test_transcription_whisper.py @@ -0,0 +1,165 @@ +"""Tests for content_parser.transcription.whisper_api — HTTP client behavior.""" +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from content_parser.transcription.whisper_api import WhisperError, transcribe_audio + + +def _mock_response(payload, *, ok=True, status=200): + m = MagicMock() + m.ok = ok + m.status_code = status + m.json.return_value = payload + m.text = str(payload) + return m + + +class TranscribeAudioTest(unittest.TestCase): + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.audio = Path(self.tmpdir.name) / "x.mp3" + self.audio.write_bytes(b"\x00" * 100) # any bytes — we mock the HTTP call + + def tearDown(self): + self.tmpdir.cleanup() + + def test_missing_api_key_raises(self): + with self.assertRaises(WhisperError): + transcribe_audio(self.audio, "") + + def test_missing_file_raises(self): + with self.assertRaises(WhisperError): + transcribe_audio(Path("/nonexistent.mp3"), "k") + + def test_uses_bearer_header(self): + payload = {"text": "hello", "language": "en", "segments": []} + with patch("content_parser.transcription.whisper_api.requests.post") as rp: + rp.return_value = _mock_response(payload) + transcribe_audio(self.audio, "MY_TOKEN") + kwargs = rp.call_args.kwargs + self.assertEqual(kwargs["headers"], {"Authorization": "Bearer MY_TOKEN"}) + + def test_sends_verbose_json_format(self): + with patch("content_parser.transcription.whisper_api.requests.post") as rp: + rp.return_value = _mock_response({"text": "x", "segments": []}) + transcribe_audio(self.audio, "k") + data = rp.call_args.kwargs["data"] + self.assertEqual(data["model"], "whisper-1") + self.assertEqual(data["response_format"], "verbose_json") + self.assertEqual(data["timestamp_granularities[]"], "segment") + + def test_language_optional(self): + with patch("content_parser.transcription.whisper_api.requests.post") as rp: + rp.return_value = _mock_response({"text": "x", "segments": []}) + transcribe_audio(self.audio, "k", language="ru") + data = rp.call_args.kwargs["data"] + self.assertEqual(data["language"], "ru") + + def test_401_message_explicit(self): + with patch("content_parser.transcription.whisper_api.requests.post") as rp: + rp.return_value = _mock_response({}, ok=False, status=401) + with self.assertRaises(WhisperError) as cm: + transcribe_audio(self.audio, "bad") + self.assertIn("API key", str(cm.exception)) + + def test_429_rate_limit(self): + # 429 is now retried; with max_retries=0 we get a single attempt and a raise. + with patch("content_parser.transcription.whisper_api.requests.post") as rp, \ + patch("content_parser.transcription.whisper_api._sleep"): + rp.return_value = _mock_response({}, ok=False, status=429) + with self.assertRaises(WhisperError) as cm: + transcribe_audio(self.audio, "k", max_retries=0) + self.assertIn("rate-limit", str(cm.exception).lower()) + + def test_other_error_includes_message(self): + resp = _mock_response({"error": {"message": "Bad request: file too short"}}, ok=False, status=400) + with patch("content_parser.transcription.whisper_api.requests.post", return_value=resp): + with self.assertRaises(WhisperError) as cm: + transcribe_audio(self.audio, "k") + self.assertIn("Bad request", str(cm.exception)) + self.assertIn("400", str(cm.exception)) + + def test_returns_parsed_json(self): + payload = { + "text": "Hello world", + "language": "en", + "segments": [ + {"id": 0, "start": 0.0, "end": 2.5, "text": "Hello world"}, + ], + } + with patch("content_parser.transcription.whisper_api.requests.post") as rp: + rp.return_value = _mock_response(payload) + result = transcribe_audio(self.audio, "k") + self.assertEqual(result["text"], "Hello world") + self.assertEqual(len(result["segments"]), 1) + + +class RetryTest(unittest.TestCase): + """Whisper retries 429 and 5xx with exponential backoff.""" + + def setUp(self): + self.tmpdir = tempfile.TemporaryDirectory() + self.audio = Path(self.tmpdir.name) / "x.mp3" + self.audio.write_bytes(b"\x00" * 100) + + def tearDown(self): + self.tmpdir.cleanup() + + def test_429_then_success(self): + sleeps: list[float] = [] + responses = [ + _mock_response({}, ok=False, status=429), + _mock_response({"text": "ok", "segments": []}), + ] + with patch("content_parser.transcription.whisper_api.requests.post", side_effect=responses) as rp, \ + patch("content_parser.transcription.whisper_api._sleep", side_effect=sleeps.append): + result = transcribe_audio(self.audio, "k") + self.assertEqual(rp.call_count, 2) + self.assertEqual(sleeps, [2.0]) + self.assertEqual(result["text"], "ok") + + def test_500_then_503_then_success(self): + sleeps: list[float] = [] + responses = [ + _mock_response({}, ok=False, status=500), + _mock_response({}, ok=False, status=503), + _mock_response({"text": "ok", "segments": []}), + ] + with patch("content_parser.transcription.whisper_api.requests.post", side_effect=responses) as rp, \ + patch("content_parser.transcription.whisper_api._sleep", side_effect=sleeps.append): + transcribe_audio(self.audio, "k") + self.assertEqual(rp.call_count, 3) + self.assertEqual(sleeps, [2.0, 4.0]) # exponential + + def test_429_exhausts_retries_then_raises(self): + responses = [_mock_response({}, ok=False, status=429)] * 5 # plenty + with patch("content_parser.transcription.whisper_api.requests.post", side_effect=responses) as rp, \ + patch("content_parser.transcription.whisper_api._sleep"): + with self.assertRaises(WhisperError): + transcribe_audio(self.audio, "k", max_retries=2) + # Initial + 2 retries = 3 calls. + self.assertEqual(rp.call_count, 3) + + def test_401_does_not_retry(self): + responses = [_mock_response({}, ok=False, status=401)] * 3 + with patch("content_parser.transcription.whisper_api.requests.post", side_effect=responses) as rp, \ + patch("content_parser.transcription.whisper_api._sleep"): + with self.assertRaises(WhisperError): + transcribe_audio(self.audio, "bad") + self.assertEqual(rp.call_count, 1) + + def test_400_does_not_retry(self): + responses = [_mock_response({"error": {"message": "bad audio"}}, ok=False, status=400)] * 3 + with patch("content_parser.transcription.whisper_api.requests.post", side_effect=responses) as rp, \ + patch("content_parser.transcription.whisper_api._sleep"): + with self.assertRaises(WhisperError): + transcribe_audio(self.audio, "k") + self.assertEqual(rp.call_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vk_adapter.py b/tests/test_vk_adapter.py new file mode 100644 index 0000000..6a3fc58 --- /dev/null +++ b/tests/test_vk_adapter.py @@ -0,0 +1,167 @@ +"""Tests for content_parser.plugins.vk.adapter — VK dicts → Item/Comment.""" +from __future__ import annotations + +import unittest + +from content_parser.plugins.vk.adapter import ( + _iso, + _label_for_id, + comment_to_core, + index_by_id, + post_to_item, +) + + +SAMPLE_POST = { + "id": 678, + "owner_id": -12345, # negative = group + "from_id": -12345, + "date": 1_700_000_000, + "post_type": "post", + "text": "Запуск нового продукта\nПодробности по ссылке.", + "attachments": [ + {"type": "photo", "photo": {}}, + {"type": "video", "video": {}}, + ], + "comments": {"count": 42}, + "likes": {"count": 1234}, + "reposts": {"count": 56}, + "views": {"count": 100_000}, + "marked_as_ads": False, + "is_pinned": True, +} + + +SAMPLE_PROFILES = [ + {"id": 555, "first_name": "Иван", "last_name": "Иванов", "screen_name": "ivanov"}, + {"id": 777, "first_name": "Петя", "last_name": "Петров"}, +] +SAMPLE_GROUPS = [ + {"id": 12345, "name": "Awesome Group", "screen_name": "awesome"}, +] + + +class IsoTest(unittest.TestCase): + def test_unix_to_iso(self): + self.assertTrue(_iso(1_700_000_000).startswith("2023-")) + + def test_none(self): + self.assertIsNone(_iso(None)) + + +class LabelForIdTest(unittest.TestCase): + def test_user_id(self): + profiles = index_by_id(SAMPLE_PROFILES) + groups = index_by_id(SAMPLE_GROUPS) + name, label = _label_for_id(555, profiles, groups) + self.assertEqual(name, "Иван Иванов") + self.assertEqual(label, "id555") + + def test_group_id_is_negative(self): + profiles = index_by_id(SAMPLE_PROFILES) + groups = index_by_id(SAMPLE_GROUPS) + name, label = _label_for_id(-12345, profiles, groups) + self.assertEqual(name, "Awesome Group") + self.assertEqual(label, "club12345") + + def test_unknown_user_falls_back_to_id(self): + name, label = _label_for_id(999, {}, {}) + self.assertEqual(name, "id999") + self.assertEqual(label, "id999") + + +class PostToItemTest(unittest.TestCase): + def test_basic_fields(self): + item = post_to_item( + SAMPLE_POST, + profiles_by_id=index_by_id(SAMPLE_PROFILES), + groups_by_id=index_by_id(SAMPLE_GROUPS), + ) + self.assertEqual(item.source, "vk") + self.assertEqual(item.item_id, "-12345_678") + self.assertEqual(item.url, "https://vk.com/wall-12345_678") + self.assertEqual(item.title, "Запуск нового продукта") + self.assertEqual(item.author, "Awesome Group") + self.assertEqual(item.author_id, "-12345") + self.assertTrue(item.published_at.startswith("2023-")) + self.assertIn("Запуск", item.text) + + def test_metrics(self): + item = post_to_item( + SAMPLE_POST, + profiles_by_id=index_by_id(SAMPLE_PROFILES), + groups_by_id=index_by_id(SAMPLE_GROUPS), + ) + self.assertEqual(item.media["likes_count"], 1234) + self.assertEqual(item.media["reposts_count"], 56) + self.assertEqual(item.media["comments_count"], 42) + self.assertEqual(item.media["views_count"], 100_000) + self.assertTrue(item.media["has_photo"]) + self.assertTrue(item.media["has_video"]) + self.assertTrue(item.media["is_pinned"]) + # marked_as_ads is False → stripped from media + self.assertNotIn("marked_as_ads", item.media) + + def test_attachment_types_in_extra(self): + item = post_to_item(SAMPLE_POST) + self.assertEqual(item.extra["attachment_types"], ["photo", "video"]) + + def test_explicit_owner_label_wins(self): + item = post_to_item(SAMPLE_POST, owner_label="Custom Label") + self.assertEqual(item.author, "Custom Label") + + def test_empty_text_keeps_title_none(self): + post = dict(SAMPLE_POST) + post["text"] = "" + item = post_to_item(post) + self.assertIsNone(item.title) + + +class CommentToCoreTest(unittest.TestCase): + def test_user_comment(self): + c = { + "id": 9001, + "from_id": 555, + "date": 1_700_000_500, + "text": "Отличный пост", + "likes": {"count": 7}, + } + out = comment_to_core( + c, + parent_id=None, + profiles_by_id=index_by_id(SAMPLE_PROFILES), + groups_by_id=index_by_id(SAMPLE_GROUPS), + ) + self.assertEqual(out.comment_id, "9001") + self.assertIsNone(out.parent_id) + self.assertEqual(out.author, "Иван Иванов") + self.assertEqual(out.author_id, "id555") + self.assertEqual(out.text, "Отличный пост") + self.assertEqual(out.like_count, 7) + + def test_reply_carries_parent_id(self): + c = {"id": 9002, "from_id": -12345, "date": 1_700_000_600, "text": "thx"} + out = comment_to_core( + c, + parent_id="9001", + profiles_by_id=index_by_id(SAMPLE_PROFILES), + groups_by_id=index_by_id(SAMPLE_GROUPS), + ) + self.assertEqual(out.parent_id, "9001") + self.assertEqual(out.author, "Awesome Group") + self.assertEqual(out.author_id, "club12345") + + +class IndexByIdTest(unittest.TestCase): + def test_basic(self): + idx = index_by_id([{"id": 1, "name": "a"}, {"id": 2, "name": "b"}]) + self.assertEqual(idx[1]["name"], "a") + self.assertEqual(idx[2]["name"], "b") + + def test_skips_bad_entries(self): + idx = index_by_id([{"id": 1}, {}, {"id": "not-int"}]) + self.assertEqual(set(idx.keys()), {1}) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_vk_plugin.py b/tests/test_vk_plugin.py new file mode 100644 index 0000000..52603e6 --- /dev/null +++ b/tests/test_vk_plugin.py @@ -0,0 +1,507 @@ +"""Tests for content_parser.plugins.vk.plugin — input validation + dispatch.""" +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from content_parser.core.errors import AuthError, PluginError, RateLimitError +from content_parser.core.redact import redact_spec as _redact_spec +from content_parser.plugins.vk.plugin import VKPlugin, _is_vk_host + + +class NormalizeCommunityTest(unittest.TestCase): + def setUp(self): + self.p = VKPlugin() + + def test_screen_name(self): + self.assertEqual(self.p._normalize_community("durov_says"), "durov_says") + + def test_url_with_screen_name(self): + self.assertEqual( + self.p._normalize_community("https://vk.com/durov_says"), + "durov_says", + ) + + def test_club_prefix_url(self): + self.assertEqual(self.p._normalize_community("https://vk.com/club12345"), "club12345") + + def test_public_prefix(self): + self.assertEqual(self.p._normalize_community("public99"), "club99") + + def test_numeric_id(self): + self.assertEqual(self.p._normalize_community("12345"), "12345") + + def test_strips_www_subdomain(self): + # m.vk.com is recognized, but www.vk.com (a subdomain) — must also pass + self.assertEqual( + self.p._normalize_community("https://m.vk.com/awesome"), + "awesome", + ) + + def test_rejects_non_vk_host(self): + with self.assertRaises(PluginError): + self.p._normalize_community("https://evil.example/durov_says") + + def test_rejects_reserved_path(self): + # vk.com/feed → /feed/ is the news feed, not a community + with self.assertRaises(PluginError): + self.p._normalize_community("https://vk.com/feed") + with self.assertRaises(PluginError): + self.p._normalize_community("im") + + def test_rejects_empty(self): + with self.assertRaises(PluginError): + self.p._normalize_community(" ") + + def test_rejects_invalid_chars(self): + with self.assertRaises(PluginError): + self.p._normalize_community("name with spaces!") + + +class ExtractWallIdTest(unittest.TestCase): + def setUp(self): + self.p = VKPlugin() + + def test_canonical_post_url(self): + self.assertEqual( + self.p._extract_wall_id("https://vk.com/wall-12345_678"), + "-12345_678", + ) + + def test_user_post_url(self): + # positive owner_id = user wall + self.assertEqual( + self.p._extract_wall_id("https://vk.com/wall1_42"), + "1_42", + ) + + def test_mobile_subdomain(self): + self.assertEqual( + self.p._extract_wall_id("https://m.vk.com/wall-1_2"), + "-1_2", + ) + + def test_non_vk_host_rejected(self): + self.assertIsNone(self.p._extract_wall_id("https://evil.example/wall-1_2")) + + def test_non_wall_path_rejected(self): + self.assertIsNone(self.p._extract_wall_id("https://vk.com/durov_says")) + + +class IsVkHostTest(unittest.TestCase): + def test_main_domain(self): + self.assertTrue(_is_vk_host("vk.com")) + self.assertTrue(_is_vk_host("m.vk.com")) + self.assertTrue(_is_vk_host("VK.COM")) + + def test_lookalike_rejected(self): + self.assertFalse(_is_vk_host("evilvk.com")) + self.assertFalse(_is_vk_host("vk.com.evil.example")) + + def test_empty(self): + self.assertFalse(_is_vk_host("")) + + +class RedactSpecTest(unittest.TestCase): + def test_strips_query(self): + out = _redact_spec("post:https://vk.com/wall-1_2?token=secret") + self.assertNotIn("secret", out) + self.assertIn("?…", out) + + def test_strips_fragment(self): + out = _redact_spec("post:https://vk.com/wall-1_2#access_token=xxx") + self.assertNotIn("access_token", out) + + def test_truncates_long(self): + out = _redact_spec("community:" + "x" * 200) + self.assertLessEqual(len(out), 80) + + +class ResolveTest(unittest.TestCase): + def setUp(self): + self.p = VKPlugin() + self.secrets = {"VK_ACCESS_TOKEN": "x"} + + def test_specs_carry_kind_prefix(self): + specs = self.p.resolve( + { + "query": ["маркетинг"], + "community": ["durov_says"], + "post_url": ["https://vk.com/wall-1_2"], + }, + {}, + self.secrets, + ) + kinds = [s.split(":", 1)[0] for s in specs] + self.assertEqual(kinds.count("query"), 1) + self.assertEqual(kinds.count("community"), 1) + self.assertEqual(kinds.count("post"), 1) + + def test_dedupe(self): + specs = self.p.resolve( + {"community": ["durov_says", "https://vk.com/durov_says", "durov_says"]}, + {}, + self.secrets, + ) + self.assertEqual(len(specs), 1) + + def test_rejects_non_wall_url_in_post_field(self): + with self.assertRaises(PluginError): + self.p.resolve( + {"post_url": ["https://vk.com/durov_says"]}, + {}, + self.secrets, + ) + + +class FetchAuthGuardTest(unittest.TestCase): + def test_missing_token_raises_auth(self): + p = VKPlugin() + with self.assertRaises(AuthError): + list(p.fetch(["community:durov_says"], {}, {})) + + +class RetryOnRateLimitTest(unittest.TestCase): + """VKClient retries with exponential backoff on RateLimitError before giving up.""" + + def _mock_response(self, payload, ok=True, status=200): + m = MagicMock() + m.ok = ok + m.status_code = status + m.json.return_value = payload + return m + + def test_retries_then_succeeds(self): + from content_parser.plugins.vk.client import VKClient + + sleeps: list[float] = [] + rate_limit_payload = {"error": {"error_code": 6, "error_msg": "Too many requests per second"}} + success_payload = {"response": {"items": []}} + + with patch("content_parser.plugins.vk.client.requests.Session") as MockSession, \ + patch.object(VKClient, "_sleep", staticmethod(lambda s: sleeps.append(s))): + session = MagicMock() + session.post.side_effect = [ + self._mock_response(rate_limit_payload), + self._mock_response(rate_limit_payload), + self._mock_response(success_payload), + ] + MockSession.return_value = session + + client = VKClient("x") + resp = client.call("groups.search", q="x") + + self.assertEqual(session.post.call_count, 3) + self.assertEqual(resp, {"items": []}) + self.assertEqual(sleeps, [1.0, 2.0]) # exponential backoff + + def test_gives_up_after_max_retries(self): + from content_parser.plugins.vk.client import VKClient + + rate_limit_payload = {"error": {"error_code": 6, "error_msg": "x"}} + + with patch("content_parser.plugins.vk.client.requests.Session") as MockSession, \ + patch.object(VKClient, "_sleep", staticmethod(lambda s: None)): + session = MagicMock() + session.post.return_value = self._mock_response(rate_limit_payload) + MockSession.return_value = session + + client = VKClient("x", max_rate_limit_retries=2) + with self.assertRaises(RateLimitError): + client.call("groups.search", q="x") + + # Initial call + 2 retries = 3 total. + self.assertEqual(session.post.call_count, 3) + + def test_auth_error_not_retried(self): + from content_parser.plugins.vk.client import VKClient + + auth_payload = {"error": {"error_code": 5, "error_msg": "Auth failed"}} + + with patch("content_parser.plugins.vk.client.requests.Session") as MockSession, \ + patch.object(VKClient, "_sleep", staticmethod(lambda s: None)): + session = MagicMock() + session.post.return_value = self._mock_response(auth_payload) + MockSession.return_value = session + + client = VKClient("x") + with self.assertRaises(AuthError): + client.call("groups.search", q="x") + + # Only one call — no retries on AuthError. + self.assertEqual(session.post.call_count, 1) + + +class CommentPaginationTest(unittest.TestCase): + """_fetch_comments correctness: cap respected, short-circuit on count.""" + + def setUp(self): + self.p = VKPlugin() + + def _make_client(self, responses: list[dict]) -> MagicMock: + client = MagicMock() + client.call.side_effect = responses + return client + + @staticmethod + def _comment(cid, replies=None): + c = { + "id": cid, + "from_id": 100 + cid, + "date": 1_700_000_000, + "text": f"comment {cid}", + "likes": {"count": 0}, + } + if replies: + c["thread"] = {"items": replies} + return c + + def test_top_level_cap_enforced_exactly(self): + # 200 top-level comments available, max=50. Should yield exactly 50. + all_comments = [self._comment(i) for i in range(200)] + responses = [ + {"items": all_comments[:100], "profiles": [], "groups": [], "count": 200}, + {"items": all_comments[100:200], "profiles": [], "groups": [], "count": 200}, + ] + client = self._make_client(responses) + out = self.p._fetch_comments( + client, owner_id=-1, post_id=1, max_comments=50, depth="top_level" + ) + self.assertEqual(len(out), 50) + # We only need one page (50 ≤ 100 page size) + self.assertEqual(client.call.call_count, 1) + + def test_depth_all_does_not_overshoot_cap(self): + # 1 top-level with 200 replies, max=10 → exactly 10, not 11. + replies = [self._comment(1000 + i) for i in range(200)] + top = [self._comment(1, replies=replies)] + responses = [ + {"items": top, "profiles": [], "groups": [], "count": 1}, + ] + client = self._make_client(responses) + out = self.p._fetch_comments( + client, owner_id=-1, post_id=1, max_comments=10, depth="all" + ) + self.assertEqual(len(out), 10) + # First should be the top-level; rest replies with parent_id=1 + self.assertIsNone(out[0].parent_id) + for c in out[1:]: + self.assertEqual(c.parent_id, "1") + + def test_pagination_short_circuit_via_count(self): + # Page 1: 100 items, count=100 → no second page. + all_comments = [self._comment(i) for i in range(100)] + responses = [ + {"items": all_comments, "profiles": [], "groups": [], "count": 100}, + ] + client = self._make_client(responses) + out = self.p._fetch_comments( + client, owner_id=-1, post_id=1, max_comments=500, depth="top_level" + ) + self.assertEqual(len(out), 100) + # Critically: only ONE call, no useless extra request. + self.assertEqual(client.call.call_count, 1) + + def test_pagination_continues_when_count_higher_than_page(self): + page1 = [self._comment(i) for i in range(100)] + page2 = [self._comment(100 + i) for i in range(50)] + responses = [ + {"items": page1, "profiles": [], "groups": [], "count": 150}, + {"items": page2, "profiles": [], "groups": [], "count": 150}, + ] + client = self._make_client(responses) + out = self.p._fetch_comments( + client, owner_id=-1, post_id=1, max_comments=500, depth="top_level" + ) + self.assertEqual(len(out), 150) + self.assertEqual(client.call.call_count, 2) + + +class AdapterDefensiveTest(unittest.TestCase): + def test_post_to_item_raises_on_missing_owner_id(self): + from content_parser.plugins.vk.adapter import post_to_item + with self.assertRaises(ValueError): + post_to_item({"id": 1}) + + def test_post_to_item_raises_on_missing_id(self): + from content_parser.plugins.vk.adapter import post_to_item + with self.assertRaises(ValueError): + post_to_item({"owner_id": -1}) + + +class ClientErrorMappingTest(unittest.TestCase): + """VKClient maps known error_code → AuthError / RateLimitError / PluginError.""" + + def _mock_response(self, payload, status=200, ok=True): + m = MagicMock() + m.ok = ok + m.status_code = status + m.json.return_value = payload + return m + + def _patched_session(self, response): + """Build a context manager that patches requests.Session in the client.""" + session = MagicMock() + session.post.return_value = response + return patch( + "content_parser.plugins.vk.client.requests.Session", + return_value=session, + ), session + + def test_auth_error(self): + from content_parser.plugins.vk.client import VKClient + + ctx, session = self._patched_session(self._mock_response({ + "error": {"error_code": 5, "error_msg": "User authorization failed"} + })) + with ctx, patch.object(VKClient, "_sleep", staticmethod(lambda s: None)): + client = VKClient("bad_token") + with self.assertRaises(AuthError) as cm: + client.call("groups.search", q="x") + self.assertIn("authorization", str(cm.exception).lower()) + + def test_rate_limit(self): + from content_parser.plugins.vk.client import VKClient + + ctx, session = self._patched_session(self._mock_response({ + "error": {"error_code": 6, "error_msg": "Too many requests per second"} + })) + with ctx, patch.object(VKClient, "_sleep", staticmethod(lambda s: None)): + with self.assertRaises(RateLimitError): + VKClient("x", max_rate_limit_retries=0).call("groups.search", q="x") + + def test_other_error(self): + from content_parser.plugins.vk.client import VKClient + + ctx, session = self._patched_session(self._mock_response({ + "error": {"error_code": 100, "error_msg": "Param missing"} + })) + with ctx, patch.object(VKClient, "_sleep", staticmethod(lambda s: None)): + with self.assertRaises(PluginError): + VKClient("x").call("groups.search", q="x") + + def test_token_in_body_not_query(self): + from content_parser.plugins.vk.client import VKClient + + ctx, session = self._patched_session(self._mock_response({"response": {"items": []}})) + with ctx: + VKClient("MY_SECRET_TOKEN").call("groups.search", q="x", count=10) + + args, kwargs = session.post.call_args + # token must not appear in URL + self.assertNotIn("MY_SECRET_TOKEN", args[0]) + self.assertNotIn("MY_SECRET_TOKEN", str(kwargs.get("params") or {})) + # token must appear in form-encoded body + self.assertEqual(kwargs["data"]["access_token"], "MY_SECRET_TOKEN") + self.assertEqual(kwargs["data"]["q"], "x") + + +class FetchDispatchTest(unittest.TestCase): + """Verify fetch routes each spec kind to the right VK API method.""" + + def setUp(self): + self.p = VKPlugin() + self.secrets = {"VK_ACCESS_TOKEN": "tok"} + + def _make_client(self): + client = MagicMock() + return client + + def _patch_client(self, client): + return patch( + "content_parser.plugins.vk.plugin.VKClient", + return_value=client, + ) + + def test_query_calls_groups_search_and_walls(self): + client = self._make_client() + + def fake_call(method, **params): + if method == "groups.search": + return {"items": [{"id": 1, "name": "G1"}, {"id": 2, "name": "G2"}]} + if method == "wall.get": + return {"items": [{"id": 100, "owner_id": params["owner_id"], "date": 1, "text": "x"}], + "profiles": [], "groups": []} + return None + + client.call.side_effect = fake_call + with self._patch_client(client): + items = list(self.p.fetch( + ["query:маркетинг"], + {"max_communities_per_query": 2, "max_posts_per_input": 5, + "fetch_comments": False}, + self.secrets, + )) + self.assertEqual(len(items), 2) + owners = sorted(it.author_id for it in items) + self.assertEqual(owners, ["-1", "-2"]) + + def test_community_resolves_then_fetches_wall(self): + client = self._make_client() + + def fake_call(method, **params): + if method == "groups.getById": + return {"groups": [{"id": 12345, "name": "Awesome", "screen_name": "awesome"}]} + if method == "wall.get": + return {"items": [{"id": 7, "owner_id": -12345, "date": 1, "text": "hi"}], + "profiles": [], "groups": []} + return None + + client.call.side_effect = fake_call + with self._patch_client(client): + items = list(self.p.fetch( + ["community:awesome"], + {"max_posts_per_input": 5, "fetch_comments": False}, + self.secrets, + )) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].item_id, "-12345_7") + self.assertEqual(items[0].author, "Awesome") + + def test_post_calls_wall_getById(self): + client = self._make_client() + + def fake_call(method, **params): + if method == "wall.getById": + self.assertEqual(params["posts"], "-12345_7") + return { + "items": [{"id": 7, "owner_id": -12345, "date": 1, "text": "p"}], + "profiles": [], + "groups": [{"id": 12345, "name": "Awesome"}], + } + return None + + client.call.side_effect = fake_call + with self._patch_client(client): + items = list(self.p.fetch( + ["post:-12345_7"], {"fetch_comments": False}, self.secrets, + )) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].author, "Awesome") + + def test_dedupes_same_post_from_multiple_inputs(self): + client = self._make_client() + same_post = {"id": 7, "owner_id": -12345, "date": 1, "text": "same"} + + def fake_call(method, **params): + if method == "wall.get": + return {"items": [same_post], "profiles": [], "groups": []} + if method == "groups.getById": + return {"groups": [{"id": 12345, "name": "Awesome"}]} + if method == "wall.getById": + return {"items": [same_post], "profiles": [], + "groups": [{"id": 12345, "name": "Awesome"}]} + return None + + client.call.side_effect = fake_call + with self._patch_client(client): + items = list(self.p.fetch( + ["community:awesome", "post:-12345_7"], + {"max_posts_per_input": 5, "fetch_comments": False}, + self.secrets, + )) + self.assertEqual(len(items), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/youtube_parser/__init__.py b/youtube_parser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/youtube_parser/comments.py b/youtube_parser/comments.py new file mode 100644 index 0000000..a4d0532 --- /dev/null +++ b/youtube_parser/comments.py @@ -0,0 +1,2 @@ +"""Back-compat shim — real code now in content_parser.plugins.youtube.comments.""" +from content_parser.plugins.youtube.comments import * # noqa: F401,F403 diff --git a/youtube_parser/main.py b/youtube_parser/main.py new file mode 100644 index 0000000..65ca845 --- /dev/null +++ b/youtube_parser/main.py @@ -0,0 +1,77 @@ +"""Back-compat CLI: translates legacy flags into 'content_parser.cli run --source youtube'. + +The original argument set is preserved so existing scripts keep working. +""" +from __future__ import annotations + +import argparse +import os +import sys + +from content_parser.cli import main as cp_main + + +def _build_legacy_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(prog="youtube_parser") + p.add_argument("--query", "-q", action="append", default=[]) + p.add_argument("--channel", "-c", action="append", default=[]) + p.add_argument("--playlist", "-p", action="append", default=[]) + p.add_argument("--video", "-v", action="append", default=[]) + p.add_argument("--api-key", default=os.environ.get("YOUTUBE_API_KEY")) + p.add_argument("--output", "-o", default=None) + p.add_argument("--search-max", type=int, default=25) + p.add_argument("--per-source-max", type=int, default=None) + p.add_argument("--max-comments", type=int, default=None) + p.add_argument("--include-replies", action="store_true") + p.add_argument("--comment-order", choices=("relevance", "time"), default="relevance") + p.add_argument("--transcript-langs", default="ru,en") + p.add_argument("--no-transcripts", action="store_true") + p.add_argument("--no-comments", action="store_true") + return p + + +def _to_new_argv(args: argparse.Namespace) -> list[str]: + new = ["run", "--source", "youtube"] + + if args.output: + new += ["--output", args.output] + + for q in args.query: + new += ["--query", q] + for c in args.channel: + new += ["--channel", c] + for pl in args.playlist: + new += ["--playlist", pl] + for v in args.video: + new += ["--video", v] + + settings: dict[str, str] = { + "search_max": str(args.search_max), + "per_source_max": str(args.per_source_max if args.per_source_max is not None else 0), + "max_comments": str(args.max_comments if args.max_comments is not None else 0), + "include_replies": "true" if args.include_replies else "false", + "comment_order": args.comment_order, + "transcript_langs": args.transcript_langs, + "fetch_transcripts": "false" if args.no_transcripts else "true", + "fetch_comments": "false" if args.no_comments else "true", + } + for k, v in settings.items(): + new += ["--set", f"{k}={v}"] + return new + + +def main(argv: list[str] | None = None) -> int: + args = _build_legacy_parser().parse_args(argv) + + if not (args.query or args.channel or args.playlist or args.video): + print("Provide at least one of --query / --channel / --playlist / --video.", file=sys.stderr) + return 2 + + if args.api_key: + os.environ.setdefault("YOUTUBE_API_KEY", args.api_key) + + return cp_main(_to_new_argv(args)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/youtube_parser/output.py b/youtube_parser/output.py new file mode 100644 index 0000000..711de0e --- /dev/null +++ b/youtube_parser/output.py @@ -0,0 +1,161 @@ +"""Write parsed results as JSON and Markdown documents.""" +from __future__ import annotations + +import csv +import json +import re +from pathlib import Path + + +def _safe_filename(name: str, max_length: int = 80) -> str: + cleaned = re.sub(r"[^\w\s-]", "", name, flags=re.UNICODE).strip() + cleaned = re.sub(r"\s+", "_", cleaned) + return cleaned[:max_length] or "video" + + +def _format_seconds(seconds: float) -> str: + total = int(seconds) + h, rem = divmod(total, 3600) + m, s = divmod(rem, 60) + return f"{h:02d}:{m:02d}:{s:02d}" if h else f"{m:02d}:{s:02d}" + + +def write_video_json(video: dict, out_dir: Path) -> Path: + fname = f"{video['video_id']}_{_safe_filename(video.get('title') or '')}.json" + path = out_dir / fname + path.write_text(json.dumps(video, ensure_ascii=False, indent=2), encoding="utf-8") + return path + + +def write_video_markdown(video: dict, out_dir: Path) -> Path: + fname = f"{video['video_id']}_{_safe_filename(video.get('title') or '')}.md" + path = out_dir / fname + + lines: list[str] = [] + title = video.get("title") or video["video_id"] + lines.append(f"# {title}") + lines.append("") + lines.append(f"- **Channel:** {video.get('channel_title', '—')}") + lines.append(f"- **URL:** {video.get('url')}") + lines.append(f"- **Published:** {video.get('published_at', '—')}") + lines.append(f"- **Duration:** {video.get('duration', '—')}") + lines.append( + f"- **Views / Likes / Comments:** " + f"{video.get('view_count', '—')} / " + f"{video.get('like_count', '—')} / " + f"{video.get('comment_count', '—')}" + ) + lines.append("") + + if video.get("description"): + lines.append("## Description") + lines.append("") + lines.append(video["description"].strip()) + lines.append("") + + transcript = video.get("transcript") + lines.append("## Transcript") + lines.append("") + if transcript and transcript.get("segments"): + lang = transcript.get("language", "?") + kind = "auto-generated" if transcript.get("is_generated") else "manual" + lines.append(f"_Language: {lang} ({kind})_") + lines.append("") + for seg in transcript["segments"]: + ts = _format_seconds(seg["start"]) + text = seg["text"].replace("\n", " ").strip() + if text: + lines.append(f"- `[{ts}]` {text}") + lines.append("") + else: + lines.append("_No transcript available._") + lines.append("") + + comments = video.get("comments") or [] + lines.append(f"## Comments ({len(comments)})") + lines.append("") + if not comments: + lines.append("_No comments fetched (disabled or empty)._") + lines.append("") + else: + by_parent: dict[str | None, list[dict]] = {} + for c in comments: + by_parent.setdefault(c.get("parent_id"), []).append(c) + + for top in by_parent.get(None, []): + lines.append( + f"### {top.get('author', '—')} " + f"_({top.get('published_at', '—')}, ♥ {top.get('like_count', 0)})_" + ) + lines.append("") + lines.append((top.get("text") or "").strip()) + lines.append("") + for reply in by_parent.get(top["comment_id"], []): + lines.append( + f"> **{reply.get('author', '—')}** " + f"_({reply.get('published_at', '—')}, ♥ {reply.get('like_count', 0)})_" + ) + lines.append("> ") + for ln in (reply.get("text") or "").strip().splitlines(): + lines.append(f"> {ln}") + lines.append("") + + path.write_text("\n".join(lines), encoding="utf-8") + return path + + +def write_summary_csv(videos: list[dict], out_dir: Path) -> Path: + path = out_dir / "summary.csv" + fields = [ + "video_id", + "title", + "channel_title", + "url", + "published_at", + "duration", + "view_count", + "like_count", + "comment_count", + "comments_fetched", + "transcript_language", + "transcript_is_generated", + ] + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fields) + writer.writeheader() + for v in videos: + t = v.get("transcript") or {} + writer.writerow( + { + "video_id": v.get("video_id"), + "title": v.get("title"), + "channel_title": v.get("channel_title"), + "url": v.get("url"), + "published_at": v.get("published_at"), + "duration": v.get("duration"), + "view_count": v.get("view_count"), + "like_count": v.get("like_count"), + "comment_count": v.get("comment_count"), + "comments_fetched": len(v.get("comments") or []), + "transcript_language": t.get("language"), + "transcript_is_generated": t.get("is_generated"), + } + ) + return path + + +def write_combined_markdown(videos: list[dict], out_dir: Path) -> Path: + """One big Markdown index linking to each video file.""" + path = out_dir / "index.md" + lines = ["# YouTube Parser Results", "", f"_{len(videos)} video(s)_", ""] + for v in videos: + title = v.get("title") or v["video_id"] + fname = f"{v['video_id']}_{_safe_filename(title)}.md" + comments = len(v.get("comments") or []) + has_transcript = bool((v.get("transcript") or {}).get("segments")) + lines.append( + f"- [{title}]({fname}) — {comments} comment(s), " + f"transcript: {'yes' if has_transcript else 'no'}" + ) + path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return path diff --git a/youtube_parser/sources.py b/youtube_parser/sources.py new file mode 100644 index 0000000..b0878cf --- /dev/null +++ b/youtube_parser/sources.py @@ -0,0 +1,2 @@ +"""Back-compat shim — real code now in content_parser.plugins.youtube.sources.""" +from content_parser.plugins.youtube.sources import * # noqa: F401,F403 diff --git a/youtube_parser/transcripts.py b/youtube_parser/transcripts.py new file mode 100644 index 0000000..31f4fab --- /dev/null +++ b/youtube_parser/transcripts.py @@ -0,0 +1,2 @@ +"""Back-compat shim — real code now in content_parser.plugins.youtube.transcripts.""" +from content_parser.plugins.youtube.transcripts import * # noqa: F401,F403