In [None]:
import polars as pl
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import gc
from prj.config import DATA_DIR, EXP_DIR
from tqdm import tqdm
from prj.agents.factory import AgentsFactory
import os

In [None]:
import tensorflow as tf
tf.__version__

In [None]:
from prj.model.nn.mlp import Mlp
from prj.model.nn.neural import TabularNNModel

model = Mlp(input_dim=(79, ))

model._build()

model.save('./test')

model = TabularNNModel.load('./test')
model

In [None]:
from prj.model.nn.neural import TabularNNModel

model = TabularNNModel.load('/home/lorecampa/projects/jane_street_forecasting/experiments/tuning/best_trial/saved_model/seed_3724349352')

model

In [None]:
agent_base_dir = EXP_DIR / 'saved'

agents_dict = [
    {'agent_type': 'lgbm', 'load_path': os.path.join(agent_base_dir, "LGBMRegressor_1seeds_5_6-7_7_134811", "best_trial", "saved_model")},
    {'agent_type': 'lgbm', 'load_path': os.path.join(agent_base_dir, "LGBMRegressor_1seeds_4_5-6_6_134802", "best_trial", "saved_model")},
    {'agent_type': 'lgbm', 'load_path': os.path.join(agent_base_dir, "LGBMRegressor_1seeds_7_8-9_9_134923", "best_trial", "saved_model")},
    # {'agent_type': 'xgb', 'load_path': os.path.join(agent_base_dir, "xgb_5-5_20241119_222538", "models")},

]
agents = [AgentsFactory.load_agent(agent_dict) for agent_dict in agents_dict]

len(agents)

In [None]:
from prj.data_loader import DataLoader

# data_args = {'ffill': False, 'include_symbol_id': False}
# loader = DataLoader(**data_args)

# X, y, w, info = loader.load_partitions(start_partition=3, end_partition=3)
# np.savez('x.npz', X=X, y=y, w=w, info=info)

data = np.load('x.npz')
X = data['X']
y = data['y']
w = data['w']
info = data['info']

f = 20000
X = X[:f]
y = y[:f]
w = w[:f]
info = info[:f]

gc.collect()

X.shape, y.shape, w.shape, info.shape

In [None]:
def squared_weighted_error_loss_fn(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray) -> np.ndarray:
    return w.reshape(-1, 1) * ((y_true.reshape(-1, 1) - y_pred_agents) ** 2)

def absolute_weighted_error_loss_fn(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray) -> np.ndarray:
    return w.reshape(-1, 1) * np.abs(y_true.reshape(-1, 1) - y_pred_agents)

def compute_loss(y_true: np.ndarray, y_pred_agents: np.ndarray, w: np.ndarray, loss_fn) -> np.ndarray:
    assert y_true.shape[0] == y_pred_agents.shape[0]
    if w is None:
        w = np.ones_like(y_true)
    assert y_true.shape[0] == w.shape[0]

    return loss_fn(y_true, y_pred_agents, w)


agents_predictions = np.concatenate([agent.predict(X).reshape(-1, 1) for agent in tqdm(agents)], axis=1)
agents_losses = compute_loss(y, agents_predictions, w, absolute_weighted_error_loss_fn)
agents_predictions.shape, agents_losses.shape

In [None]:
from prj.metrics import weighted_r2
from prj.oamp.oamp import OAMP
from prj.oamp.oamp_config import ConfigOAMP

oamp_args = ConfigOAMP({'agents_weights_upd_freq': 5, 'loss_fn_window': 20000})
n_agents = len(agents)
oamp: OAMP = OAMP(n_agents, oamp_args)

ensemble_preds = []
for i in tqdm(range(agents_predictions.shape[0])):  
    is_new_group = info[1, i] != info[1, i - 1] if i > 0 else False
    ensemble_prediction = oamp.step(agents_losses[i], agents_predictions[i], is_new_group=is_new_group)
    ensemble_preds.append(ensemble_prediction)


ensemble_preds = np.array(ensemble_preds)

ensemble_preds.shape

In [None]:
from prj.metrics import weighted_mae, weighted_mse, weighted_rmse

def metrics(y_true, y_pred, weights):
    return {
        'r2_w': weighted_r2(y_true, y_pred, weights=weights),
        'mae_w': weighted_mae(y_true, y_pred, weights=weights),
        'mse_w': weighted_mse(y_true, y_pred, weights=weights),
        'rmse_w': weighted_rmse(y_true, y_pred, weights=weights),
    }
    
results = {}
res = metrics(y, ensemble_preds, w)
columns = list(res.keys())
results['ensemble'] = res.values()
for i in range(n_agents):
    results[f'agent_{i}'] = metrics(y, agents_predictions[:, i], w).values()
results['mean'] = metrics(y, np.mean(agents_predictions, axis=1), w).values()
results['median'] = metrics(y, np.median(agents_predictions, axis=1), w).values()


results = pl.DataFrame(results)\
    .transpose(include_header=True, column_names=columns, header_name='Agent')\
    .sort('r2_w', descending=True) \
    .to_pandas().set_index('Agent')
results

In [None]:
oamp.plot_stats()