In [19]:
import sys
import json
from datetime import date, datetime
from dateutil.relativedelta import relativedelta

import pandas as pd
import mlflow
import mlflow.sklearn
from mlflow.models.signature import infer_signature
from sklearn.decomposition import PCA
from domino.data_sources import DataSourceClient

# ─── PCA WITH PREDICT ─────────────────────────────────────────────────────────
class PCAWithPredict(PCA):
    def predict(self, X):
        return self.transform(X)

# ─── DEFAULTS & ARG PARSING ─────────────────────────────────────────────────
default_window_months   = 3*12
default_n_components    = 5
default_svd_solver      = "auto"
default_whiten          = False
default_tol             = 0.0
default_curve_type      = "US Treasury Par"
default_as_of_date      = date.today()

args = sys.argv[1:]
try:
    window_months = int(args[0])
except:
    window_months = default_window_months

try:
    n_components = int(args[1])
except:
    n_components = default_n_components

try:
    svd_solver = args[2] if args[2] in ("auto","full","arpack","randomized") else default_svd_solver
except:
    svd_solver = default_svd_solver

try:
    whiten = str(args[3]).lower() in ("true","1","yes")
except:
    whiten = default_whiten

try:
    tol = float(args[4])
except:
    tol = default_tol

try:
    curve_type = args[5]
except:
    curve_type = default_curve_type

try:
    as_of_date = datetime.strptime(args[6], '%Y-%m-%d').date()
except:
    as_of_date = default_as_of_date

# ─── MAIN PIPELINE ──────────────────────────────────────────────────────────
def main(
    window_months: int,
    n_components: int,
    svd_solver: str,
    whiten: bool,
    tol: float,
    curve_type: str,
    as_of_date: date
):
    ds = DataSourceClient().get_datasource("market_data")
    end_date = as_of_date
    start_date = end_date - relativedelta(months=window_months)

    # 1) Fetch curve data for the specified type
    sql = f"""
        SELECT curve_date, tenor_num, rate
          FROM rate_curves
         WHERE curve_type = '{curve_type}'
           AND curve_date BETWEEN '{start_date}' AND '{end_date}'
         ORDER BY curve_date, tenor_num;
    """
    df = ds.query(sql).to_pandas()

    pivot = (
        df
        .pivot(index="curve_date", columns="tenor_num", values="rate")
        .sort_index()
        .interpolate(axis=1).ffill(axis=1).bfill(axis=1)
        .dropna(axis=0)
    )
    dates = pivot.index.to_list()
    X = pivot.values
    
    mlflow.set_experiment("Curve PCA history test")
    with mlflow.start_run() as run:
        run_id = run.info.run_id
    
        # Log hyperparameters
        mlflow.log_params({
            "window_months": window_months,
            "n_components": n_components,
            "svd_solver": svd_solver,
            "whiten": whiten,
            "tol": tol,
            "curve_type": curve_type,
            "as_of_date": str(as_of_date)
        })
    
        # 2) Fit our subclassed PCA
        pca = PCAWithPredict(
            n_components=n_components,
            svd_solver=svd_solver,
            whiten=whiten,
            tol=tol
        )
        scores = pca.fit_transform(X)
    
        # 3) Log explained-variance metrics
        total_var = 0.0
        for i, var in enumerate(pca.explained_variance_ratio_, start=1):
            mlflow.log_metric(f"var_deg_{i}", float(var))
            total_var += float(var)
        mlflow.log_metric("var_total", total_var)

        # 4) Build input example & signature
        input_example = pd.DataFrame(X[:3, :], columns=pivot.columns).head(1)
        signature = infer_signature(input_example, pca.transform(input_example))
    
        # 5) Log the model (with signature & python_function flavor)
        mlflow.sklearn.log_model(
            sk_model=pca,
            artifact_path="pca_model",
            input_example=input_example,
            signature=signature
        )

        # 6) Build PCA results records
        evr = pca.explained_variance_ratio_.tolist()
        comps_json = json.dumps(pca.components_.tolist())
        records = []
        for idx, dt in enumerate(dates):
            for comp_idx in range(n_components):
                records.append({
                    "run_id": run_id,
                    "curve_type": curve_type,
                    "n_components": n_components,
                    "explained_variance_ratio": evr,
                    "components": comps_json,
                    "curve_date": dt.isoformat(),
                    "component_index": comp_idx + 1,
                    "score": float(scores[idx, comp_idx]),
                })

        # 7) Batch-insert into curve_pca_results
        BATCH = 500
        for i in range(0, len(records), BATCH):
            chunk = records[i : i + BATCH]
            vals = ",".join(
                f"('{r['run_id']}','{r['curve_type']}',CURRENT_TIMESTAMP,"
                f"{r['n_components']},ARRAY{r['explained_variance_ratio']},"
                f"'{r['components']}'::jsonb,'{r['curve_date']}',"
                f"{r['component_index']},{r['score']})"
                for r in chunk
            )
            insert_sql = f"""
            INSERT INTO curve_pca_results
              (run_id, curve_type, run_timestamp, n_components,
               explained_variance_ratio, components, curve_date,
               component_index, score)
            VALUES {vals}
            ON CONFLICT (run_id, curve_type, curve_date, component_index)
              DO NOTHING;
            """
            ds.query(insert_sql)

    print(f"✅ PCA pipeline done (run_id={run_id}, curve_type={curve_type}, as_of_date={as_of_date})")

# ─── ENTRY POINT ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
    main(window_months, n_components, svd_solver, whiten, tol, curve_type, as_of_date)


🏃 View run trusting-gnu-927 at: http://127.0.0.1:8768/#/experiments/1441/runs/eae5da2fa68e4812b5932581e765fc08
🧪 View experiment at: http://127.0.0.1:8768/#/experiments/1441
✅ PCA pipeline done (run_id=eae5da2fa68e4812b5932581e765fc08, curve_type=US Treasury Par, as_of_date=2025-05-19)
