# Root level Data Engineering needs for testing

In [None]:
%%writefile pyproject.toml
[project]
name = "react_fastapi_railway"
version = "0.1.0"
description = "Pytorch and Jax GPU docker container"
authors = [
  { name = "Geoffrey Hadfield" },
]
license = "MIT"
readme = "README.md"

# ─── Restrict to Python 3.10–3.12 ──────────────────────────────
requires-python = ">=3.10,<3.13"

dependencies = [
  # Core web framework
  "fastapi>=0.104.0",
  "uvicorn[standard]>=0.24.0",
  "python-dotenv>=1.0.0",

  # Settings and validation
  "pydantic>=2.0.0",
  "pydantic-settings>=2.0.0",

  # HTTP client and multipart parsing
  "httpx>=0.24.0",
  "python-multipart>=0.0.6",

  # Data & ML basics
  "numpy>=1.24.0",
  "pandas>=2.1.0",
  "scikit-learn>=1.3.0",
  "mlflow>=2.8.0",

  # (Your existing extras—keep if you still need them)
  "matplotlib>=3.4.0",
  "pymc>=5.0.0",
  "arviz>=0.14.0",
  "statsmodels>=0.13.0",
  "jupyterlab>=3.0.0",
  "seaborn>=0.11.0",
  "tabulate>=0.9.0",
  "shap>=0.40.0",
  "xgboost>=1.5.0",
  "lightgbm>=3.3.0",
  "catboost>=1.2.8,<1.3.0",
  "scipy>=1.7.0",
  "shapash[report]>=2.3.0",
  "shapiq>=0.1.0",
  "explainerdashboard==0.5.1",
  "ipywidgets>=8.0.0",
  "nutpie>=0.7.1",
  "numpyro>=0.18.0,<1.0.0",
  "jax==0.6.0",
  "jaxlib==0.6.0",
  "pytensor>=2.18.3",
  "aesara>=2.9.4",
  "tqdm>=4.67.0",
  "pyarrow>=12.0.0",
  "optuna>=3.0.0",
  "optuna-integration[mlflow]>=0.2.0",
  "omegaconf>=2.3.0,<2.4.0",
  "hydra-core>=1.3.2,<1.4.0",
  "aiosqlite>=0.19.0", 
  "python-jose[cryptography]>=3.3.0",
  "passlib[bcrypt]>=1.7.4",
  "bcrypt==4.0.1",  # Pin bcrypt version to resolve warning
  # Rate limiting
  "fastapi-limiter>=0.1.5",
  "aioredis>=2.0.0",
  "httpx>=0.24.0",
  "psutil>=5.0.0,<8.0.0",
  "protobuf>=4.24.4",
  "confluent-kafka>=2.3.0",              # meets provider min
  "sqlalchemy>=1.4.49,<2.0.0",           # py3.12 support, keeps MLflow happy
  "psycopg2-binary>=2.9.9",
  "fastavro>=1.7.0",
  "apache-airflow==2.10.4",              # add core airflow
  "apache-airflow-providers-postgres>=6.2.1,<7.0.0",
  "apache-airflow-providers-apache-kafka>=1.10.0,<2.0.0",
  "apache-airflow-providers-standard>=1.4.1,<2.0.0",
  "apache-airflow-providers-http>=4.5.0,<5.0.0",
  "ipykernel>=6.25.0",
]

[project.optional-dependencies]
dev = [
  "pytest>=7.0.0",
  "black>=23.0.0",
  "isort>=5.0.0",
  "flake8>=5.0.0",
  "mypy>=1.0.0",
  "invoke>=2.2",
]

cuda = [
  "cupy-cuda12x>=12.0.0",
]

[tool.pytensor]
device    = "cuda"
floatX    = "float32"
allow_gc  = true
optimizer = "fast_run"





# Astro/Airflow Local Dev Quickstart
1. Initialize & Start the Project

# From the directory where you want the project to live
cd api/src
mkdir -p airflow_project && cd airflow_project

astro dev init            # Scaffold Airflow project (dags/, Dockerfile, etc.)
astro dev start           # Build image & start all Airflow services (webserver, scheduler, DB)

    If ports (8080/5432) are busy, set alternates before start:
    astro config set webserver.port 8081 / astro config set postgres.port 5433.
    Astronomer
    Astronomer

2. Everyday Lifecycle Commands

Use these while iterating on DAGs and dependencies:

astro dev stop            # Stop containers, keep project state
astro dev restart         # Stop → rebuild image → start (after reqs/Dockerfile changes)

    These two are your main “apply changes” loop during development.
    Astronomer
    Astronomer

3. Inspect, Logs, & Diagnostics

astro dev status          # Health & ports of services
astro dev logs            # Combined logs (Ctrl+C to quit)
astro dev tail scheduler  # Live-tail a specific service (scheduler/webserver/triggerer)
astro dev ps              # Show running containers for this project
astro dev top             # Process table inside a service container
astro dev stats           # CPU/mem stats per container

    Helpful when debugging start-up issues, stuck tasks, or resource pressure.
    Astronomer
    GitHub

4. Force a DAG Reparse (No Waiting)

astro dev run dags reserialize

    Airflow auto-parses: new files ~5 min, edits ~30 s. This command forces an immediate parse.
    Astronomer
    Astronomer

5. One-off DAG Test Runs

astro run <dag-id>

    Compiles and executes a single DAG in a throwaway worker container—fast feedback without touching the scheduler.
    Astronomer


# Within Airflow project, created files from astro dev init that we adjusted to fit our needs

In [None]:
%%writefile api/src/airflow_project/requirements.txt
# Astro Runtime includes the following pre-installed providers packages: https://www.astronomer.io/docs/astro/runtime-image-architecture#provider-packages


In [None]:
%%writefile api/src/airflow_project/airflow_settings.yaml
# This file allows you to configure Airflow Connections, Pools, and Variables in a single place for local development only.
# NOTE: json dicts can be added to the conn_extra field as yaml key value pairs. See the example below.

# For more information, refer to our docs: https://www.astronomer.io/docs/astro/cli/develop-project#configure-airflow_settingsyaml-local-development-only
# For questions, reach out to: https://support.astronomer.io
# For issues create an issue ticket here: https://github.com/astronomer/astro-cli/issues

airflow:
  connections:
    - conn_id:
      conn_type:
      conn_host:
      conn_schema:
      conn_login:
      conn_password:
      conn_port:
      conn_extra:
        example_extra_field: example-value
  pools:
    - pool_name:
      pool_slot:
      pool_description:
  variables:
    - variable_name:
      variable_value:


In [None]:
%%writefile api/src/airflow_project/.env
#not empty










# Utils + Configs

In [None]:
%%writefile api/src/airflow_project/utils/__init__.py
"""
Utils package for the NBA Player Valuation project.
""" 

In [None]:
%%writefile api/src/airflow_project/utils/config.py
"""
Central configuration for the NBA‑Player‑Valuation project.
All magic values live here so they can be tweaked without code edits.
"""
from pathlib import Path
import os

# ── Core directories ───────────────────────────────────────────────────────────
def find_project_root(name: str = "airflow_project") -> Path:
    """
    Walk up from this file (or cwd) until a directory named `name` is found.
    Fallback to cwd if not found.
    """
    try:
        p = Path(__file__).resolve()
    except NameError:
        p = Path.cwd()
    # walk through p and its parents
    for parent in (p, *p.parents):
        if parent.name == name or (parent / ".git").is_dir():
            return parent
    # no match → fallback
    return Path.cwd()

# Allow explicit override
if find_project_root():
    PROJECT_ROOT = Path(find_project_root()).resolve() # / "api/src/airflow_project"
else:
    PROJECT_ROOT = find_project_root() # / "api/src/airflow_project"


DATA_DIR: Path = Path(PROJECT_ROOT / "data")
LOG_DIR: Path = Path(PROJECT_ROOT / "logs")
DUCKDB_FILE: Path = Path(DATA_DIR / "nba.duckdb")

# NBA stats API
NBA_API_RPM: int = int(os.getenv("NBA_API_RPM", "12"))  # requests per minute

# Spotrac scraping
SPOTRAC_BASE: str = "https://www.spotrac.com/nba"
SPOTRAC_FREE_AGENTS: str = f"{SPOTRAC_BASE}/free-agents/{{year}}/"
SPOTRAC_CAP_TRACKER: str = f"{SPOTRAC_BASE}/cap/{{year}}/"
SPOTRAC_TAX_TRACKER: str = f"{SPOTRAC_BASE}/tax/_/year/{{year}}/"
# Spotrac dedicated folder


# Injury sources
NBA_OFFICIAL_INJURY_URL: str = "https://cdn.nba.com/static/json/injury/injury_report_{{date}}.json"
ROTOWIRE_RSS: str = "https://www.rotowire.com/rss/news.php?sport=NBA"

# StatsD (optional)
STATSD_HOST: str = os.getenv("STATSD_HOST", "localhost")
STATSD_PORT: int = int(os.getenv("STATSD_PORT", "8125"))

# ── Data ranges ────────────────────────────────────────────────────────────────
SEASONS: range = range(2015, 2026)  # inclusive upper bound matches Spotrac sample

# Thread pools & concurrency
MAX_WORKERS: int = int(os.getenv("NPV_MAX_WORKERS", "8")) 


# ── Core data directories ───────────────────────────────────────────────────────────
RAW_DIR      : Path = DATA_DIR / "raw"
DEBUG_DIR    : Path = DATA_DIR / "debug"
EXPORTS_DIR  : Path = DATA_DIR / "exports"
INJURY_DIR   : Path = DATA_DIR / "injury_reports"
INJURY_DATASETS_DIR : Path = DATA_DIR / "injury_datasets"
NBA_BASIC_ADVANCED_STATS_DIR : Path = DATA_DIR / "nba_basic_advanced_stats"
ADVANCED_METRICS_DIR: Path = DATA_DIR / "new_processed" / "advanced_metrics"
NBA_BASE_DATA_DIR : Path = DATA_DIR / "nba_processed"
DEFENSE_DATA_DIR : Path = DATA_DIR / "defense_metrics"
FINAL_DATASET_DIR : Path = DATA_DIR / "merged_final_dataset"
PLAY_TYPES_DIR : Path = DATA_DIR / "synergyplay_types"
SPOTRAC_DIR  : Path = DATA_DIR / "spotrac_contract_data"
SILVER_DIR   : Path = SPOTRAC_DIR / "silver"
FINAL_DIR    : Path = SPOTRAC_DIR / "final"
SPOTRAC_DEBUG_DIR    : Path = SPOTRAC_DIR / "debug"
SPOTRAC_RAW_DIR      : Path = SPOTRAC_DIR / "raw"

# ── Helper functions (single source of truth) ────────────────────────────────
def get_injury_base_dir() -> Path:
    """
    Return the canonical injury_reports root.  
    ENV `INJURY_DATA_DIR` wins; otherwise we fall back to INJURY_DIR.
    Always creates the dir so callers can assume it exists.
    """
    base = Path(os.getenv("INJURY_DATA_DIR", INJURY_DIR)).resolve()
    return base

def injury_path(*parts: str) -> Path:
    """Shorthand for get_injury_base_dir().joinpath(*parts).resolve()."""
    return get_injury_base_dir().joinpath(*parts).resolve()

# ── One‑shot: ensure all declared dirs exist at import‑time ───────────────────
for _p in (
    DATA_DIR, RAW_DIR, DEBUG_DIR, EXPORTS_DIR, INJURY_DIR,
    SPOTRAC_DIR, SILVER_DIR, FINAL_DIR, SPOTRAC_DEBUG_DIR, SPOTRAC_RAW_DIR,
    NBA_BASIC_ADVANCED_STATS_DIR, NBA_BASE_DATA_DIR, ADVANCED_METRICS_DIR,
    DEFENSE_DATA_DIR, FINAL_DATASET_DIR,
    PLAY_TYPES_DIR   # ← ensure synergy play‑types folder exists
):
    _p.mkdir(parents=True, exist_ok=True)



print("all directories:")
print("root directory:")
print(f"PROJECT_ROOT: {PROJECT_ROOT}")
print(f"DATA_DIR: {DATA_DIR}")
print(f"RAW_DIR: {RAW_DIR}")
print(f"DEBUG_DIR: {DEBUG_DIR}")
print(f"EXPORTS_DIR: {EXPORTS_DIR}")
print(f"INJURY_DIR: {INJURY_DIR}")
print(f"INJURY_DATASETS_DIR: {INJURY_DATASETS_DIR}")
print(f"NBA_BASIC_ADVANCED_STATS_DIR: {NBA_BASIC_ADVANCED_STATS_DIR}")
print(f"NBA_BASE_DATA_DIR: {NBA_BASE_DATA_DIR}")
print(f"ADVANCED_METRICS_DIR: {ADVANCED_METRICS_DIR}")
print(f"DEFENSE_DATA_DIR: {DEFENSE_DATA_DIR}")
print(f"FINAL_DATASET_DIR: {FINAL_DATASET_DIR}")
print(f"PLAY_TYPES_DIR: {PLAY_TYPES_DIR}")
print(f"SPOTRAC_DIR: {SPOTRAC_DIR}")
print(f"SILVER_DIR: {SILVER_DIR}")
print(f"FINAL_DIR: {FINAL_DIR}")
print(f"SPOTRAC_DEBUG_DIR: {SPOTRAC_DEBUG_DIR}")
print(f"SPOTRAC_RAW_DIR: {SPOTRAC_RAW_DIR}")
print("all directories:")


In [None]:
%%writefile api/src/airflow_project/utils/utils.py
import logging
from pathlib import Path
import duckdb, boto3, logging
import pandas as pd
import re
import os


def configure_logging(level: int = logging.INFO, log_dir: str = "logs") -> None:
    Path(log_dir).mkdir(parents=True, exist_ok=True)
    fmt = "%(asctime)s [%(levelname)s] %(name)s :: %(message)s"
    logging.basicConfig(
        level=level,
        format=fmt,
        handlers=[
            logging.FileHandler(Path(log_dir, "data_pipeline.log")),
            logging.StreamHandler()
        ]
    )


def to_duck(table: str, parquet: Path) -> None:
    con = duckdb.connect(database="nba.duckdb", read_only=False)
    con.execute(f"""
        CREATE TABLE IF NOT EXISTS {table} AS
        SELECT * FROM parquet_scan('{parquet}')
        """)
    con.execute(f"COPY (SELECT * FROM parquet_scan('{parquet}')) TO '{table}.parquet' (FORMAT PARQUET);")
    con.close()
    logging.info("%s appended to DuckDB", parquet)

def maybe_upload(parquet: Path) -> None:
    if os.getenv("AWS_ACCESS_KEY_ID"):
        s3 = boto3.client("s3")
        bucket = os.getenv("NBA_S3_BUCKET", "nba-bronze")
        s3.upload_file(str(parquet), bucket, f"bronze/{parquet.name}")
        logging.info("Uploaded %s to s3://%s/bronze/", parquet.name, bucket)

def _write_parquet_safe(df: pd.DataFrame, out_path: Path) -> None:
    """
    Atomic parquet write to avoid partial files.
    ALSO: drop any duplicate column names before writing.
    """
    # 1) Drop duplicate columns by name (keep the first occurrence)
    if df.columns.duplicated().any():
        df = df.loc[:, ~df.columns.duplicated()]
    # 2) Write to a temporary file and atomically replace
    tmp = out_path.with_suffix(out_path.suffix + ".tmp")
    df.to_parquet(tmp, index=False)
    os.replace(tmp, out_path)


# --- NEW final writer ---
def write_final_dataset(df: pd.DataFrame, out_path: Path) -> Path:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    _write_parquet_safe(df, out_path)
    logging.info("Final merged parquet -> %s", out_path)
    return out_path

In [10]:
%%writefile api/src/airflow_project/utils/data_check_utils.py
"""
Reusable data-pull + data-quality checks (clear-arg edition)

Dimensions covered:
- Min/Max summary, Nulls, Duplicates, Outliers (IQR), Consistency (rules),
  Completeness, Accuracy (vs. reference or lookup lists), Validity (column rules),
  Uniqueness (primary key).

Style & references:
- Naming follows PEP 8 / Google Python Style Guide.
- Data-quality dimensions per common industry taxonomy.
- Outliers via 1.5×IQR; duplicates via pandas .duplicated().
"""
from __future__ import annotations

import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd


# =========================
# Report dataclasses
# =========================
@dataclass
class NullsReport:
    per_column: pd.DataFrame  # columns: column, null_count, total_rows, null_pct
    total_rows: int

@dataclass
class DuplicatesReport:
    has_duplicates: bool
    duplicate_row_count: int
    sample_rows: pd.DataFrame

@dataclass
class OutliersReport:
    per_column_summary: pd.DataFrame  # column, outlier_count, pct, lower_fence, upper_fence
    sample_rows: pd.DataFrame

@dataclass
class ConsistencyReport:
    violations_by_rule: Dict[str, int]
    sample_rows_by_rule: Dict[str, pd.DataFrame]

@dataclass
class CompletenessReport:
    required_columns: List[str]
    nonnull_ratio_by_column: Dict[str, float]
    failing_columns: List[str]

@dataclass
class AccuracyReport:
    check_status_by_name: Dict[str, str]
    sample_rows_by_check: Dict[str, pd.DataFrame]

@dataclass
class ValidityReport:
    status: str  # "passed" | "failed" | "skipped"
    details: Dict[str, Any]

@dataclass
class UniquenessReport:
    primary_key_columns: List[str]
    has_violations: bool
    duplicate_key_row_count: int
    sample_duplicate_keys: pd.DataFrame

@dataclass
class RangeReport:
    numeric_summary: pd.DataFrame   # column, count, min, max, mean, std

@dataclass
class DataQualityReport:
    minmax: RangeReport
    nulls: NullsReport
    duplicates: DuplicatesReport
    outliers: OutliersReport
    consistency: ConsistencyReport
    completeness: CompletenessReport
    accuracy: AccuracyReport
    validity: ValidityReport
    uniqueness: UniquenessReport

    def to_json(self, pretty: bool = True) -> str:
        """
        Convert the report to JSON format.
        
        Args:
            pretty: If True, format with indentation for readability
            
        Returns:
            JSON string representation of the report
        """
        def df_to_records(df: pd.DataFrame) -> List[Dict[str, Any]]:
            return df.to_dict(orient="records") if not df.empty else []
        
        data = {
            "minmax": {"numeric_summary": df_to_records(self.minmax.numeric_summary)},
            "nulls": {
                "total_rows": self.nulls.total_rows,
                "per_column": df_to_records(self.nulls.per_column),
            },
            "duplicates": {
                "has_duplicates": self.duplicates.has_duplicates,
                "duplicate_row_count": self.duplicates.duplicate_row_count,
                "sample_rows": df_to_records(self.duplicates.sample_rows),
            },
            "outliers": {
                "per_column_summary": df_to_records(self.outliers.per_column_summary),
                "sample_rows": df_to_records(self.outliers.sample_rows),
            },
            "consistency": {
                "violations_by_rule": self.consistency.violations_by_rule,
                "sample_rows_by_rule": {k: df_to_records(v) for k, v in self.consistency.sample_rows_by_rule.items()},
            },
            "completeness": {
                "required_columns": self.completeness.required_columns,
                "nonnull_ratio_by_column": self.completeness.nonnull_ratio_by_column,
                "failing_columns": self.completeness.failing_columns,
            },
            "accuracy": {
                "check_status_by_name": self.accuracy.check_status_by_name,
                "sample_rows_by_check": {k: df_to_records(v) for k, v in self.accuracy.sample_rows_by_check.items()},
            },
            "validity": {"status": self.validity.status, "details": self.validity.details},
            "uniqueness": {
                "primary_key_columns": self.uniqueness.primary_key_columns,
                "has_violations": self.uniqueness.has_violations,
                "duplicate_key_row_count": self.uniqueness.duplicate_key_row_count,
                "sample_duplicate_keys": df_to_records(self.uniqueness.sample_duplicate_keys),
            },
        }
        return json.dumps(data, indent=2 if pretty else None)


def _json_or_none(obj):
    try:
        return json.dumps(obj, ensure_ascii=False)
    except Exception:
        return None

def report_to_rows(report: DataQualityReport) -> list[dict]:
    """Flatten DataQualityReport into long rows suitable for a single table."""
    rows: list[dict] = []

    # ---- Min/Max
    for rec in report.minmax.numeric_summary.to_dict("records"):
        rows.append({
            "section": "minmax", "subsection": rec["column"], "field": "count", "value": rec["count"], "extra": None
        })
        for k in ("min", "max", "mean", "std"):
            rows.append({"section": "minmax", "subsection": rec["column"], "field": k, "value": rec[k], "extra": None})

    # ---- Nulls
    for rec in report.nulls.per_column.to_dict("records"):
        rows.append({
            "section": "nulls", "subsection": rec["column"], "field": "null_count",
            "value": rec["null_count"], "extra": {"total_rows": rec["total_rows"], "null_pct": rec["null_pct"]}
        })

    # ---- Duplicates
    rows.append({
        "section": "duplicates", "subsection": "_overall", "field": "has_duplicates",
        "value": bool(report.duplicates.has_duplicates),
        "extra": {"duplicate_row_count": int(report.duplicates.duplicate_row_count),
                  "sample_rows": report.duplicates.sample_rows.head(5).to_dict("records")}
    })

    # ---- Outliers (IQR fences: Q1-1.5*IQR, Q3+1.5*IQR)
    for rec in report.outliers.per_column_summary.to_dict("records"):
        rows.append({
            "section": "outliers", "subsection": rec["column"], "field": "outlier_count",
            "value": rec.get("outlier_count", 0),
            "extra": {"pct": rec.get("pct"), "lower_fence": rec.get("lower_fence"), "upper_fence": rec.get("upper_fence")}
        })
    if not report.outliers.sample_rows.empty:
        rows.append({
            "section": "outliers", "subsection": "_samples", "field": "sample_rows",
            "value": None, "extra": report.outliers.sample_rows.head(5).to_dict("records")
        })

    # ---- Consistency
    for name, cnt in (report.consistency.violations_by_rule or {}).items():
        rows.append({"section": "consistency", "subsection": name, "field": "violations", "value": int(cnt), "extra": None})
        ex = report.consistency.sample_rows_by_rule.get(name)
        if isinstance(ex, pd.DataFrame) and not ex.empty:
            rows.append({"section": "consistency", "subsection": name, "field": "sample_rows",
                         "value": None, "extra": ex.head(5).to_dict("records")})

    # ---- Completeness
    rows.append({
        "section": "completeness", "subsection": "_overall", "field": "required_columns",
        "value": None, "extra": report.completeness.required_columns
    })
    for col, ratio in (report.completeness.nonnull_ratio_by_column or {}).items():
        rows.append({"section": "completeness", "subsection": col, "field": "nonnull_ratio", "value": float(ratio), "extra": None})
    if report.completeness.failing_columns:
        rows.append({"section": "completeness", "subsection": "_overall", "field": "failing_columns",
                     "value": None, "extra": report.completeness.failing_columns})

    # ---- Accuracy
    for name, status in (report.accuracy.check_status_by_name or {}).items():
        rows.append({"section": "accuracy", "subsection": name, "field": "status", "value": status, "extra": None})
        ex = report.accuracy.sample_rows_by_check.get(name)
        if isinstance(ex, pd.DataFrame) and not ex.empty:
            rows.append({"section": "accuracy", "subsection": name, "field": "sample_rows",
                         "value": None, "extra": ex.head(5).to_dict("records")})

    # ---- Validity
    rows.append({"section": "validity", "subsection": "_overall", "field": "status", "value": report.validity.status, "extra": None})
    if report.validity.details:
        rows.append({"section": "validity", "subsection": "_overall", "field": "details",
                     "value": None, "extra": report.validity.details})

    # ---- Uniqueness
    rows.append({"section": "uniqueness", "subsection": "_overall", "field": "primary_key_columns",
                 "value": None, "extra": report.uniqueness.primary_key_columns})
    rows.append({"section": "uniqueness", "subsection": "_overall", "field": "has_violations",
                 "value": bool(report.uniqueness.has_violations), "extra": None})
    rows.append({"section": "uniqueness", "subsection": "_overall", "field": "duplicate_key_row_count",
                 "value": int(report.uniqueness.duplicate_key_row_count), "extra": None})
    if isinstance(report.uniqueness.sample_duplicate_keys, pd.DataFrame) and not report.uniqueness.sample_duplicate_keys.empty:
        rows.append({"section": "uniqueness", "subsection": "_samples", "field": "sample_duplicate_keys",
                     "value": None, "extra": report.uniqueness.sample_duplicate_keys.head(5).to_dict("records")})

    return rows

def report_to_table(report: DataQualityReport) -> pd.DataFrame:
    """Return a tidy DataFrame with columns: section, subsection, field, value, extra(json)."""
    rows = report_to_rows(report)
    tbl = pd.DataFrame.from_records(rows, columns=["section", "subsection", "field", "value", "extra"])
    # store 'extra' as a compact JSON string for easy viewing
    if "extra" in tbl.columns:
        tbl["extra"] = tbl["extra"].map(_json_or_none)
    return tbl.sort_values(["section", "subsection", "field"]).reset_index(drop=True)

# ---------- Render helpers ----------
def report_to_markdown_table(report: DataQualityReport) -> str:
    """Render the flattened table to Markdown (great for README/Slack)."""
    df = report_to_table(report)
    return df.to_markdown(index=False)

def report_to_html_table(report: DataQualityReport, caption: str | None = None) -> str:
    """Render a compact HTML table (email/dashboard)."""
    df = report_to_table(report)
    styler = df.style.set_table_attributes('class="table table-sm"').hide(axis="index")
    if caption:
        styler = styler.set_caption(caption)
    return styler.to_html()


# =========================
# Core check helpers
# =========================
def summarize_min_max(df: pd.DataFrame) -> RangeReport:
    """
    Numeric min/max/mean/std summary. Returns empty summary if no numeric columns.
    """
    num = df.select_dtypes(include=[np.number])
    if num.empty:
        empty = pd.DataFrame(columns=["column", "count", "min", "max", "mean", "std"])
        return RangeReport(empty)
    desc = pd.DataFrame({
        "column": num.columns,
        "count": num.count().values,
        "min": num.min().values,
        "max": num.max().values,
        "mean": num.mean().values,
        "std": num.std(ddof=1).values,
    }).sort_values("column").reset_index(drop=True)
    return RangeReport(desc)


def summarize_nulls(df: pd.DataFrame) -> NullsReport:
    """
    Per-column null counts and percentages.
    """
    total = len(df)
    rows = []
    for col in df.columns:
        n = int(df[col].isna().sum())
        rows.append({
            "column": col,
            "null_count": n,
            "total_rows": total,
            "null_pct": (n / total) if total else 0.0,
        })
    out = pd.DataFrame(rows).sort_values("null_pct", ascending=False).reset_index(drop=True)
    return NullsReport(per_column=out, total_rows=total)


def find_duplicates(
    df: pd.DataFrame,
    *,
    key_columns: Optional[Sequence[str]] = None,
    sample_size: int = 10,
) -> DuplicatesReport:
    """
    Detect duplicate rows. If key_columns is provided, checks duplicates over that subset;
    otherwise considers the entire row. Uses pandas .duplicated().
    """
    mask = df.duplicated(subset=key_columns, keep=False)
    dupes = df.loc[mask].head(sample_size)
    count = int(mask.sum())
    return DuplicatesReport(
        has_duplicates=count > 0,
        duplicate_row_count=count,
        sample_rows=dupes
    )


def detect_outliers_iqr(
    df: pd.DataFrame,
    *,
    numeric_columns: Optional[Sequence[str]] = None,
    sample_size: int = 10,
) -> OutliersReport:
    """
    IQR-based outliers per column: values < Q1-1.5*IQR or > Q3+1.5*IQR.
    """
    num = df.select_dtypes(include=[np.number])
    if numeric_columns:
        num = num[[c for c in numeric_columns if c in num.columns]]
    summary_rows = []
    any_outlier_mask = pd.Series(False, index=df.index)
    for col in num.columns:
        s = num[col].dropna()
        if s.empty:
            continue
        q1, q3 = s.quantile(0.25), s.quantile(0.75)
        iqr = q3 - q1
        lower, upper = q1 - 1.5 * iqr, q3 + 1.5 * iqr
        mask = (num[col] < lower) | (num[col] > upper)
        any_outlier_mask |= mask.fillna(False)
        summary_rows.append({
            "column": col,
            "outlier_count": int(mask.sum()),
            "pct": float(mask.mean()) if mask.size else 0.0,
            "lower_fence": float(lower),
            "upper_fence": float(upper),
        })
    summary = (pd.DataFrame(summary_rows)
               if summary_rows else
               pd.DataFrame(columns=["column","outlier_count","pct","lower_fence","upper_fence"]))
    summary = summary.sort_values("outlier_count", ascending=False).reset_index(drop=True)
    examples = df.loc[any_outlier_mask].head(sample_size)
    return OutliersReport(per_column_summary=summary, sample_rows=examples)


def check_consistency_rules(
    df: pd.DataFrame,
    *,
    rules_by_name: Dict[str, Callable[[pd.DataFrame], pd.Series]],
    sample_size: int = 5,
) -> ConsistencyReport:
    """
    Apply boolean-violation rules (fn returns mask of violations). Examples:
    - season string format; set membership; cross-field dependencies.
    """
    violations, examples = {}, {}
    for name, fn in rules_by_name.items():
        mask = fn(df)
        cnt = int(mask.sum())
        violations[name] = cnt
        if cnt:
            examples[name] = df.loc[mask].head(sample_size)
    return ConsistencyReport(violations_by_rule=violations, sample_rows_by_rule=examples)


def check_completeness(
    df: pd.DataFrame,
    *,
    required_columns: Sequence[str],
    min_nonnull_ratio: float = 1.0,
) -> CompletenessReport:
    """
    Ensure required columns exist and meet a non-null ratio (default 100%).
    """
    ratios: Dict[str, float] = {}
    failing: List[str] = []
    total = len(df)
    for col in required_columns:
        if col not in df.columns:
            ratios[col] = 0.0
            failing.append(col)
            continue
        ratio = 1.0 if total == 0 else float(df[col].notna().mean())
        ratios[col] = ratio
        if ratio < min_nonnull_ratio:
            failing.append(col)
    return CompletenessReport(
        required_columns=list(required_columns),
        nonnull_ratio_by_column=ratios,
        failing_columns=failing,
    )


def check_accuracy(
    df: pd.DataFrame,
    *,
    allowed_values_by_column: Optional[Dict[str, Iterable[Any]]] = None,
    reference_table: Optional[pd.DataFrame] = None,
    reference_join_keys: Optional[Sequence[str]] = None,
    sample_size: int = 5,
) -> AccuracyReport:
    """
    Two simple accuracy patterns:
      (A) allowed_values_by_column: domain/lookup checks for categorical fields.
      (B) reference join coverage: ensure rows match a reference table on keys.
    """
    checks: Dict[str, str] = {}
    examples: Dict[str, pd.DataFrame] = {}

    if allowed_values_by_column:
        for col, allowed in allowed_values_by_column.items():
            if col in df.columns:
                bad_mask = ~df[col].isin(list(allowed)) & df[col].notna()
                cnt = int(bad_mask.sum())
                checks[f"lookup:{col}"] = "ok (0 violations)" if cnt == 0 else f"violations={cnt}"
                if cnt:
                    examples[f"lookup:{col}"] = df.loc[bad_mask].head(sample_size)

    if reference_table is not None and reference_join_keys:
        left = df[list(reference_join_keys)].drop_duplicates()
        right = reference_table[list(reference_join_keys)].drop_duplicates()
        merged = left.merge(right, on=list(reference_join_keys), how="left", indicator=True)
        missing = merged["_merge"] == "left_only"
        cnt = int(missing.sum())
        checks["reference_join_coverage"] = "ok (0 missing)" if cnt == 0 else f"missing={cnt}"
        if cnt:
            examples["reference_join_coverage"] = merged.loc[missing].drop(columns=["_merge"]).head(sample_size)

    return AccuracyReport(check_status_by_name=checks, sample_rows_by_check=examples)


def check_validity(
    df: pd.DataFrame,
    *,
    column_rules: Optional[Dict[str, Dict[str, Any]]] = None,
) -> ValidityReport:
    """
    Lightweight validity (type/regex/range/membership). For strict contracts, consider Pandera/GE.
    Example column_rules:
    {
      "PLAYER_ID":   {"dtype": "int64", "ge": 0},
      "SEASON":      {"regex": r"^\\d{4}-\\d{2}$"},
      "SEASON_TYPE": {"in": ["Regular Season","Playoffs","Pre Season","All Star"]},
      "E_OFF_RATING":{"ge": 0, "le": 200}
    }
    """
    if not column_rules:
        return ValidityReport(status="skipped", details={})

    details = {"errors": []}
    for col, rules in column_rules.items():
        if col not in df.columns:
            details["errors"].append({"column": col, "error": "missing column"})
            continue
        s = df[col]
        if "dtype" in rules:
            expected = rules["dtype"]
            if str(s.dtype) != expected:
                details["errors"].append({"column": col, "error": f"dtype {s.dtype} != {expected}"})
        if "regex" in rules:
            pat = re.compile(rules["regex"])
            bad = s.dropna().astype(str).map(lambda x: not bool(pat.match(x)))
            if bad.any():
                details["errors"].append({"column": col, "error": f"regex_fail_count={int(bad.sum())}"})
        if "in" in rules:
            allowed = set(rules["in"])
            bad = ~s.isin(allowed) & s.notna()
            if bad.any():
                details["errors"].append({"column": col, "error": f"in_fail_count={int(bad.sum())}"})
        if s.dtype.kind in "fi":
            if "ge" in rules:
                bad = s.dropna() < rules["ge"]
                if bad.any():
                    details["errors"].append({"column": col, "error": f"ge_fail_count={int(bad.sum())}"})
            if "le" in rules:
                bad = s.dropna() > rules["le"]
                if bad.any():
                    details["errors"].append({"column": col, "error": f"le_fail_count={int(bad.sum())}"})
    status = "passed" if not details["errors"] else "failed"
    return ValidityReport(status=status, details=details)


def check_uniqueness(
    df: pd.DataFrame,
    *,
    primary_key_columns: Sequence[str],
    sample_size: int = 10,
) -> UniquenessReport:
    """
    Primary-key-level uniqueness across `primary_key_columns`.
    Uses pandas .duplicated(subset=...).
    """
    if not primary_key_columns:
        return UniquenessReport([], False, 0, pd.DataFrame())
    mask = df.duplicated(subset=primary_key_columns, keep=False)
    dup_count = int(mask.sum())
    examples = df.loc[mask, list(primary_key_columns)].drop_duplicates().head(sample_size)
    return UniquenessReport(
        primary_key_columns=list(primary_key_columns),
        has_violations=dup_count > 0,
        duplicate_key_row_count=dup_count,
        sample_duplicate_keys=examples,
    )


# ==========================================
# Orchestrator: run all checks & build report
# ==========================================
def run_quality_suite(
    df: pd.DataFrame,
    *,
    required_columns: Sequence[str],
    primary_key_columns: Sequence[str],
    outlier_numeric_columns: Optional[Sequence[str]] = None,
    consistency_rules_by_name: Optional[Dict[str, Callable[[pd.DataFrame], pd.Series]]] = None,
    allowed_values_by_column: Optional[Dict[str, Iterable[Any]]] = None,
    reference_table: Optional[pd.DataFrame] = None,
    reference_join_keys: Optional[Sequence[str]] = None,
    validity_column_rules: Optional[Dict[str, Dict[str, Any]]] = None,
    output_format: str = "table",
    table_caption: Optional[str] = None,
) -> Union[DataQualityReport, pd.DataFrame, str]:
    """
    One-call quality suite with explicit names and flexible output formatting.

    Args:
        df: The DataFrame to analyze
        required_columns: columns that must exist and meet min non-null ratio (default 100%)
        primary_key_columns: columns forming the business key (checked for uniqueness)
        outlier_numeric_columns: numeric columns to scan for IQR outliers (None = all numeric)
        consistency_rules_by_name: {rule_name: fn(df)->mask_of_violations}
        allowed_values_by_column: {column: iterable of allowed values} for lookup/domain checks
        reference_table: DataFrame to test join coverage/accuracy against
        reference_join_keys: columns used to join df to reference_table
        validity_column_rules: {column: {dtype/regex/in/ge/le}} lightweight validity rules
        output_format: Output format - "table" (default), "json", "markdown", "html", or "report"
        table_caption: Optional caption for HTML table format
        
    Returns:
        - "table": pd.DataFrame with flattened results (default)
        - "json": JSON string
        - "markdown": Markdown table string  
        - "html": HTML table string
        - "report": Raw DataQualityReport object for programmatic access
        
    Examples:
        # Default table format
        df_results = run_quality_suite(df, required_columns=["id"], primary_key_columns=["id"])
        
        # JSON format for APIs
        json_results = run_quality_suite(df, required_columns=["id"], primary_key_columns=["id"], 
                                       output_format="json")
        
        # Markdown for documentation
        md_results = run_quality_suite(df, required_columns=["id"], primary_key_columns=["id"], 
                                     output_format="markdown")
        
        # Raw report object for complex logic
        report = run_quality_suite(df, required_columns=["id"], primary_key_columns=["id"], 
                                 output_format="report")
    """
    # Validate output format
    valid_formats = ["table", "json", "markdown", "html", "report"]
    if output_format not in valid_formats:
        raise ValueError(f"output_format must be one of {valid_formats}, got '{output_format}'")
    
    # Run all quality checks
    minmax = summarize_min_max(df)
    nulls = summarize_nulls(df)
    duplicates = find_duplicates(df, key_columns=None)
    outliers = detect_outliers_iqr(df, numeric_columns=outlier_numeric_columns)
    consistency = check_consistency_rules(df, rules_by_name=(consistency_rules_by_name or {}))
    completeness = check_completeness(df, required_columns=required_columns)
    accuracy = check_accuracy(
        df,
        allowed_values_by_column=allowed_values_by_column,
        reference_table=reference_table,
        reference_join_keys=reference_join_keys,
    )
    validity = check_validity(df, column_rules=validity_column_rules)
    uniqueness = check_uniqueness(df, primary_key_columns=primary_key_columns)

    # Create the report object
    report = DataQualityReport(
        minmax=minmax,
        nulls=nulls,
        duplicates=duplicates,
        outliers=outliers,
        consistency=consistency,
        completeness=completeness,
        accuracy=accuracy,
        validity=validity,
        uniqueness=uniqueness,
    )
    
    # Return in requested format
    if output_format == "report":
        return report
    elif output_format == "json":
        return report.to_json()
    elif output_format == "table":
        return report_to_table(report)
    elif output_format == "markdown":
        return report_to_markdown_table(report)
    elif output_format == "html":
        return report_to_html_table(report, caption=table_caption)
    else:
        # This shouldn't happen due to validation above, but just in case
        raise ValueError(f"Unsupported output format: {output_format}")


# ------------------------------------------
# Convenience functions for specific formats
# ------------------------------------------
def run_quality_suite_json(df: pd.DataFrame, **kwargs) -> str:
    """
    Convenience function to run quality suite and return JSON.
    All kwargs are passed to run_quality_suite except output_format.
    """
    kwargs.pop('output_format', None)  # Remove if present
    return run_quality_suite(df, output_format="json", **kwargs)


def run_quality_suite_table(df: pd.DataFrame, **kwargs) -> pd.DataFrame:
    """
    Convenience function to run quality suite and return DataFrame table.
    All kwargs are passed to run_quality_suite except output_format.
    """
    kwargs.pop('output_format', None)  # Remove if present
    return run_quality_suite(df, output_format="table", **kwargs)


def run_quality_suite_markdown(df: pd.DataFrame, **kwargs) -> str:
    """
    Convenience function to run quality suite and return Markdown.
    All kwargs are passed to run_quality_suite except output_format.
    """
    kwargs.pop('output_format', None)  # Remove if present
    return run_quality_suite(df, output_format="markdown", **kwargs)


def run_quality_suite_html(df: pd.DataFrame, **kwargs) -> str:
    """
    Convenience function to run quality suite and return HTML.
    All kwargs are passed to run_quality_suite except output_format.
    """
    kwargs.pop('output_format', None)  # Remove if present
    return run_quality_suite(df, output_format="html", **kwargs)


# ------------------------------------------
# Backward-compatibility aliases (soft-deprecate)
# ------------------------------------------
def _alias_run_quality_suite(
    df: pd.DataFrame,
    *,
    required_cols: Sequence[str],
    pk: Sequence[str],
    iqr_outlier_cols: Optional[Sequence[str]] = None,
    consistency_rules: Optional[Dict[str, Callable[[pd.DataFrame], pd.Series]]] = None,
    accuracy_lookups: Optional[Dict[str, Iterable[Any]]] = None,
    accuracy_reference_df: Optional[pd.DataFrame] = None,
    accuracy_join_key: Optional[Sequence[str]] = None,
    validity_schema: Optional[Dict[str, Dict[str, Any]]] = None,
    output_format: str = "table",
) -> Union[DataQualityReport, pd.DataFrame, str]:
    """Deprecated arg names wrapper. Prefer run_quality_suite(...)."""
    return run_quality_suite(
        df,
        required_columns=required_cols,
        primary_key_columns=pk,
        outlier_numeric_columns=iqr_outlier_cols,
        consistency_rules_by_name=consistency_rules,
        allowed_values_by_column=accuracy_lookups,
        reference_table=accuracy_reference_df,
        reference_join_keys=accuracy_join_key,
        validity_column_rules=validity_schema,
        output_format=output_format,
    )


# =========================
# Example usage
# =========================
def _example_fetch() -> pd.DataFrame:
    return pd.DataFrame({
        "PLAYER_ID": [1, 1, 2, 3, 4],
        "PLAYER_NAME": ["A", "A", "B", "C", None],
        "SEASON": ["2023-24", "2023-24", "2023-24", "2023-24", "BAD"],
        "SEASON_TYPE": ["Regular Season", "Regular Season", "Regular Season", "Playoffs", "Regular Season"],
        "E_OFF_RATING": [110.0, 110.0, 5000.0, 108.2, 105.0],
        "TEAM_ID": [10, 10, 11, None, 13],
        "TEAM_NAME": ["X", "X", "Y", "Z", "W"],
    })


def _example_rules() -> Dict[str, Callable[[pd.DataFrame], pd.Series]]:
    return {
        "season_format": lambda df: ~df["SEASON"].astype(str).str.match(r"^\d{4}-\d{2}$"),
        "team_id_name_coupling": lambda df: (
            (df["TEAM_ID"].isna() & df["TEAM_NAME"].notna()) |
            (df["TEAM_ID"].notna() & df["TEAM_NAME"].isna())
        ),
        "season_type_allowed": lambda df: ~df["SEASON_TYPE"].isin(
            ["Regular Season", "Playoffs", "Pre Season", "All Star"]
        ),
    }


def _example_validity_rules() -> Dict[str, Dict[str, Any]]:
    return {
        "PLAYER_ID": {"dtype": "int64", "ge": 0},
        "PLAYER_NAME": {"dtype": "object"},
        "SEASON": {"regex": r"^\d{4}-\d{2}$"},
        "SEASON_TYPE": {"in": ["Regular Season", "Playoffs", "Pre Season", "All Star"]},
        "E_OFF_RATING": {"ge": 0, "le": 200},
    }


def example_usage() -> None:
    """Demonstrates different output formats."""
    df = _example_fetch()
    
    # Default table format
    print("=== TABLE FORMAT (DEFAULT) ===")
    table_result = run_quality_suite(
        df,
        required_columns=["PLAYER_ID", "PLAYER_NAME", "SEASON", "SEASON_TYPE"],
        primary_key_columns=["PLAYER_ID", "SEASON", "SEASON_TYPE"],
        outlier_numeric_columns=["E_OFF_RATING"],
        consistency_rules_by_name=_example_rules(),
        allowed_values_by_column={"SEASON_TYPE": ["Regular Season","Playoffs","Pre Season","All Star"]},
        reference_table=pd.DataFrame({"TEAM_ID": [10, 11, 12, 13], "TEAM_NAME": ["X","Y","Z","W"]}),
        reference_join_keys=["TEAM_ID"],
        validity_column_rules=_example_validity_rules(),
    )
    print(table_result)
    print(f"Table shape: {table_result.shape}")
    
    # JSON format
    print("\n=== JSON FORMAT ===")
    json_result = run_quality_suite_json(
        df,
        required_columns=["PLAYER_ID", "PLAYER_NAME"],
        primary_key_columns=["PLAYER_ID"],
    )
    print(json_result[:500] + "...")
    
    # Markdown format
    print("\n=== MARKDOWN FORMAT ===")
    md_result = run_quality_suite_markdown(
        df,
        required_columns=["PLAYER_ID"],
        primary_key_columns=["PLAYER_ID"],
    )
    print(md_result[:500] + "...")
    
    # Raw report for programmatic access
    print("\n=== REPORT OBJECT ===")
    report = run_quality_suite(
        df,
        required_columns=["PLAYER_ID"],
        primary_key_columns=["PLAYER_ID"],
        output_format="report"
    )
    print(f"Report type: {type(report)}")
    print(f"Nulls found: {report.nulls.per_column.shape[0]} columns checked")
    print(f"Duplicates found: {report.duplicates.has_duplicates}")
    
    # Save outputs
    out = Path("data/quality_reports")
    out.mkdir(parents=True, exist_ok=True)
    
    # Save in different formats
    (out / "example_report.json").write_text(json_result)
    (out / "example_report.md").write_text(md_result)
    table_result.to_csv(out / "example_report.csv", index=False)
    
    html_result = run_quality_suite_html(
        df,
        required_columns=["PLAYER_ID"],
        primary_key_columns=["PLAYER_ID"],
        table_caption="Data Quality Report - Player Metrics"
    )
    (out / "example_report.html").write_text(html_result)
    
    print(f"\nReports saved to {out}/")


if __name__ == "__main__":
    example_usage()

Overwriting api/src/airflow_project/utils/data_check_utils.py


# EDA

In [None]:
%%writefile api/src/airflow_project/eda/data_pull.py
"""

**Checks on the data source**
Pull in the data_check_utils to utilize the below, all arguments are optional:
report = run_quality_suite(
    df,
    required_columns=["PLAYER_ID", "PLAYER_NAME", "SEASON", "SEASON_TYPE"],
    primary_key_columns=["PLAYER_ID", "SEASON", "SEASON_TYPE"],
    outlier_numeric_columns=["E_OFF_RATING"],
    consistency_rules_by_name=_example_rules(),
    allowed_values_by_column={"SEASON_TYPE": ["Regular Season","Playoffs","Pre Season","All Star"]},
    reference_table=pd.DataFrame({"TEAM_ID": [10, 11, 12, 13], "TEAM_NAME": ["X","Y","Z","W"]}),
    reference_join_keys=["TEAM_ID"],
    validity_column_rules=_example_validity_rules(),
)
    
**Data Saving Utils**
Pull in the utils for the logging/duckdb/s3/etc.

    
    
"""











# Dags

In [None]:
%%writefile api/src/airflow_project/dags/exampledag.py
"""
## Astronaut ETL example DAG

This DAG queries the list of astronauts currently in space from the
Open Notify API and prints each astronaut's name and flying craft.

There are two tasks, one to get the data from the API and save the results,
and another to print the results. Both tasks are written in Python using
Airflow's TaskFlow API, which allows you to easily turn Python functions into
Airflow tasks, and automatically infer dependencies and pass data.

The second task uses dynamic task mapping to create a copy of the task for
each Astronaut in the list retrieved from the API. This list will change
depending on how many Astronauts are in space, and the DAG will adjust
accordingly each time it runs.

For more explanation and getting started instructions, see our Write your
first DAG tutorial: https://www.astronomer.io/docs/learn/get-started-with-airflow

![Picture of the ISS](https://www.esa.int/var/esa/storage/images/esa_multimedia/images/2010/02/space_station_over_earth/10293696-3-eng-GB/Space_Station_over_Earth_card_full.jpg)
"""

from airflow.sdk.definitions.asset import Asset
from airflow.decorators import dag, task
from pendulum import datetime
import requests


# Define the basic parameters of the DAG, like schedule and start_date
@dag(
    start_date=datetime(2024, 1, 1),
    schedule="@daily",
    catchup=False,
    doc_md=__doc__,
    default_args={"owner": "Astro", "retries": 3},
    tags=["example"],
)
def example_astronauts():
    # Define tasks
    @task(
        # Define a dataset outlet for the task. This can be used to schedule downstream DAGs when this task has run.
        outlets=[Asset("current_astronauts")]
    )  # Define that this task updates the `current_astronauts` Dataset
    def get_astronauts(**context) -> list[dict]:
        """
        This task uses the requests library to retrieve a list of Astronauts
        currently in space. The results are pushed to XCom with a specific key
        so they can be used in a downstream pipeline. The task returns a list
        of Astronauts to be used in the next task.
        """
        try:
            r = requests.get("http://api.open-notify.org/astros.json")
            r.raise_for_status()
            number_of_people_in_space = r.json()["number"]
            list_of_people_in_space = r.json()["people"]
        except Exception:
            print("API currently not available, using hardcoded data instead.")
            number_of_people_in_space = 12
            list_of_people_in_space = [
                {"craft": "ISS", "name": "Oleg Kononenko"},
                {"craft": "ISS", "name": "Nikolai Chub"},
                {"craft": "ISS", "name": "Tracy Caldwell Dyson"},
                {"craft": "ISS", "name": "Matthew Dominick"},
                {"craft": "ISS", "name": "Michael Barratt"},
                {"craft": "ISS", "name": "Jeanette Epps"},
                {"craft": "ISS", "name": "Alexander Grebenkin"},
                {"craft": "ISS", "name": "Butch Wilmore"},
                {"craft": "ISS", "name": "Sunita Williams"},
                {"craft": "Tiangong", "name": "Li Guangsu"},
                {"craft": "Tiangong", "name": "Li Cong"},
                {"craft": "Tiangong", "name": "Ye Guangfu"},
            ]

        context["ti"].xcom_push(
            key="number_of_people_in_space", value=number_of_people_in_space
        )
        return list_of_people_in_space

    @task
    def print_astronaut_craft(greeting: str, person_in_space: dict) -> None:
        """
        This task creates a print statement with the name of an
        Astronaut in space and the craft they are flying on from
        the API request results of the previous task, along with a
        greeting which is hard-coded in this example.
        """
        craft = person_in_space["craft"]
        name = person_in_space["name"]

        print(f"{name} is currently in space flying on the {craft}! {greeting}")

    # Use dynamic task mapping to run the print_astronaut_craft task for each
    # Astronaut in space
    print_astronaut_craft.partial(greeting="Hello! :)").expand(
        person_in_space=get_astronauts()  # Define dependencies using TaskFlow API syntax
    )


# Instantiate the DAG
example_astronauts()


# Plugins
- plugins: Add custom or community plugins for your project to this file. It is empty by default.

In [None]:
%%writefile api/src/airflow_project/plugins/custom_operator.py



# testing folder

In [None]:
%%writefile api/src/airflow_project/tests/dags/test_dag_example.py
"""Example DAGs test. This test ensures that all Dags have tags, retries set to two, and no import errors. This is an example pytest and may not be fit the context of your DAGs. Feel free to add and remove tests."""

import os
import logging
from contextlib import contextmanager
import pytest
from airflow.models import DagBag


@contextmanager
def suppress_logging(namespace):
    logger = logging.getLogger(namespace)
    old_value = logger.disabled
    logger.disabled = True
    try:
        yield
    finally:
        logger.disabled = old_value


def get_import_errors():
    """
    Generate a tuple for import errors in the dag bag
    """
    with suppress_logging("airflow"):
        dag_bag = DagBag(include_examples=False)

        def strip_path_prefix(path):
            return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))

        # prepend "(None,None)" to ensure that a test object is always created even if it's a no op.
        return [(None, None)] + [
            (strip_path_prefix(k), v.strip()) for k, v in dag_bag.import_errors.items()
        ]


def get_dags():
    """
    Generate a tuple of dag_id, <DAG objects> in the DagBag
    """
    with suppress_logging("airflow"):
        dag_bag = DagBag(include_examples=False)

    def strip_path_prefix(path):
        return os.path.relpath(path, os.environ.get("AIRFLOW_HOME"))

    return [(k, v, strip_path_prefix(v.fileloc)) for k, v in dag_bag.dags.items()]


@pytest.mark.parametrize(
    "rel_path,rv", get_import_errors(), ids=[x[0] for x in get_import_errors()]
)
def test_file_imports(rel_path, rv):
    """Test for import errors on a file"""
    if rel_path and rv:
        raise Exception(f"{rel_path} failed to import with message \n {rv}")


APPROVED_TAGS = {}


@pytest.mark.parametrize(
    "dag_id,dag,fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_tags(dag_id, dag, fileloc):
    """
    test if a DAG is tagged and if those TAGs are in the approved list
    """
    assert dag.tags, f"{dag_id} in {fileloc} has no tags"
    if APPROVED_TAGS:
        assert not set(dag.tags) - APPROVED_TAGS


@pytest.mark.parametrize(
    "dag_id,dag, fileloc", get_dags(), ids=[x[2] for x in get_dags()]
)
def test_dag_retries(dag_id, dag, fileloc):
    """
    test if a DAG has retries set
    """
    assert (
        dag.default_args.get("retries", None) >= 2
    ), f"{dag_id} in {fileloc} must have task retries >= 2."
