In [None]:
import os, time
from collections import deque, defaultdict
from dataclasses import dataclass, field
from datetime import timedelta
from urllib.parse import quote_plus

import numpy as np
import pandas as pd
from dateutil import tz
from sqlalchemy import create_engine, text
import joblib

# ===================== CONFIG =====================
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASS = quote_plus(os.getenv("DB_PASSWORD", ""))

ENGINE = create_engine(f"postgresql+psycopg2://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}")

MODEL_PATH   = os.getenv("MODEL_PATH", "xgb_bps_h1s_v3.joblib")  
POLL_SECONDS = int(os.getenv("POLL_SECONDS", "1"))            
WARMUP_SEC   = int(os.getenv("WARMUP_SEC", "10"))             
SAVE_TO_DB   = os.getenv("SAVE_TO_DB", "true").lower() == "true"
MODEL_VERSION= os.getenv("MODEL_VERSION", "xgb_bps_h1s_v3")

# mismas FEATURES que usaste al entrenar:
FEATURES = [
  "throughput_bps_t","pps_t",
  "thr_lag1","thr_lag2","thr_lag3","thr_lag5",
  "pps_lag1","pps_lag2","pps_lag3","pps_lag5",
  "thr_ma_5","thr_std_5","pps_ma_5","thr_slope5"
]

# ================== ESTADO POR FLUJO ==================
@dataclass
class FlowState:
    thr: deque = field(default_factory=lambda: deque(maxlen=5))   # throughput_bps_t
    pps: deque = field(default_factory=lambda: deque(maxlen=5))   # pps_t
    ts : deque = field(default_factory=lambda: deque(maxlen=5))
    last_packets: int | None = None
    last_ts: pd.Timestamp | None = None

    def update_from_row(self, ts, throughput_bps_t, packets):
        """Actualiza pps desde 'packets' acumulativo usando deltat para robustez."""
        ts = pd.Timestamp(ts)  # tz-aware OK
        if self.last_packets is not None and self.last_ts is not None:
            dt_s = (ts - self.last_ts).total_seconds()
            if dt_s > 0 and packets is not None:
                dp = max(packets - self.last_packets, 0)
                pps_t = dp / dt_s
            else:
                pps_t = None
        else:
            pps_t = None

        self.last_packets = packets
        self.last_ts = ts

        # Guardar valores
        self.ts.append(ts)
        self.thr.append(float(throughput_bps_t) if throughput_bps_t is not None else 0.0)
        self.pps.append(float(pps_t) if pps_t is not None else np.nan)

    def feature_vector(self):
        """Construye el vector FEATURES si hay suficiente historia; si no, devuelve None."""
        if len(self.thr) < 5 or len(self.pps) < 5:
            return None

        # Convertir a arrays para cálculos
        thr_arr = np.array(self.thr, dtype=float)
        pps_arr = np.array(self.pps, dtype=float)

        # lags (1,2,3,5) → posiciones desde el final
        thr_lag1 = thr_arr[-2]
        thr_lag2 = thr_arr[-3]
        thr_lag3 = thr_arr[-4]
        thr_lag5 = thr_arr[0]

        pps_lag1 = pps_arr[-2]
        pps_lag2 = pps_arr[-3]
        pps_lag3 = pps_arr[-4]
        pps_lag5 = pps_arr[0]

        # ventanas sobre TODO el buffer (5 seg)
        thr_ma_5  = float(np.nanmean(thr_arr))
        thr_std_5 = float(np.nanstd(thr_arr, ddof=1)) if np.isfinite(thr_arr).sum() >= 2 else 0.0
        pps_ma_5  = float(np.nanmean(pps_arr))

        # pendiente (slope) en 5 s: diff(5)/5
        thr_slope5 = (thr_arr[-1] - thr_arr[0]) / 5.0

        current = {
            "throughput_bps_t": thr_arr[-1],
            "pps_t":            pps_arr[-1],
            "thr_lag1": thr_lag1, "thr_lag2": thr_lag2, "thr_lag3": thr_lag3, "thr_lag5": thr_lag5,
            "pps_lag1": pps_lag1, "pps_lag2": pps_lag2, "pps_lag3": pps_lag3, "pps_lag5": pps_lag5,
            "thr_ma_5": thr_ma_5, "thr_std_5": thr_std_5, "pps_ma_5": pps_ma_5, "thr_slope5": thr_slope5,
        }

        # Respeta el orden de FEATURES
        return np.array([current[k] for k in FEATURES], dtype=float)

# ================== HELPERS SQL ==================
def fetch_rows_since(since_ts: pd.Timestamp, limit=5000):
    """
    Trae filas nuevas de flow_metrics_logs desde since_ts (exclusivo), ordenadas por ts.
    Sólo lo necesario para features: ts, 5-tupla, throughput, packets.
    """
    sql = text("""
        SELECT
          ts,
          src_ip::text AS src_ip,
          dst_ip::text AS dst_ip,
          src_port, dst_port, protocol,
          throughput::double precision AS throughput_bps_t,
          packets::bigint AS packets
        FROM flow_metrics_logs
        WHERE ts > :since
        ORDER BY ts
        LIMIT :lim;
    """)
    with ENGINE.connect() as con:
        df = pd.read_sql(sql, con, params={"since": since_ts, "lim": limit})
    # Asegurar tz-aware
    if not pd.api.types.is_datetime64tz_dtype(df["ts"]):
        df["ts"] = pd.to_datetime(df["ts"], utc=True)
    return df

def save_prediction(row, yhat):
    if not SAVE_TO_DB:
        return
    sql = text("""
        INSERT INTO throughput_forecast_h1s
        (ts, pred_for_ts, src_ip, dst_ip, src_port, dst_port, protocol, yhat_bps_next_1s, model_version)
        VALUES (:ts, :pred_ts, :src_ip, :dst_ip, :src_port, :dst_port, :protocol, :yhat, :model_version)
    """)
    with ENGINE.begin() as con:
        con.execute(sql, {
            "ts": row["ts"],
            "pred_ts": row["ts"] + pd.Timedelta(seconds=1),
            "src_ip": row["src_ip"], "dst_ip": row["dst_ip"],
            "src_port": int(row["src_port"]), "dst_port": int(row["dst_port"]),
            "protocol": int(row["protocol"]),
            "yhat": float(yhat),
            "model_version": MODEL_VERSION
        })

# ================== MAIN LOOP ==================
def main():
    print("Cargando modelo:", MODEL_PATH)
    model = joblib.load(MODEL_PATH)

    # Estado por flujo
    states: dict[str, FlowState] = defaultdict(FlowState)

    # Warmup: arrancamos unos segundos atrás para llenar buffers
    now = pd.Timestamp.utcnow().tz_localize("UTC") if pd.Timestamp.utcnow().tzinfo is None else pd.Timestamp.utcnow()
    since_ts = now - pd.Timedelta(seconds=WARMUP_SEC)

    print(f"Arrancando desde {since_ts.isoformat()} (warmup {WARMUP_SEC}s).")
    last_seen = since_ts

    while True:
        try:
            df = fetch_rows_since(last_seen)
            if not df.empty:
                # procesar en orden cronológico
                for _, r in df.iterrows():
                    flow_id = f"{r['src_ip']}:{r['src_port']}>{r['dst_ip']}:{r['dst_port']}/{r['protocol']}"
                    st = states[flow_id]
                    # actualizar estado con la muestra actual
                    st.update_from_row(r["ts"], r["throughput_bps_t"], r["packets"])

                    # construir features si hay buffer suficiente
                    x = st.feature_vector()
                    if x is not None and np.isfinite(x).all():
                        yhat = float(model.predict(x.reshape(1, -1))[0])

                        # log consola (en Mbps)
                        print(f"{r['ts']} {flow_id}  thr_now={st.thr[-1]/1e6:.2f} Mbps  "
                              f"pps={st.pps[-1]:.0f}  yhat(+1s)={yhat/1e6:.2f} Mbps")

                        # persistir (opcional)
                        save_prediction(r, yhat)

                    # avanzar puntero
                    last_seen = max(last_seen, r["ts"])

            time.sleep(POLL_SECONDS)

        except KeyboardInterrupt:
            print("Detenido por usuario.")
            break
        except Exception as e:
            # No te quedes mudo si algo se rompe: logueá y seguí
            print("Error en loop:", repr(e))
            time.sleep(POLL_SECONDS)

if __name__ == "__main__":
    main()
