Skip to content

Commit

Permalink
feat: add dedup logic (#299)
Browse files Browse the repository at this point in the history
1) Introduces dedup logic
2) Add delay param while fetching data from druid

---------

Signed-off-by: s0nicboOm <i.kushalbatra@gmail.com>
  • Loading branch information
s0nicboOm authored Sep 27, 2023
1 parent 2d99849 commit 2973dd2
Show file tree
Hide file tree
Showing 11 changed files with 402 additions and 120 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.9" ]
python-version: [ "3.11" ]

name: Publish to PyPi
steps:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

#script output
.btoutput/

# Mac related
*.DS_Store

Expand Down
4 changes: 2 additions & 2 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class TrainerConf:
train_hours: int = 24 * 8 # 8 days worth of data
min_train_size: int = 2000
retrain_freq_hr: int = 24
model_expiry_sec: int = 86400 # 24 hrs # TODO: revisit this
dedup_expiry_sec: int = 1800 # 30 days # TODO: revisit this
model_expiry_sec: int = 172800 # 48 hrs # TODO: revisit this
retry_secs: int = 600 # 10 min # TODO: revisit this
batch_size: int = 64
pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf)

Expand Down
10 changes: 10 additions & 0 deletions numalogic/connectors/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,15 @@ def __post_init__(self):

@dataclass
class DruidConf(ConnectorConf):
"""
Class for configuring Druid connector.
Args:
endpoint: Druid endpoint
delay_hrs: Delay in hours for fetching data from Druid
fetcher: DruidFetcherConf
"""

endpoint: str
delay_hrs: float = 3.0
fetcher: DruidFetcherConf = MISSING
60 changes: 34 additions & 26 deletions numalogic/connectors/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def build_params(
filter_pairs: dict,
granularity: str,
hours: float,
delay: float,
) -> dict:
"""
Expand All @@ -48,6 +49,7 @@ def build_params(
data to include in the query
granularity: Time bucket to aggregate data by hour, day, minute, etc.,
hours: Hours from now to skip training.
delay: Added delay to the fetch query from current time.
Returns: a dict of parameters
Expand All @@ -56,7 +58,9 @@ def build_params(
type="and",
fields=[Filter(type="selector", dimension=k, value=v) for k, v in filter_pairs.items()],
)
end_dt = datetime.now(pytz.utc)
end_dt = datetime.now(pytz.utc) - timedelta(hours=delay)
_LOGGER.debug("Querying with end_dt: %s, that is with delay of %s hrs", end_dt, delay)

start_dt = end_dt - timedelta(hours=hours)

intervals = [f"{start_dt.isoformat()}/{end_dt.isoformat()}"]
Expand Down Expand Up @@ -98,6 +102,7 @@ def fetch(
filter_keys: list[str],
filter_values: list[str],
dimensions: list[str],
delay: float = 3.0,
granularity: str = "minute",
aggregations: Optional[dict] = None,
group_by: Optional[list[str]] = None,
Expand All @@ -107,33 +112,36 @@ def fetch(
_start_time = time.perf_counter()
filter_pairs = make_filter_pairs(filter_keys, filter_values)
query_params = build_params(
aggregations, datasource, dimensions, filter_pairs, granularity, hours
aggregations, datasource, dimensions, filter_pairs, granularity, hours, delay
)

response = self.client.groupby(**query_params)
df = response.export_pandas()

if df is None or df.shape[0] == 0:
logging.warning("No data found for keys %s", filter_pairs)
try:
response = self.client.groupby(**query_params)
except Exception:
_LOGGER.exception("Problem with getting response from client")
return pd.DataFrame()

df["timestamp"] = pd.to_datetime(df["timestamp"]).astype("int64") // 10**6

if group_by:
df = df.groupby(by=group_by).sum().reset_index()

if pivot.columns:
df = df.pivot(
index=pivot.index,
columns=pivot.columns,
values=pivot.value,
)
df.columns = df.columns.map("{0[1]}".format)
df.reset_index(inplace=True)

_end_time = time.perf_counter() - _start_time
_LOGGER.debug("Druid query latency: %.6fs", _end_time)
return df
else:
df = response.export_pandas()
if df.empty or df.shape[0] == 0:
logging.warning("No data found for keys %s", filter_pairs)
return pd.DataFrame()

df["timestamp"] = pd.to_datetime(df["timestamp"]).astype("int64") // 10**6

if group_by:
df = df.groupby(by=group_by).sum().reset_index()

if pivot.columns:
df = df.pivot(
index=pivot.index,
columns=pivot.columns,
values=pivot.value,
)
df.columns = df.columns.map("{0[1]}".format)
df.reset_index(inplace=True)

_end_time = time.perf_counter() - _start_time
_LOGGER.debug("Druid query latency: %.6fs", _end_time)
return df

def raw_fetch(self, *args, **kwargs) -> pd.DataFrame:
raise NotImplementedError
101 changes: 100 additions & 1 deletion numalogic/udfs/tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
from dataclasses import replace
import time
from typing import Optional

import numpy as np
import pandas as pd
from pandas import DataFrame
from redis import RedisError

from numalogic.registry import ArtifactManager, ArtifactData
from numalogic.tools.exceptions import RedisRegistryError
from numalogic.tools.types import KEYS
from numalogic.tools.types import KEYS, redis_client_t
from numalogic.udfs._config import StreamConf
from numalogic.udfs.entities import StreamPayload

Expand Down Expand Up @@ -144,3 +146,100 @@ def _load_artifact(
},
)
return artifact, payload


class TrainMsgDeduplicator:
"""
TrainMsgDeduplicator class is used to deduplicate the train messages.
Args:
r_client: Redis client.
"""

__slots__ = "client"

def __init__(self, r_client: redis_client_t):
self.client = r_client

@staticmethod
def __construct_key(keys: KEYS) -> str:
return f"TRAIN::{':'.join(keys)}"

def __fetch_ts(self, key: str) -> tuple[Optional[str], Optional[str]]:
try:
data = self.client.hgetall(key)
except RedisError:
_LOGGER.exception("Problem fetching ts information for the key: %s", key)
return None, None
else:
# decode the key:value pair and update the values
data = {key.decode(): data.get(key).decode() for key in data}
_msg_read_ts = str(data["_msg_read_ts"]) if data and "_msg_read_ts" in data else None
_msg_train_ts = str(data["_msg_train_ts"]) if data and "_msg_train_ts" in data else None
return _msg_read_ts, _msg_train_ts

def ack_read(self, key: KEYS, uuid: str, retrain_freq: int = 24, retry: int = 600) -> bool:
"""
Acknowledge the read message. Return True wh`en the msg has to be trained.
Args:
key: key
uuid: uuid.
retrain_freq: retrain frequency for the model in hrs
retry: Time difference(in secs) between triggering retraining and msg read_ack.
Returns
-------
bool
"""
_key = self.__construct_key(key)
_msg_read_ts, _msg_train_ts = self.__fetch_ts(key=_key)
if _msg_read_ts and time.time() - float(_msg_read_ts) < retry:
_LOGGER.info("%s - Model with key : %s is being trained by another process", uuid, key)
return False

# This check is needed if there is backpressure in the pl.
if _msg_train_ts and time.time() - float(_msg_train_ts) < retrain_freq * 60 * 60:
_LOGGER.info(
"%s - Model was saved for the key: %s in less than %s secs, skipping training",
uuid,
key,
retrain_freq,
)
return False
try:
self.client.hset(name=_key, key="_msg_read_ts", value=str(time.time()))
except RedisError:
_LOGGER.exception(
"%s - Problem while updating msg_read_ts information for the key: %s",
uuid,
key,
)
return False
_LOGGER.info("%s - Acknowledging request for Training for key : %s", uuid, key)
return True

def ack_train(self, key: KEYS, uuid: str) -> bool:
"""
Acknowledge the train message is trained and saved. Return True when
_msg_train_ts is updated.
Args:
key: key
uuid: uuid.
Returns
-------
bool
"""
_key = self.__construct_key(key)
try:
self.client.hset(name=_key, key="_msg_train_ts", value=str(time.time()))
except RedisError:
_LOGGER.exception(
" %s - Problem while updating msg_train_ts information for the key: %s",
uuid,
key,
)
return False
else:
_LOGGER.info("%s - Acknowledging model saving complete for for the key: %s", uuid, key)
return True
32 changes: 19 additions & 13 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from numalogic.udfs import NumalogicUDF
from numalogic.udfs._config import StreamConf, PipelineConf
from numalogic.udfs.entities import TrainerPayload
from numalogic.udfs.tools import TrainMsgDeduplicator

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
self._model_factory = ModelFactory()
self._preproc_factory = PreprocessFactory()
self._thresh_factory = ThresholdFactory()
self.train_msg_deduplicator = TrainMsgDeduplicator(r_client)

def register_conf(self, config_id: str, conf: StreamConf) -> None:
"""
Expand Down Expand Up @@ -163,7 +165,17 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Construct payload object
payload = TrainerPayload(**orjson.loads(datum.value))
if not self._is_new_request(payload):
_conf = self.get_conf(payload.config_id)

# set the retry and retrain_freq
retrain_freq_ts = _conf.numalogic_conf.trainer.retrain_freq_hr
retry_ts = _conf.numalogic_conf.trainer.retry_secs
if not self.train_msg_deduplicator.ack_read(
key=payload.composite_keys,
uuid=payload.uuid,
retrain_freq=retrain_freq_ts,
retry=retry_ts,
):
return Messages(Message.to_drop())

# Fetch data
Expand All @@ -182,7 +194,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# Construct feature array
x_train = self.get_feature_arr(df, payload.metrics)
_conf = self.get_conf(payload.config_id)

# Initialize artifacts
preproc_clf = self._construct_preproc_clf(_conf)
Expand All @@ -207,6 +218,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:
model_registry=self.model_registry,
payload=payload,
)
if self.train_msg_deduplicator.ack_train(key=payload.composite_keys, uuid=payload.uuid):
_LOGGER.info(
"%s - Model trained and saved successfully.",
payload.uuid,
)

_LOGGER.debug(
"%s - Time taken in trainer: %.4f sec", payload.uuid, time.perf_counter() - _start_time
Expand Down Expand Up @@ -266,17 +282,6 @@ def _is_data_sufficient(self, payload: TrainerPayload, df: pd.DataFrame) -> bool
_conf = self.get_conf(payload.config_id)
return len(df) > _conf.numalogic_conf.trainer.min_train_size

# TODO: improve the dedup logic; this is too naive
def _is_new_request(self, payload: TrainerPayload) -> bool:
_conf = self.get_conf(payload.config_id)
_ckeys = ":".join(payload.composite_keys)
r_key = f"train::{_ckeys}"
value = self.r_client.get(r_key)
if value:
return False
self.r_client.setex(r_key, time=_conf.numalogic_conf.trainer.dedup_expiry_sec, value=1)
return True

@staticmethod
def get_feature_arr(
raw_df: pd.DataFrame, metrics: list[str], fill_value: float = 0.0
Expand Down Expand Up @@ -309,6 +314,7 @@ def fetch_data(self, payload: TrainerPayload) -> pd.DataFrame:
filter_keys=_conf.composite_keys,
filter_values=payload.composite_keys,
dimensions=list(self.druid_conf.fetcher.dimensions),
delay=self.druid_conf.delay_hrs,
granularity=self.druid_conf.fetcher.granularity,
aggregations=dict(self.druid_conf.fetcher.aggregations),
group_by=list(self.druid_conf.fetcher.group_by),
Expand Down
Loading

0 comments on commit 2973dd2

Please sign in to comment.