In [30]:
import polars as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import gc
from prj.config import DATA_DIR, EXP_DIR
from tqdm import tqdm
from prj.agents.factory import AgentsFactory
import os

In [None]:
agent_base_dir = EXP_DIR / 'saved' / 'oamp'


agents_dict = [
    {'agent_type': 'lgbm', 'load_path': os.path.join(agent_base_dir, "lgbm_1seeds_1020_1529-0.2_20241218_104118", "train", "model")},
    {'agent_type': 'catboost', 'load_path': os.path.join(agent_base_dir, "catboost_1seeds_1020_1529-0.2_20241218_104022", "train", "model")},
    {'agent_type': 'xgb', 'load_path': os.path.join(agent_base_dir, "xgb_1seeds_1020_1529-0.2_20241218_104101", "train", "model")},
]
agents = [AgentsFactory.load_agent(agent_dict) for agent_dict in agents_dict]

len(agents)

In [None]:
from prj.config import DATA_DIR
from prj.data.data_loader import PARTITIONS_DATE_INFO, DataConfig, DataLoader

data_args = {'zero_fill': False}
config = DataConfig(**data_args)
loader = DataLoader(data_dir=DATA_DIR, config=config)

start_dt, end_dt = PARTITIONS_DATE_INFO[9]['min_date'], PARTITIONS_DATE_INFO[9]['max_date']



X, y, w, info = loader.load_numpy(start_dt=start_dt, end_dt=end_dt)  

f = 100000
X = X[:f, :]
y = y[:f]
w = w[:f]
info = info[:f, :]

X.shape, y.shape, w.shape      

In [None]:
def squared_weighted_error_loss_fn(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray) -> np.ndarray:
    return w.reshape(-1, 1) * ((y_true.reshape(-1, 1) - y_pred_agents) ** 2)

def absolute_weighted_error_loss_fn(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray) -> np.ndarray:
    return w.reshape(-1, 1) * np.abs(y_true.reshape(-1, 1) - y_pred_agents)

def compute_loss(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray, loss_fn) -> np.ndarray:
    assert y_true.shape[0] == y_pred_agents.shape[0]
    if w is None:
        w = np.ones_like(y_true)
    assert y_true.shape[0] == w.shape[0]

    return loss_fn(y_true, y_pred_agents, w)


agents_predictions = np.concatenate([agent.predict(X).reshape(-1, 1) for agent in tqdm(agents)], axis=1)
agents_losses = compute_loss(y, agents_predictions, w, absolute_weighted_error_loss_fn)
agents_predictions.shape, agents_losses.shape, np.any(np.isnan(agents_predictions), axis=1).sum(), np.any(np.isnan(agents_losses), axis=1).sum()

In [None]:

from prj.metrics import weighted_r2
from prj.oamp.oamp import OAMP
from prj.oamp.oamp_config import ConfigOAMP

oamp_args = ConfigOAMP({'agents_weights_upd_freq':1, 'loss_fn_window': 100, 'agg_type': 'max'})
n_agents = len(agents)
oamp: OAMP = OAMP(n_agents, oamp_args)

ensemble_preds = []
last_day = 0

for i in tqdm(range(agents_predictions.shape[0])):
    is_new_day = i > 0 and info[i, 0] != info[i-1, 0]
    if is_new_day:
        print(f'New day {i}, doing steps of previous day')
        for j in range(last_day, i):
            is_new_group = j > last_day and info[j, 1] != info[j-1, 1]
            oamp.step(agents_losses[j], is_new_group=is_new_group)
        last_day = i
        
    ensemble_preds.append(oamp.compute_prediction(agents_predictions[i]))


ensemble_preds = np.array(ensemble_preds)

ensemble_preds.shape

In [None]:
from prj.metrics import weighted_mae, weighted_mse, weighted_rmse

def metrics(y_true, y_pred, weights):
    return {
        'r2_w': weighted_r2(y_true, y_pred, weights=weights),
        'mae_w': weighted_mae(y_true, y_pred, weights=weights),
        'mse_w': weighted_mse(y_true, y_pred, weights=weights),
        'rmse_w': weighted_rmse(y_true, y_pred, weights=weights),
    }
    
results = {}
res = metrics(y, ensemble_preds, w)
columns = list(res.keys())
results['ensemble'] = res.values()
for i in range(n_agents):
    results[f'agent_{i}'] = metrics(y, agents_predictions[:, i], w).values()
results['mean'] = metrics(y, np.mean(agents_predictions, axis=1), w).values()
results['median'] = metrics(y, np.median(agents_predictions, axis=1), w).values()


results = pl.DataFrame(results)\
    .transpose(include_header=True, column_names=columns, header_name='Agent')\
    .sort('r2_w', descending=True) \
    .to_pandas().set_index('Agent')
results

In [None]:
oamp.plot_stats()