<a href="https://colab.research.google.com/github/jmcconne100/Pandas_Notebook_Project/blob/main/Schema.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# dq_contracts.py
from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Tuple, Union
from pathlib import Path
import json
import numpy as np
import pandas as pd


# ----------------------------- Helpers -----------------------------

_PANDAS2SCHEMA = {
    "int64": "integer", "int32": "integer", "Int64": "integer",
    "float64": "float", "float32": "float",
    "bool": "boolean",
    "object": "string",
    "category": "category",
    "datetime64[ns]": "datetime",
}

_ALLOWED_COERCIONS = {
    ("integer", "float"),
    ("integer", "string"),  # safe-ish as text
    ("float", "string"),
    ("boolean", "string"),
    ("datetime", "string"),
    ("category", "string"),
    ("string", "category"),
}

def _schema_type_from_dtype(dtype: pd.api.types.CategoricalDtype | np.dtype) -> str:
    s = str(dtype)
    for k, v in _PANDAS2SCHEMA.items():
        if k in s:
            return v
    return "string"


@dataclass
class ColumnConstraint:
    name: str
    type: str
    nullable: bool
    min: Optional[float] = None
    max: Optional[float] = None
    allowed: Optional[List[Any]] = None   # for enums/categories
    unique: Optional[bool] = None         # enforces column uniqueness (optional)


@dataclass
class TableSchema:
    columns: List[ColumnConstraint]
    primary_key: Optional[List[str]] = None
    description: Optional[str] = None
    version: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "columns": [asdict(c) for c in self.columns],
            "primary_key": self.primary_key,
            "description": self.description,
            "version": self.version,
        }

    @staticmethod
    def from_dict(d: Dict[str, Any]) -> "TableSchema":
        cols = [ColumnConstraint(**c) for c in d["columns"]]
        return TableSchema(columns=cols, primary_key=d.get("primary_key"),
                           description=d.get("description"), version=d.get("version"))


# ----------------------------- Core APIs -----------------------------

def infer_schema(
    df: pd.DataFrame,
    *,
    primary_key: Optional[List[str]] = None,
    description: Optional[str] = None,
    version: Optional[str] = None,
    sample_for_ranges: int = 200_000,
) -> TableSchema:
    """
    Infer a lightweight, JSON-serializable table schema from a DataFrame.
    - Maps pandas dtypes to simple types
    - Computes nullability, min/max for numeric/datetime, and allowed values for small-cardinality categoricals
    """
    cols: List[ColumnConstraint] = []
    n = len(df)

    for col in df.columns:
        s = df[col]
        stype = _schema_type_from_dtype(s.dtype)
        nullable = bool(s.isna().any())

        c = ColumnConstraint(name=col, type=stype, nullable=nullable)

        # Ranges
        if stype in ("integer", "float"):
            sample = s.dropna()
            if len(sample) > sample_for_ranges:
                sample = sample.sample(sample_for_ranges, random_state=0)
            if not sample.empty:
                c.min = float(np.nanmin(sample))
                c.max = float(np.nanmax(sample))
        elif stype == "datetime":
            sample = s.dropna()
            if not sample.empty:
                c.min = pd.to_datetime(sample.min()).value / 1e9  # seconds since epoch
                c.max = pd.to_datetime(sample.max()).value / 1e9

        # Allowed values for small-cardinality categoricals/strings
        unique_ct = s.nunique(dropna=True)
        if stype in ("category", "string", "boolean") and unique_ct <= 50:
            allowed = s.dropna().unique().tolist()
            c.allowed = allowed

        # Uniqueness hint if looks like a key
        if primary_key and col in primary_key:
            c.unique = True
        elif unique_ct == n and n > 0:
            c.unique = True

        cols.append(c)

    return TableSchema(columns=cols, primary_key=primary_key, description=description, version=version)


def save_schema(schema: TableSchema, path: Union[str, Path]) -> None:
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    with p.open("w", encoding="utf-8") as f:
        json.dump(schema.to_dict(), f, indent=2)


def load_schema(path: Union[str, Path]) -> TableSchema:
    with Path(path).open("r", encoding="utf-8") as f:
        d = json.load(f)
    return TableSchema.from_dict(d)


def validate_schema(
    df: pd.DataFrame,
    expected: TableSchema,
    *,
    strict_types: bool = False,
    check_ranges: bool = True,
    check_allowed: bool = True,
) -> Dict[str, Any]:
    """
    Validate df against a TableSchema.
    - Column presence/absence
    - Type compatibility (strict or with allowed coercions)
    - Nullability violations
    - Range checks for numeric/datetime (if available)
    - Allowed value checks for categorical/string columns (if provided)
    - Primary key uniqueness if present
    Returns a structured report dict with 'ok': bool and detailed findings.
    """
    report = {
        "ok": True,
        "missing_columns": [],
        "extra_columns": [],
        "type_mismatches": [],
        "nullability_violations": [],
        "range_violations": [],
        "allowed_value_violations": [],
        "primary_key_violations": None,
    }

    expected_cols = {c.name: c for c in expected.columns}
    # Missing / extra
    for col in expected_cols:
        if col not in df.columns:
            report["ok"] = False
            report["missing_columns"].append(col)
    for col in df.columns:
        if col not in expected_cols:
            report["extra_columns"].append(col)

    # Per-column checks
    for col, constraint in expected_cols.items():
        if col not in df.columns:
            continue
        s = df[col]
        actual_type = _schema_type_from_dtype(s.dtype)

        # Type check
        if strict_types:
            if actual_type != constraint.type:
                report["ok"] = False
                report["type_mismatches"].append({"column": col, "expected": constraint.type, "actual": actual_type})
        else:
            if actual_type != constraint.type and (constraint.type, actual_type) not in _ALLOWED_COERCIONS:
                report["ok"] = False
                report["type_mismatches"].append({"column": col, "expected": constraint.type, "actual": actual_type})

        # Nullability
        if constraint.nullable is False and s.isna().any():
            report["ok"] = False
            report["nullability_violations"].append({"column": col, "null_count": int(s.isna().sum())})

        # Ranges
        if check_ranges and constraint.type in ("integer", "float") and (constraint.min is not None or constraint.max is not None):
            s_num = pd.to_numeric(s, errors="coerce")
            too_low = s_num < (constraint.min if constraint.min is not None else -np.inf)
            too_high = s_num > (constraint.max if constraint.max is not None else np.inf)
            n_low, n_high = int(too_low.sum()), int(too_high.sum())
            if n_low or n_high:
                report["ok"] = False
                report["range_violations"].append({"column": col, "below_min": n_low, "above_max": n_high})

        # Allowed values
        if check_allowed and constraint.allowed is not None:
            invalid_mask = ~s.isna() & ~s.astype(object).isin(set(constraint.allowed))
            n_bad = int(invalid_mask.sum())
            if n_bad:
                report["ok"] = False
                sample = s[invalid_mask].astype(str).head(10).tolist()
                report["allowed_value_violations"].append({"column": col, "bad_count": n_bad, "sample_values": sample})

    # Primary key uniqueness
    if expected.primary_key:
        dup = df.duplicated(subset=expected.primary_key, keep=False).sum()
        nulls = df[expected.primary_key].isna().any(axis=1).sum()
        if dup or nulls:
            report["ok"] = False
            report["primary_key_violations"] = {"duplicate_rows": int(dup), "rows_with_null_key": int(nulls)}
        else:
            report["primary_key_violations"] = {"duplicate_rows": 0, "rows_with_null_key": 0}

    return report


def evolve_schema(old: TableSchema, new: TableSchema) -> Dict[str, Any]:
    """
    Compare two schemas and describe changes (add/remove/modify).
    """
    old_cols = {c.name: c for c in old.columns}
    new_cols = {c.name: c for c in new.columns}

    added = [c for c in new_cols if c not in old_cols]
    removed = [c for c in old_cols if c not in new_cols]
    modified: List[Dict[str, Any]] = []

    for name in set(old_cols).intersection(new_cols):
        a, b = old_cols[name], new_cols[name]
        diffs = {}
        for field in ["type", "nullable", "min", "max", "allowed", "unique"]:
            if getattr(a, field) != getattr(b, field):
                diffs[field] = {"old": getattr(a, field), "new": getattr(b, field)}
        if diffs:
            modified.append({"column": name, "changes": diffs})

    return {"added": added, "removed": removed, "modified": modified}


def enforce_contract(
    df: pd.DataFrame,
    contract: TableSchema,
    *,
    cast_strings_to: Optional[Dict[str, str]] = None,   # e.g., {"price": "float"}
    fill_defaults: Optional[Dict[str, Any]] = None,     # e.g., {"country": "US"}
    clip_numeric_to_range: bool = True,
) -> pd.DataFrame:
    """
    Apply a contract to a DataFrame:
    - Add missing columns with defaults (None if not supplied)
    - Cast columns to target types when safe
    - Clip numeric values to [min, max] if specified (optional)
    - Keep only columns present in the contract (in order)
    """
    out = df.copy()
    fill_defaults = fill_defaults or {}
    cast_strings_to = cast_strings_to or {}

    # Ensure columns exist
    for c in contract.columns:
        if c.name not in out.columns:
            out[c.name] = fill_defaults.get(c.name, np.nan if c.nullable else None)

    # Cast basic types
    for c in contract.columns:
        if c.name not in out.columns:
            continue
        t = c.type
        s = out[c.name]
        try:
            if t == "integer":
                out[c.name] = pd.to_numeric(s, errors="coerce").astype("Int64")
            elif t == "float":
                out[c.name] = pd.to_numeric(s, errors="coerce").astype(float)
            elif t == "boolean":
                out[c.name] = s.astype("boolean")
            elif t == "datetime":
                out[c.name] = pd.to_datetime(s, errors="coerce")
            elif t == "category":
                out[c.name] = s.astype("category")
            elif t == "string":
                # Optionally cast specific strings to numeric
                if c.name in cast_strings_to:
                    tgt = cast_strings_to[c.name]
                    if tgt in ("integer", "float"):
                        out[c.name] = pd.to_numeric(s, errors="coerce")
                        if tgt == "integer":
                            out[c.name] = out[c.name].astype("Int64")
                    else:
                        out[c.name] = s.astype("string")
                else:
                    out[c.name] = s.astype("string")
            else:
                out[c.name] = s  # leave as-is
        except Exception:
            # leave column unchanged if cast fails globally
            pass

        # Clip numerics if desired
        if clip_numeric_to_range and t in ("integer", "float") and (c.min is not None or c.max is not None):
            out[c.name] = pd.to_numeric(out[c.name], errors="coerce")
            if c.min is not None:
                out[c.name] = out[c.name].clip(lower=c.min)
            if c.max is not None:
                out[c.name] = out[c.name].clip(upper=c.max)

        # Enforce allowed set (if provided)
        if c.allowed is not None:
            mask = ~out[c.name].isna() & ~out[c.name].astype(object).isin(set(c.allowed))
            if mask.any():
                # Replace invalid with NaN to avoid bad values downstream
                out.loc[mask, c.name] = np.nan

    # Keep only contract columns (and in that order)
    col_order = [c.name for c in contract.columns]
    out = out[col_order]

    return out


In [4]:
# dq_checks.py
from __future__ import annotations
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd

try:
    from scipy import stats
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False


def check_uniqueness(df: pd.DataFrame, keys: List[str]) -> Dict[str, int]:
    """
    Check if `keys` uniquely identify rows.
    Returns counts of duplicate key rows and null-key rows.
    """
    dup_rows = int(df.duplicated(subset=keys, keep=False).sum())
    null_key_rows = int(df[keys].isna().any(axis=1).sum())
    return {"duplicate_rows": dup_rows, "rows_with_null_key": null_key_rows}


def check_referential_integrity(
    child: pd.DataFrame,
    parent: pd.DataFrame,
    fk_cols: List[str],
    pk_cols: Optional[List[str]] = None,
) -> Dict[str, int]:
    """
    Check that child[fk_cols] exists in parent[pk_cols].
    If pk_cols not provided, uses fk_cols in parent as PK.
    Returns number of child rows violating FK constraint.
    """
    pk_cols = pk_cols or fk_cols
    child_keys = child[fk_cols].drop_duplicates()
    parent_keys = parent[pk_cols].drop_duplicates()

    merged = child_keys.merge(parent_keys, left_on=fk_cols, right_on=pk_cols, how="left", indicator=True)
    missing = int((merged["_merge"] == "left_only").sum())
    return {"missing_fk_rows": missing, "child_distinct_keys": int(len(child_keys)), "parent_distinct_keys": int(len(parent_keys))}


def detect_distribution_drift(
    df_new: pd.DataFrame,
    df_ref: pd.DataFrame,
    columns: Optional[List[str]] = None,
    *,
    alpha: float = 0.05,
) -> pd.DataFrame:
    """
    Compare distributions between df_new and df_ref.
    - Numeric columns: two-sample KS test (if SciPy available), else AD-hoc Wasserstein/quantile distance fallbacks.
    - Categorical columns: chi-square test on frequency tables.
    Returns a DataFrame with per-column statistic, p-value, and drift flag (p < alpha).
    """
    if columns is None:
        columns = sorted(set(df_new.columns).intersection(df_ref.columns))

    records = []

    for col in columns:
        a = df_ref[col].dropna()
        b = df_new[col].dropna()

        if a.empty or b.empty:
            records.append({"column": col, "type": "unknown", "test": None, "statistic": np.nan, "p_value": np.nan, "drift": False})
            continue

        if pd.api.types.is_numeric_dtype(a) and pd.api.types.is_numeric_dtype(b):
            if _HAS_SCIPY:
                stat, p = stats.ks_2samp(a, b)
                records.append({"column": col, "type": "numeric", "test": "KS", "statistic": float(stat), "p_value": float(p), "drift": bool(p < alpha)})
            else:
                # Fallback: simple Earth Mover's (Wasserstein) distance normalized by IQR
                from numpy import quantile
                emd = np.abs(np.mean(a) - np.mean(b))
                iqr = (quantile(a, 0.75) - quantile(a, 0.25)) + 1e-9
                stat = float(emd / iqr)
                # Heuristic p: not a true p-value; flag drift if stat > 0.5
                records.append({"column": col, "type": "numeric", "test": "heuristic_EMD", "statistic": stat, "p_value": np.nan, "drift": bool(stat > 0.5)})
        else:
            # Categorical: chi-square on value counts aligned
            vc_a = a.astype(str).value_counts()
            vc_b = b.astype(str).value_counts()
            idx = sorted(set(vc_a.index).union(vc_b.index))
            obs = np.vstack([vc_a.reindex(idx, fill_value=0).values, vc_b.reindex(idx, fill_value=0).values])
            if _HAS_SCIPY:
                chi2, p, dof, _ = stats.chi2_contingency(obs)
                records.append({"column": col, "type": "categorical", "test": "chi2", "statistic": float(chi2), "p_value": float(p), "drift": bool(p < alpha)})
            else:
                # Heuristic: Jensen-Shannon distance proxy via normalized counts
                pa = obs[0] / max(obs[0].sum(), 1)
                pb = obs[1] / max(obs[1].sum(), 1)
                stat = float(np.sqrt(((pa - pb) ** 2).sum()))
                records.append({"column": col, "type": "categorical", "test": "heuristic_L2", "statistic": stat, "p_value": np.nan, "drift": bool(stat > 0.25)})

    return pd.DataFrame.from_records(records)


In [5]:
# Build reference data
np.random.seed(0)
df_ref = pd.DataFrame({
    "id": range(1, 101),
    "age": np.random.normal(35, 8, 100).round(1),
    "country": np.random.choice(["US", "UK", "CA"], 100, p=[0.6, 0.25, 0.15]),
    "joined": pd.date_range("2024-01-01", periods=100, freq="D"),
})
schema_ref = infer_schema(df_ref, primary_key=["id"], description="Users ref v1", version="1.0")

print("Inferred schema (truncated):")
print(schema_ref.to_dict()["columns"][:3])

# Save / load
save_schema(schema_ref, "schema_users_v1.json")
schema_loaded = load_schema("schema_users_v1.json")

# New data with a couple issues: extra column, age spike, a bad country value, duplicate id
df_new = df_ref.copy()
df_new.loc[0, "id"] = 1   # duplicate id on purpose
df_new.loc[5:10, "age"] = 120
df_new.loc[20, "country"] = "DE"  # invalid category
df_new["extra_col"] = 1

report = validate_schema(df_new, schema_loaded, strict_types=False, check_ranges=True, check_allowed=True)
print("\nValidation report ok?:", report["ok"])
print("Missing columns:", report["missing_columns"])
print("Extra columns:", report["extra_columns"])
print("Type mismatches:", report["type_mismatches"])
print("Nullability violations:", report["nullability_violations"])
print("Range violations:", report["range_violations"])
print("Allowed value violations:", report["allowed_value_violations"])
print("PK violations:", report["primary_key_violations"])

# Enforce contract (coerce, clip, drop extra)
df_enforced = enforce_contract(df_new, schema_loaded)
print("\nEnforced DataFrame columns:", list(df_enforced.columns))
print("Age min/max after enforcement:", df_enforced["age"].min(), df_enforced["age"].max())

# Schema evolution example: add a column and widen allowed set
schema_v2 = infer_schema(df_enforced.assign(country="US", loyalty="silver"), primary_key=["id"], description="Users v2", version="2.0")
diff = evolve_schema(schema_loaded, schema_v2)
print("\nSchema evolution diff:")
print(diff)

# DQ checks
print("\nUniqueness check on ['id']:")
print(check_uniqueness(df_new, ["id"]))

# Referential integrity: build parent (countries) and child (users)
parent = pd.DataFrame({"country": ["US", "UK", "CA"]})
ri = check_referential_integrity(df_new, parent, fk_cols=["country"], pk_cols=["country"])
print("\nReferential integrity (country):", ri)

# Distribution drift (e.g., age/country changed)
drift = detect_distribution_drift(df_new, df_ref, columns=["age", "country"])
print("\nDistribution drift:")
print(drift)


Inferred schema (truncated):
[{'name': 'id', 'type': 'integer', 'nullable': False, 'min': 1.0, 'max': 100.0, 'allowed': None, 'unique': True}, {'name': 'age', 'type': 'float', 'nullable': False, 'min': 14.6, 'max': 53.2, 'allowed': None, 'unique': None}, {'name': 'country', 'type': 'string', 'nullable': False, 'min': None, 'max': None, 'allowed': ['US', 'UK', 'CA'], 'unique': None}]

Validation report ok?: False
Missing columns: []
Extra columns: ['extra_col']
Type mismatches: []
Nullability violations: []
Range violations: [{'column': 'age', 'below_min': 0, 'above_max': 6}]
Allowed value violations: [{'column': 'country', 'bad_count': 1, 'sample_values': ['DE']}]
PK violations: {'duplicate_rows': 0, 'rows_with_null_key': 0}

Enforced DataFrame columns: ['id', 'age', 'country', 'joined']
Age min/max after enforcement: 14.6 53.2

Schema evolution diff:
{'added': ['loyalty'], 'removed': [], 'modified': [{'column': 'country', 'changes': {'allowed': {'old': ['US', 'UK', 'CA'], 'new': ['US'