In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import logging

for handler in logging.root.handlers:
    logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO)

In [3]:
# Set project root
from pathlib import Path

PROJECT_ROOT = Path("..").resolve()
PROJECT_ROOT_LS = [p.name for p in PROJECT_ROOT.iterdir()]
assert "featurologists" in PROJECT_ROOT_LS, f"Not a project root? {PROJECT_ROOT}, pwd: {Path().resolve()}"

In [4]:
from typing import List
import pandas as pd
import pickle
from customer_segmentation_toolkit.data_zoo import download_data_csv

from featurologists.data_transforms import build_client_clusters, clean_client_clusters
from featurologists.models.customer_segmentation import (
    calc_score_roc_auc,
    calc_score_accuracy,
    train_test_split,
    save_model,
    train_lightgbm,
    predict,
    train_xgboost,
)

In [5]:
DATETIME_COLUMNS = ('InvoiceDate',)

def preprocess_input(df, datetime_columns: List[str] = DATETIME_COLUMNS):
    for column in datetime_columns:
        df[column] = pd.to_datetime(df[column])
    return df

In [6]:
df = pd.read_csv(PROJECT_ROOT/'data/output/online_raw.csv')
df = preprocess_input(df)
df

Unnamed: 0,InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
0,569203,79321,CHILLI LIGHTS,48,2011-10-02 10:32:00,4.95,16353.0,United Kingdom
1,569203,21154,RED RETROSPOT OVEN GLOVE,20,2011-10-02 10:32:00,1.25,16353.0,United Kingdom
2,569204,21790,VINTAGE SNAP CARDS,4,2011-10-02 10:43:00,0.85,16591.0,United Kingdom
3,569204,23284,DOORMAT KEEP CALM AND COME IN,15,2011-10-02 10:43:00,7.08,16591.0,United Kingdom
4,569204,23355,HOT WATER BOTTLE KEEP CALM,4,2011-10-02 10:43:00,4.95,16591.0,United Kingdom
...,...,...,...,...,...,...,...,...
170973,581587,22613,PACK OF 20 SPACEBOY NAPKINS,12,2011-12-09 12:50:00,0.85,12680.0,France
170974,581587,22899,CHILDREN'S APRON DOLLY GIRL,6,2011-12-09 12:50:00,2.10,12680.0,France
170975,581587,23254,CHILDRENS CUTLERY DOLLY GIRL,4,2011-12-09 12:50:00,4.15,12680.0,France
170976,581587,23255,CHILDRENS CUTLERY CIRCUS PARADE,4,2011-12-09 12:50:00,4.15,12680.0,France


In [7]:
#MODEL_NAME = 'lightgbm'
MODEL_NAME = 'xgboost'

MODELS_DIR = PROJECT_ROOT / "models" / "customer_segmentation"
MODEL_PATH = MODELS_DIR / MODEL_NAME / 'model.pkl'

with MODEL_PATH.open('rb') as f:
    model = pickle.load(f)
print(f'Loaded model: {MODEL_PATH}')

Loaded model: /plain/github/opensource/Featurologists/models/customer_segmentation/xgboost/model.pkl


In [14]:
# # for a small batch (<250), it fails as this: ValueError: n_samples=2 should be >= n_clusters=11.
# predict(model, df[:10])

In [15]:
# # For a larger batch (>250), it fails like this: Number of columns in data must equal to trained model.
# predict(model, df[:250])

In [16]:
from typing import List, Optional

class BatchBuilder:
    def __init__(self, batch_size: int = 10):
        self._current: List[pd.Series] = []
        self._batch_size = batch_size
    
    def add_input(self, row: pd.Series):
        self._current.append(row[1])
    
    @property
    def is_inference_ready(self) -> bool:
        return len(self._current) >= self._batch_size
    
    def get_batch(self) -> Optional[pd.DataFrame]:
        if not self.is_inference_ready:
            return None
        batch = self._current[:self._batch_size]
        self._current = self._current[self._batch_size:]
        return pd.DataFrame(batch)

In [17]:
# BATCH_SIZE = 250
# builder = BatchBuilder(BATCH_SIZE)

# for i, row in enumerate(df.iterrows(), 1):
#     builder.add_input(row)
#     if i % BATCH_SIZE != 0:
#         assert not builder.is_inference_ready
#         batch = builder.get_batch()
#         assert batch is None, batch
#     else:
#         print(i)
#         assert builder.is_inference_ready
#         batch = builder.get_batch()
#         batch = preprocess_input(batch)
#         #print(batch)
#         r = predict(model, batch)
#         print(f'Inference result: {type(r)} {r}')
#         assert r is not None
#         if i//BATCH_SIZE >= 1:
#             break