In [16]:
import asyncio
import json
import os
import signal
import time
import uuid
from dataclasses import dataclass
from typing import Optional, List

import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field, conint, confloat
from sdv.utils import load_synthesizer
from sdv.sampling import Condition

from aiokafka import AIOKafkaProducer

from datetime import datetime, timezone
import uuid

### Sampling helper (Bernoulli per row; conditional features)

In [10]:
def sample_with_base_rate(synthesizer, n_rows, fraud_rate=0.001727, rng=None, shuffle=True, output_path=None):
    """
    Sample n_rows where each row's Class ~ Bernoulli(fraud_rate).
    Features are sampled conditionally on Class to preserve structure.

    Parameters
    ----------
    synthesizer : fitted SDV synthesizer
    fraud_rate  : float in [0, 1]  (P(Class=1)). The observed rate in the original data set is 0.001727, that's why it is the default value
    n_rows      : int >= 0
    rng         : None | int | np.random.Generator  (for reproducibility)
    shuffle     : bool  (shuffle final rows)
    output_path : optional CSV path

    Returns
    -------
    pandas.DataFrame
    """
    if not (0.0 <= float(fraud_rate) <= 1.0):
        raise ValueError("fraud_rate must be in [0, 1].")

    # RNG setup
    if isinstance(rng, (int, np.integer)) or rng is None:
        rng = np.random.default_rng(rng)

    n1 = int(rng.binomial(n_rows, float(fraud_rate)))  # number of frauds
    n0 = n_rows - n1

    # Build conditions; skip zero-sized ones
    conditions = []
    if n0:
        conditions.append(Condition(num_rows=n0, column_values={"Class": 0}))
    if n1:
        conditions.append(Condition(num_rows=n1, column_values={"Class": 1}))

    # Edge case: n_rows == 0: return empty with correct schema
    if n_rows == 0:
        # CHANGED: include the new columns in the empty schema
        cols = list(synthesizer.get_metadata().get_table_columns()['creditcard'].keys())  # existing cols
        cols = ['transaction_id', 'event_time', 'event_time_ms'] + cols
        return pd.DataFrame(columns=cols)  

    out = synthesizer.sample_from_conditions(conditions=conditions)

    # Ensure 0/1 dtype
    out["Class"] = out["Class"].astype(int)

    # Add unique ID and creation timestamps (same “now” for the whole batch)
    _now = datetime.now(timezone.utc)
    _iso = _now.isoformat(timespec='milliseconds').replace('+00:00', 'Z')
    out.insert(0, "transaction_id", [uuid.uuid4().hex for _ in range(len(out))])
    out["event_time"] = _iso

    # Optional shuffle
    if shuffle and len(out) > 1:
        seed = int(rng.integers(0, 2**32 - 1))
        out = out.sample(frac=1, random_state=seed).reset_index(drop=True)

    if output_path:
        out.to_csv(output_path, index=False)

    return out

### Test function

In [11]:
synth = load_synthesizer('artifacts/creditcard_fraud_gc.pkl')
aaa  = sample_with_base_rate(synth, 1000)

Sampling conditions: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 9285.58it/s]


In [12]:
aaa

Unnamed: 0,transaction_id,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V22,V23,V24,V25,V26,V27,V28,Amount,Class,event_time
0,fd024c43643642a9b65d87fe518774f7,0.290854,-8.604322,-1.682051,-0.768385,2.944125,-0.372607,0.441515,-0.216832,-0.431189,...,0.627002,-0.260430,-0.668087,0.256132,0.234887,-0.130587,-0.199543,88.28,0,2025-11-01T19:00:49.061Z
1,eaa4d111fab843f5883d03ffe8d00804,-0.732830,11.366916,0.218906,1.717499,-0.191105,-0.655167,-0.608125,1.269205,-0.306390,...,0.666837,0.245971,-1.192919,-0.737294,-0.279269,-0.195855,0.023919,1.26,0,2025-11-01T19:00:49.061Z
2,a9dad64fe2854d92b1bd67394b016271,-0.825033,1.034379,-0.904274,-1.722980,-0.921776,1.282451,2.067691,0.009513,-0.062712,...,-0.387561,-0.617610,0.236919,0.222657,-0.673210,0.185754,0.375049,109.72,0,2025-11-01T19:00:49.061Z
3,db811b2dd4144d5492cbed415128b72c,1.247421,-3.469088,-0.927453,-1.802013,2.041977,-2.843599,-0.925241,0.489958,-0.761986,...,-0.635226,0.214382,0.118246,-0.570335,0.315467,0.437640,-0.151240,22.09,0,2025-11-01T19:00:49.061Z
4,e922729cf2f144cfa8eeb6dd64be3e5d,-0.580632,5.770076,1.566895,-1.532929,0.884028,-2.068375,-2.120938,-1.653823,-0.304299,...,-0.650926,-0.708991,0.005008,0.784977,0.003713,-0.361025,0.305938,8.59,0,2025-11-01T19:00:49.061Z
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,469bece0602349219889128f8d4c77df,1.090209,1.012349,-0.171674,1.546281,1.974564,-0.443722,-1.166414,-1.522947,-0.751120,...,-0.632999,1.189753,0.697905,-0.948299,-0.351684,0.382411,0.259882,0.78,0,2025-11-01T19:00:49.061Z
996,737eb3ea751e44ff8838b6f0315a8120,2.100561,0.633500,-0.810998,-1.357212,-1.323528,-1.627730,-2.055878,2.232007,-3.207830,...,-0.577021,-0.278050,0.547898,-0.479639,0.283861,-0.177710,-0.431317,24.90,0,2025-11-01T19:00:49.061Z
997,61c7808b751d4322819734fccec0e2fa,-0.003531,7.624974,-2.296916,-0.020521,-2.718087,0.276351,-1.291421,0.641015,1.051000,...,0.241381,0.832332,-0.316236,0.409541,0.400891,-0.046193,0.602697,111.20,0,2025-11-01T19:00:49.061Z
998,c924c2a03f864cd1a6004a1c8c7f7f57,-2.252084,7.891367,-1.283046,-0.059365,0.940391,0.464934,0.689080,0.404670,0.594212,...,1.238146,0.888837,-0.372390,-0.862529,0.323151,-0.329247,-0.159937,106.25,0,2025-11-01T19:00:49.061Z


### Setting

In [13]:
@dataclass
class Settings:
    KAFKA_BOOTSTRAP: str = os.getenv("KAFKA_BOOTSTRAP", "localhost:9092")
    KAFKA_TOPIC: str = os.getenv("KAFKA_TOPIC", "creditcard-transactions")
    SYNTH_PATH: str = os.getenv("SYNTH_PATH", "artifacts/creditcard_fraud_gc.pkl")
    FRAUD_RATE: float = float(os.getenv("FRAUD_RATE", "0.001727"))  # set from your data
    INTERVAL_SECS: int = int(os.getenv("INTERVAL_SECS", "40"))
    BATCH_MIN: int = int(os.getenv("BATCH_MIN", "3"))
    BATCH_MAX: int = int(os.getenv("BATCH_MAX", "30"))
    RNG_SEED: Optional[int] = int(os.getenv("RNG_SEED")) if os.getenv("RNG_SEED") else None
    AUTO_START: bool = os.getenv("AUTO_START", "true").lower() in ("1", "true", "yes")


settings = Settings()

### Kafka (aiokafka)

In [14]:
def to_jsonable_records(df: pd.DataFrame) -> List[dict]:
    """Convert DataFrame rows to pure-Python types that json can serialize."""
    import numpy as np
    records = df.to_dict(orient="records")

    def pyify(x):
        if isinstance(x, (np.floating, np.integer)):
            return x.item()
        if isinstance(x, (np.bool_,)):
            return bool(x)
        return x

    for r in records:
        for k, v in list(r.items()):
            r[k] = pyify(v)
    return records


### FastAPI app

In [None]:
app = FastAPI(title="SDV: Kafka Pusher", version="1.0.0")

class StartRequest(BaseModel):
    interval_secs: Optional[conint(ge=1)] = Field(None, description="Override interval in seconds")
    fraud_rate: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description="Override fraud base rate")
    batch_min: Optional[conint(ge=1)] = None
    batch_max: Optional[conint(ge=1)] = None


class StatusResponse(BaseModel):
    running: bool
    last_sent_at_epoch: Optional[float]
    last_batch_size: Optional[int]
    interval_secs: int
    fraud_rate: float
    batch_min: int
    batch_max: int
    topic: str
    bootstrap: str


class Streamer:
    def __init__(self, settings: Settings):
        self.settings = settings
        self._task: Optional[asyncio.Task] = None
        self._producer: Optional[AIOKafkaProducer] = None
        self._rng = np.random.default_rng(settings.RNG_SEED)
        self._synth = load_synthesizer(self.settings.SYNTH_PATH)
        self._running = False
        self._last_sent_at: Optional[float] = None
        self._last_batch_size: Optional[int] = None

    async def start(self,
                    interval_secs: Optional[int] = None,
                    fraud_rate: Optional[float] = None,
                    batch_min: Optional[int] = None,
                    batch_max: Optional[int] = None):
        if self._running:
            return
        # apply overrides
        if interval_secs is not None:
            self.settings.INTERVAL_SECS = interval_secs
        if fraud_rate is not None:
            self.settings.FRAUD_RATE = fraud_rate
        if batch_min is not None:
            self.settings.BATCH_MIN = batch_min
        if batch_max is not None:
            self.settings.BATCH_MAX = batch_max
        if self.settings.BATCH_MIN > self.settings.BATCH_MAX:
            raise ValueError("BATCH_MIN must be <= BATCH_MAX")

        self._producer = AIOKafkaProducer(
            bootstrap_servers=self.settings.KAFKA_BOOTSTRAP,
            linger_ms=5,
            acks="all",
            enable_idempotence=True,  # at-least-once w/ dedup on broker
            value_serializer=lambda v: json.dumps(v).encode("utf-8"),
            key_serializer=lambda k: k.encode("utf-8"),
        )
        await self._producer.start()
        self._running = True
        self._task = asyncio.create_task(self._run_loop(), name="kafka-push-loop")

    async def stop(self):
        self._running = False
        if self._task:
            self._task.cancel()
            with contextlib.suppress(asyncio.CancelledError):
                await self._task
            self._task = None
        if self._producer:
            await self._producer.stop()
            self._producer = None

    async def _run_loop(self):
        try:
            while self._running:
                # random batch size in [BATCH_MIN, BATCH_MAX]
                n = int(self._rng.integers(self.settings.BATCH_MIN, self.settings.BATCH_MAX + 1))
                batch = sample_with_base_rate(
                    synthesizer=self._synth,
                    fraud_rate=self.settings.FRAUD_RATE,
                    n_rows=n,
                    rng=self._rng,
                    shuffle=True,
                )
                records = to_jsonable_records(batch)

                # publish one message per row (each with a unique key)
                send_futs = []
                now = int(time.time() * 1000)
                for rec in records:
                    rec["_event_time_ms"] = now
                    rec["_source"] = "sdv"
                    rec["_schema_version"] = 1
                    key = str(uuid.uuid4())
                    send_futs.append(
                        self._producer.send_and_wait(
                            topic=self.settings.KAFKA_TOPIC,
                            key=key,
                            value=rec,
                        )
                    )
                if send_futs:
                    await asyncio.gather(*send_futs)

                self._last_sent_at = time.time()
                self._last_batch_size = len(records)

                await asyncio.sleep(self.settings.INTERVAL_SECS)
        except asyncio.CancelledError:
            # normal shutdown
            pass
        except Exception as e:
            # log then stop running so /status reports correctly
            print(f"[streamer] error: {e}", flush=True)
            self._running = False

    def status(self) -> StatusResponse:
        return StatusResponse(
            running=self._running,
            last_sent_at_epoch=self._last_sent_at,
            last_batch_size=self._last_batch_size,
            interval_secs=self.settings.INTERVAL_SECS,
            fraud_rate=self.settings.FRAUD_RATE,
            batch_min=self.settings.BATCH_MIN,
            batch_max=self.settings.BATCH_MAX,
            topic=self.settings.KAFKA_TOPIC,
            bootstrap=self.settings.KAFKA_BOOTSTRAP,
        )


import contextlib
streamer = Streamer(settings)

@asynccontextmanager
async def lifespan(app: FastAPI):
    if settings.AUTO_START:
        await streamer.start()
    try:
        yield
    finally:
        await streamer.stop()

app.router.lifespan_context = lifespan


# ======================
# API
# ======================

@app.get("/health")
def health():
    return {"ok": True}

@app.get("/status", response_model=StatusResponse)
def status():
    return streamer.status()

@app.post("/start")
async def start(req: StartRequest):
    try:
        await streamer.start(
            interval_secs=req.interval_secs,
            fraud_rate=req.fraud_rate,
            batch_min=req.batch_min,
            batch_max=req.batch_max,
        )
        return streamer.status()
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.post("/stop")
async def stop():
    await streamer.stop()
    return streamer.status()
