In [10]:
a = set([28 * i for i in range(1, 10000)])
b = set([39 * i for i in range(1, 10000)])

sorted(list(a & b))[0]

1092

In [75]:
from prj.config import DATA_DIR
from prj.data.data_loader import DataConfig, DataLoader


data_args = {}
config = DataConfig(**data_args)
loader = DataLoader(data_dir=DATA_DIR, config=config)
test_ds = loader.load(1600, 1605).collect()

In [3]:
import numpy as np

np.random.rand(10).clip(-5, 5)

array([0.19342295, 0.40621698, 0.20757085, 0.35800211, 0.65921724,
       0.36698247, 0.27399136, 0.81355452, 0.19647356, 0.63540672])

In [76]:
from tqdm import tqdm
import polars as pl

def online_iterator(df: pl.DataFrame, show_progress: bool = True):
    assert df.select('date_id').n_unique() > 1, 'Dataset must contain at least 2 days'
    
    df_date_time_id = df.select('date_id', 'time_id').unique().sort('date_id', 'time_id').with_row_index('date_time_id')
    df = df.join(df_date_time_id, on=['date_id', 'time_id'], how='left', maintain_order='left')
    
    max_date_time_id = df_date_time_id['date_time_id'].max()
    min_date_id = df.select('date_id').min().item()
    
    responders = [f'responder_{i}' for i in range(9)]
    
    curr_idx:int = df_date_time_id.filter(pl.col('date_id').eq(min_date_id + 1))['date_time_id'].min()
    old_day = min_date_id

    
    with tqdm(total=max_date_time_id - curr_idx + 1, disable=not show_progress) as pbar:
        while curr_idx <= max_date_time_id:
            curr_day = df_date_time_id[curr_idx]['date_id'].item()
            is_new_day = curr_day != old_day
            lags = None
            if is_new_day:
                lags = df.filter(pl.col('date_id').eq(old_day)).select(pl.col('date_id').add(1), 'time_id', 'symbol_id', *[pl.col(r).alias(f'{r}_lag_1') for r in responders])
            
            old_day = curr_day

            batch = df.filter(pl.col('date_time_id').eq(curr_idx)).with_columns(pl.lit(True).alias('is_scored')).drop('date_time_id')
            
            yield batch, lags if lags is not None else None
            
            curr_idx += 1
            pbar.update(1)


In [77]:
for batch, lags in online_iterator(test_ds):
    break

  0%|          | 0/4840 [00:00<?, ?it/s]


In [78]:
from collections import defaultdict
import polars as pl
from river.drift import ADWIN

class DriftDetector:
    def __init__(self, features, drift_detector_factory=ADWIN, delta=0.05, drift_threshold=0.5):
        self.detectors = defaultdict(lambda: drift_detector_factory(delta=delta))
        self.drift_detector_factory = drift_detector_factory
        self.delta = delta
        self.features = features
        self.drift_threshold=drift_threshold

    def update(self, data: pl.DataFrame):
        drift_results = defaultdict(dict)
        data = data.select('date_id', 'time_id', 'symbol_id', *self.features)
        for row in data.iter_rows(named=True):
            symbol = row['symbol_id']

            for feature in self.features:
                value = row[feature]
                if value is None:
                    continue
            
                detector = self.detectors[(symbol, feature)]
                detector.update(value)
                
                drift_detected = detector.drift_detected
                drift_results[symbol][feature] = drift_detected

        return drift_results

    def check_drift(self, drift_results):
        drift_summary = {}

        for symbol, features in drift_results.items():
            drift_summary[symbol] = False
            total_features = len(features)
            drifted_features = sum(drift_detected for drift_detected in features.values())
            if drifted_features / total_features >= self.drift_threshold:
                drift_summary[symbol] = True

        return drift_summary


In [79]:
detector = DriftDetector(features=loader.features)
i = 0
for batch, lags in online_iterator(test_ds):
    if lags is not None:
        print(batch['date_id'].min())
    drift_result = detector.update(batch)
    drift_summary = detector.check_drift(drift_result)
    symbols_drifted = [k for k, v in drift_summary.items() if v]
    if len(symbols_drifted) > 0:
        print(symbols_drifted, f"Date: {batch['date_id'].min()}")

  0%|          | 0/4840 [00:00<?, ?it/s]

  0%|          | 13/4840 [00:00<00:37, 129.77it/s]

1601


 20%|██        | 979/4840 [00:14<00:56, 68.22it/s]

1602


 21%|██        | 1005/4840 [00:15<01:00, 63.90it/s]

[7, 35, 36] Date: 1602


 40%|████      | 1946/4840 [00:29<00:39, 72.58it/s]

1603


 41%|████      | 1962/4840 [00:29<00:45, 62.90it/s]

[7, 11, 14, 20] Date: 1603
[16] Date: 1603


 60%|██████    | 2912/4840 [00:44<00:31, 61.24it/s]

1604


 80%|████████  | 3885/4840 [00:59<00:18, 53.02it/s]

1605


 81%|████████  | 3915/4840 [01:00<00:15, 59.27it/s]

[24] Date: 1605


100%|██████████| 4840/4840 [01:15<00:00, 64.05it/s]
