In [7]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import pandas as pd
import torch
from datetime import datetime, timezone
import logging

logging.basicConfig(
    level=logging.INFO,  # Set the logging level
    format='%(asctime)s - %(levelname)s - %(message)s',  # Format for the log messages
    handlers=[
        logging.StreamHandler()  # Log to the console
    ]
)

%reload_ext autoreload
%autoreload 2
from core_data_prep.core_data_prep import DataPreparer
from core_data_prep.validations import Validator

from data.raw.retrievers.alpaca_markets_retriever import AlpacaMarketsRetriever
from data.raw.retrievers.stooq_retriever import StooqRetriever
from config.constants import *
from data.processed.dataset_creation import DatasetCreator
from data.processed.indicators import *
from data.processed.targets import Balanced3ClassClassification
from data.processed.normalization import ZScoreOverWindowNormalizer, ZScoreNormalizer, MinMaxNormalizer
from data.processed.dataset_pytorch import DatasetPytorch
from modeling.trainer import Trainer
from modeling.evaluate import evaluate_lgb_regressor, evaluate_torch_regressor, evaluate_torch_regressor_multiasset
from modeling.modeling_utils import print_model_parameters

from modeling.rl.environment import PortfolioEnvironment
from modeling.rl.state import State
from modeling.rl.agent import RlAgent
from modeling.rl.algorithms.policy_gradient import PolicyGradient
from modeling.rl.actors.actor import RlActor
from modeling.rl.actors.signal_predictor_actor import SignalPredictorActor
from modeling.rl.actors.high_energy_low_friction_actor import HighEnergyLowFrictionActor
from modeling.rl.actors.xsmom_actor import XSMomActor
from modeling.rl.actors.tsmom_actor import TSMomActor
from modeling.rl.actors.blsw_actor import BLSWActor
from modeling.rl.actors.allocation_propogation_actor import AllocationPropogationActor
from modeling.rl.actors.market_actor import MarketActor
from modeling.rl.trajectory_dataset import TrajectoryDataset
from modeling.rl.metrics import MetricsCalculator, DEFAULT_METRICS
from modeling.rl.reward import EstimatedReturnReward
from modeling.rl.loss import SumLogReturnLoss, ReinforceLoss
from modeling.rl.visualization.wealth_plot import plot_cumulative_wealth
from modeling.rl.visualization.position_plot import plot_position_heatmap
from config.experiments.cur_experiment import config

torch.backends.cudnn.benchmark = config.train_config.cudnn_benchmark

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [2]:
retriever = config.data_config.retriever
retrieval_result = retriever.bars_with_quotes(
    symbol_or_symbols=config.data_config.symbol_or_symbols, 
    start=config.data_config.start, 
    end=config.data_config.end
)

In [3]:
data_preparer = DataPreparer(
    normalizer=config.data_config.normalizer,
    missing_values_handler=config.data_config.missing_values_handler_polars,
    in_seq_len=config.data_config.in_seq_len,
    frequency=str(config.data_config.frequency),
    validator=config.data_config.validator
)

In [4]:
(X_train, y_train, statistics_train), (X_val, y_val, statistics_val), (X_test, y_test, statistics_test) = \
    data_preparer.get_experiment_data(
        data=retrieval_result,
        start_date=config.data_config.start,
        end_date=config.data_config.end,
        features=config.data_config.features_polars,
        statistics=config.data_config.statistics,
        target=config.data_config.target,
        train_set_last_date=config.data_config.train_set_last_date,
        val_set_last_date=config.data_config.val_set_last_date,
        backend='loky'
    )

X_train.shape, y_train.shape, statistics_train['next_return'].shape, \
    X_val.shape, y_val.shape, statistics_val['next_return'].shape, \
    X_test.shape, y_test.shape, statistics_test['next_return'].shape

2026-01-05 12:08:51,102 - INFO - Skipping day 2024-09-03 00:00:00-04:00 because it has less than 50 assets
2026-01-05 12:08:51,229 - INFO - Skipping day 2024-09-04 00:00:00-04:00 because it has less than 50 assets
2026-01-05 12:09:34,910 - INFO - Found 280 daily slices
2026-01-05 12:09:34,936 - INFO - Trained per-asset targets


((83674, 50, 60, 16),
 (83674, 50),
 (83674, 50),
 (8993, 50, 60, 16),
 (8993, 50),
 (8993, 50),
 (16813, 50, 60, 16),
 (16813, 50),
 (16813, 50))

In [5]:
# from observability.mlflow_integration import log_experiment


# log_experiment(
#     config=config, 
#     validator_snapshots=data_preparer.validator.snapshots
#     # model=model, 
#     # history=history,
# )

In [6]:
next_return_train, spread_train, volatility_train, \
    next_return_val, spread_val, volatility_val, \
    next_return_test, spread_test, volatility_test = \
        statistics_train['next_return'], statistics_train['spread'], statistics_train['volatility'], \
        statistics_val['next_return'], statistics_val['spread'], statistics_val['volatility'], \
        statistics_test['next_return'], statistics_test['spread'], statistics_test['volatility']

In [7]:
np.abs(next_return_train).mean(), spread_train.mean(), volatility_train.mean()

(0.00068538164, 0.0002124158, 0.00086798106)

In [8]:
np.abs(next_return_val).mean(), spread_val.mean(), volatility_val.mean()

(0.0004901743, 0.00021241582, 0.00062929775)

In [9]:
np.abs(next_return_test).mean(), spread_test.mean(), volatility_test.mean()

(0.00051087514, 0.00021241585, 0.0006510256)

In [10]:
train_loader = DatasetPytorch(X_train, y_train, learning_task='regression').as_dataloader(
    batch_size=config.train_config.batch_size,
    shuffle=config.train_config.shuffle,
    num_workers=config.train_config.num_workers,
    prefetch_factor=config.train_config.prefetch_factor,
    pin_memory=config.train_config.pin_memory,
    persistent_workers=config.train_config.persistent_workers,
    drop_last=config.train_config.drop_last
)
val_loader = DatasetPytorch(X_val, y_val, learning_task='regression').as_dataloader(
    batch_size=config.train_config.batch_size,
    shuffle=config.train_config.shuffle,
    num_workers=config.train_config.num_workers,
    prefetch_factor=config.train_config.prefetch_factor,
    pin_memory=config.train_config.pin_memory,
    persistent_workers=config.train_config.persistent_workers,
    drop_last=config.train_config.drop_last
)

In [11]:
model = config.model_config.model
model

TemporalSpatial(
  (asset_embed): Embedding(50, 16)
  (asset_proj): Linear(in_features=16, out_features=256, bias=False)
  (lstm): LSTM(16, 128, num_layers=2, batch_first=True, dropout=0.2, bidirectional=True)
  (spatial_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
  )
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [12]:
print_model_parameters(model)

Module                                   Params
------------------------------------------------------------
[ROOT]                                   813601
asset_embed                              800
asset_proj                               4096
lstm                                     544768
spatial_attn                             263168
spatial_attn.out_proj                    65792
fc                                       257
norm                                     512


In [13]:
config

ExperimentConfig(data_config=DataConfig(retriever=<data.raw.retrievers.alpaca_markets_retriever.AlpacaMarketsRetriever object at 0x7a7e6eda1510>, symbol_or_symbols=['AAPL', 'AMD', 'BABA', 'BITU', 'C', 'CSCO', 'DAL', 'DIA', 'GLD', 'GOOG', 'IJR', 'MARA', 'MRVL', 'MU', 'NEE', 'NKE', 'NVDA', 'ON', 'PLTR', 'PYPL', 'QLD', 'QQQ', 'QQQM', 'RKLB', 'RSP', 'SMCI', 'SMH', 'SOXL', 'SOXX', 'SPXL', 'SPY', 'TMF', 'TNA', 'TQQQ', 'TSLA', 'UBER', 'UDOW', 'UPRO', 'VOO', 'WFC', 'XBI', 'XLC', 'XLE', 'XLI', 'XLK', 'XLU', 'XLV', 'XLY', 'XOM', 'XRT'], frequency=<alpaca.data.timeframe.TimeFrame object at 0x7a7e6ecc14d0>, start=datetime.datetime(2024, 9, 1, 0, 0, tzinfo=zoneinfo.ZoneInfo(key='America/New_York')), end=datetime.datetime(2025, 10, 1, 0, 0, tzinfo=zoneinfo.ZoneInfo(key='America/New_York')), train_set_last_date=datetime.datetime(2025, 7, 1, 0, 0, tzinfo=zoneinfo.ZoneInfo(key='America/New_York')), val_set_last_date=datetime.datetime(2025, 8, 1, 0, 0, tzinfo=zoneinfo.ZoneInfo(key='America/New_York')), 

In [14]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_fn=config.train_config.loss_fn,
    optimizer=config.train_config.optimizer,
    scheduler=config.train_config.scheduler,
    num_epochs=config.train_config.num_epochs,
    early_stopping_patience=config.train_config.early_stopping_patience,
    device=config.train_config.device,
    metrics=config.train_config.metrics,
    save_path=config.train_config.save_path
)

2026-01-05 12:11:08,707 - INFO - Model compiled with torch.compile()


In [15]:
model, history = trainer.train()

2026-01-05 12:11:08,749 - INFO - Epoch 1/20
2026-01-05 12:11:29,223 - INFO - Train Loss: 0.2527        
2026-01-05 12:11:29,224 - INFO - Train Rmse: 0.5012
2026-01-05 12:11:29,224 - INFO - Val   Loss: 0.2355
2026-01-05 12:11:29,224 - INFO - Val   Rmse: 0.4853
2026-01-05 12:11:29,225 - INFO - New best model found! Updating best state dict.
2026-01-05 12:11:29,227 - INFO - 
2026-01-05 12:11:29,227 - INFO - Epoch 2/20
2026-01-05 12:11:43,387 - INFO - Train Loss: 0.2376         
2026-01-05 12:11:43,388 - INFO - Train Rmse: 0.4874
2026-01-05 12:11:43,388 - INFO - Val   Loss: 0.2311
2026-01-05 12:11:43,389 - INFO - Val   Rmse: 0.4807
2026-01-05 12:11:43,389 - INFO - New best model found! Updating best state dict.
2026-01-05 12:11:43,391 - INFO - 
2026-01-05 12:11:43,391 - INFO - Epoch 3/20
2026-01-05 12:11:57,536 - INFO - Train Loss: 0.2355         
2026-01-05 12:11:57,536 - INFO - Train Rmse: 0.4852
2026-01-05 12:11:57,537 - INFO - Val   Loss: 0.2313
2026-01-05 12:11:57,537 - INFO - Val   R

In [29]:
import copy  # Local import to avoid polluting global namespace unnecessarily
state_dict = (
    model.module.state_dict()
        if isinstance(model, torch.nn.DataParallel)
    else model.state_dict()
)

# Keep a local copy of the best weights so we can return the best model
# after training finishes, without needing to reload from disk.
best_model_state = copy.deepcopy(state_dict)

# Persist to disk if a save_path was provided
torch.save(state_dict, "best_model.pth")

In [None]:
from config.experiments.cur_experiment import config
from core_data_prep.core_data_prep import DataPreparer
from core_inference.bars_response_handler import BarsResponseHandler
from core_inference.quotes_response_handler import QuotesResponseHandler
from core_inference.trader import Trader
from core_inference.brokerage_proxies.alpaca_brokerage_proxy import AlpacaBrokerageProxy
from core_inference.brokerage_proxies.backtest_brokerage_proxy import BacktestBrokerageProxy
from core_inference.repository import Repository
from core_inference.allocators.signal_predictor_allocator import SignalPredictorAllocator


daily_slices = data_preparer._get_daily_slices(
    retrieval_result,
    start_date=config.data_config.val_set_last_date,
    end_date=config.data_config.end,
    slice_length=Constants.Data.TRADING_DAY_LENGTH_MINUTES + config.data_config.in_seq_len + config.data_config.normalizer.get_window() + 30,
    verbose=True
)

2026-01-05 14:27:47,879 - INFO - Last timestamp counts across all slices and stocks:
2026-01-05 14:27:47,880 - INFO -   2025-08-01 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,881 - INFO -   2025-08-04 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,882 - INFO -   2025-08-05 15:59:00-04:00: 1 occurrences
2026-01-05 14:27:47,884 - INFO -   2025-08-05 16:00:00-04:00: 49 occurrences
2026-01-05 14:27:47,885 - INFO -   2025-08-06 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,886 - INFO -   2025-08-07 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,887 - INFO -   2025-08-08 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,888 - INFO -   2025-08-11 15:59:00-04:00: 1 occurrences
2026-01-05 14:27:47,889 - INFO -   2025-08-11 16:00:00-04:00: 49 occurrences
2026-01-05 14:27:47,893 - INFO -   2025-08-12 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,894 - INFO -   2025-08-13 16:00:00-04:00: 50 occurrences
2026-01-05 14:27:47,895 - INFO -   2025-08-14 16:00:00-04:00: 50 occur

In [63]:
# assert all the slices have the same length
assert len(set([len(cur_day_asset_slice) for cur_day_slices in daily_slices for cur_day_asset_slice in cur_day_slices.values()])) == 1

len(daily_slices), len(daily_slices[0]), len(daily_slices[0]['AAPL'])

(43, 50, 541)

In [64]:
state_dict = torch.load(
    "../modeling/checkpoints/best_model.pth",
    map_location=device
)

new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("_orig_mod."):
        new_state_dict[k.replace("_orig_mod.", "")] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict)

allocator = SignalPredictorAllocator(
    signal_predictor=model,
    trade_asset_count=config.rl_config.trade_asset_count,
    allow_short_positions=False
).to(device)

  state_dict = torch.load(


In [None]:
cur_cash = 1.
for day_i, daily_slice in enumerate(daily_slices):
    cur_day_initialization = {asset_name: asset_df.iloc[:-Constants.Data.TRADING_DAY_LENGTH_MINUTES].copy() for asset_name, asset_df in daily_slice.items()}
    remaining_updates = [{asset_name: asset_df.iloc[-i] for asset_name, asset_df in daily_slice.items()} \
         for i in reversed(range(1, Constants.Data.TRADING_DAY_LENGTH_MINUTES + 1))]

    repository = Repository(
        trading_symbols=config.data_config.symbol_or_symbols,
        required_history_depth=config.data_config.in_seq_len + config.data_config.normalizer.get_window() + 30,
        bars_and_quotes=cur_day_initialization
    )
    backtest_proxy = BacktestBrokerageProxy(repository, config.rl_config.spread_multiplier, cur_cash)
    trader = Trader(
        data_preparer=data_preparer,
        features=config.data_config.features_polars,
        brokerage_proxy=backtest_proxy,
        repository=repository,
        portfolio_allocator=allocator
    )

    for update_i, remaining_update in enumerate(remaining_updates):
        for stock_name, stock_data_series in remaining_update.items():
            stock_data = stock_data_series.to_dict()
            stock_data['symbol'] = stock_name
            repository.add_bar(stock_data)

        trader.perform_trading_cycle()

        logging.info(f"Day {day_i} update {update_i} ended with cash {backtest_proxy.get_cash_balance()}")

    backtest_proxy.close_all_positions()
    cur_cash = backtest_proxy.get_cash_balance()
    logging.info(f"Day {day_i} ended with cash {cur_cash}")

2026-01-06 17:17:38,509 - INFO - Starting trading cycle...
2026-01-06 17:17:38,515 - INFO - Transforming data for inference...
2026-01-06 17:17:39,352 - INFO - Running portfolio allocator...
2026-01-06 17:17:39,407 - INFO - Calculating position difference...
2026-01-06 17:17:39,409 - INFO - Starting order execution...
2026-01-06 17:17:39,412 - INFO - Order execution completed!
2026-01-06 17:17:39,415 - INFO - Day 0 update 380 ended with cash 0.5000002086162567
2026-01-06 17:17:39,535 - INFO - Starting trading cycle...
2026-01-06 17:17:39,540 - INFO - Transforming data for inference...
2026-01-06 17:17:40,304 - INFO - Running portfolio allocator...
2026-01-06 17:17:40,346 - INFO - Calculating position difference...
2026-01-06 17:17:40,347 - INFO - Starting order execution...
2026-01-06 17:17:40,351 - INFO - Order execution completed!
2026-01-06 17:17:40,354 - INFO - Day 0 update 381 ended with cash 0.49950759521105503
2026-01-06 17:17:40,474 - INFO - Starting trading cycle...
2026-01-06

KeyboardInterrupt: 

ZeroDivisionError: division by zero

In [58]:
len(trader.states_history)

89

In [66]:
trader.states_history[-1]

State(desired_position={'AAPL': 0.0, 'AMD': 0.0, 'BABA': 0.0, 'BITU': 0.0, 'C': 0.0, 'CSCO': 0.0, 'DAL': 0.0, 'DIA': 0.0, 'GLD': 0.0, 'GOOG': 0.0, 'IJR': 0.0, 'MARA': 0.0, 'MRVL': 0.0, 'MU': 0.0, 'NEE': 0.0, 'NKE': 0.0, 'NVDA': 0.0, 'ON': 0.0, 'PLTR': 0.0, 'PYPL': 0.0, 'QLD': 0.0, 'QQQ': 0.0, 'QQQM': 0.0, 'RKLB': 0.9999999, 'RSP': 0.0, 'SMCI': 0.0, 'SMH': 0.0, 'SOXL': 0.0, 'SOXX': 0.0, 'SPXL': 0.0, 'SPY': 0.0, 'TMF': 0.0, 'TNA': 0.0, 'TQQQ': 0.0, 'TSLA': 0.0, 'UBER': 0.0, 'UDOW': 0.0, 'UPRO': 0.0, 'VOO': 0.0, 'WFC': 0.0, 'XBI': 0.0, 'XLC': 0.0, 'XLE': 0.0, 'XLI': 0.0, 'XLK': 0.0, 'XLU': 0.0, 'XLV': 0.0, 'XLY': 0.0, 'XOM': 0.0, 'XRT': 0.0}, position={'AAPL': 0.0, 'AMD': 0.0, 'BABA': 0.0, 'BITU': 0.0, 'C': 0.0, 'CSCO': 0.0, 'DAL': 0.0, 'DIA': 0.0, 'GLD': 0.0, 'GOOG': 0.0, 'IJR': 0.0, 'MARA': 0.0, 'MRVL': 0.0, 'MU': 0.0, 'NEE': 0.0, 'NKE': 0.0, 'NVDA': 0.0, 'ON': 0.0, 'PLTR': 0.0, 'PYPL': 0.0, 'QLD': 0.0, 'QQQ': 0.0, 'QQQM': 0.0, 'RKLB': 1.0, 'RSP': 0.0, 'SMCI': 0.0, 'SMH': 0.0, 'SOXL': 0