In [1]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import pandas as pd
import torch
from datetime import datetime, timezone
import logging

logging.basicConfig(
    level=logging.INFO,  # Set the logging level
    format='%(asctime)s - %(levelname)s - %(message)s',  # Format for the log messages
    handlers=[
        logging.StreamHandler()  # Log to the console
    ]
)

%reload_ext autoreload
%autoreload 2
from data.raw.retrievers.alpaca_markets_retriever import AlpacaMarketsRetriever
from config.constants import *
from data.processed.dataset_creation import DatasetCreator
from data.processed.indicators import *
from data.processed.targets import Balanced3ClassClassification
from data.processed.normalization import ZScoreOverWindowNormalizer, ZScoreNormalizer, MinMaxNormalizer
from data.processed.dataset_pytorch import DatasetPytorch
from modeling.trainer import Trainer
from modeling.evaluate import evaluate_lgb_regressor, evaluate_torch_regressor, evaluate_torch_regressor_multiasset
from observability.mlflow_integration import log_experiment

from modeling.rl.environment import PortfolioEnvironment
from modeling.rl.state import State
from modeling.rl.agent import RlAgent
from modeling.rl.algorithms.policy_gradient import PolicyGradient
from modeling.rl.actors.actor import RlActor

from config.experiments.cur_experiment import config

torch.backends.cudnn.benchmark = config.train_config.cudnn_benchmark


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
retriever = AlpacaMarketsRetriever(download_from_gdrive=False)

retrieval_result = retriever.bars_with_quotes(
    symbol_or_symbols=config.data_config.symbol_or_symbols, 
    start=config.data_config.start, 
    end=config.data_config.end)

In [3]:
def get_trading_days(retrieval_result) -> list[datetime.date]:
    """Enumerate all distinct trading days in *retrieval_result*."""
    days = set()
    for df in retrieval_result.values():
        days.update(pd.to_datetime(df["date"]).dt.date.unique())
    return sorted(days)

trading_days = get_trading_days(retrieval_result)
len(trading_days)

270

In [4]:
dataset_creator = DatasetCreator(
    features=config.data_config.features,
    target=config.data_config.target,
    normalizer=config.data_config.normalizer,
    missing_values_handler=config.data_config.missing_values_handler,
    train_set_last_date=config.data_config.train_set_last_date, 
    in_seq_len=config.data_config.in_seq_len,
    multi_asset_prediction=config.data_config.multi_asset_prediction,
)

X_train, y_train, next_return_train, spread_train, X_test, y_test, next_return_test, spread_test = dataset_creator.create_dataset_numpy(retrieval_result)
X_train.shape, y_train.shape, next_return_train.shape, spread_train.shape, X_test.shape, y_test.shape, next_return_test.shape, spread_test.shape

2025-07-18 15:21:31,793 - INFO - Processing AAPL …
2025-07-18 15:21:32,700 - INFO - Imputing 496 NaN rows out of 97359 with forward fill..
2025-07-18 15:21:33,271 - INFO - Imputing 39 NaN rows with 0.5 sentinel value
2025-07-18 15:21:33,304 - INFO - Processing AMD …
2025-07-18 15:21:33,920 - INFO - Imputing 214 NaN rows out of 97359 with forward fill..
2025-07-18 15:21:34,516 - INFO - Imputing 39 NaN rows with 0.5 sentinel value
2025-07-18 15:21:34,547 - INFO - Processing BABA …
2025-07-18 15:21:35,173 - INFO - Imputing 874 NaN rows out of 97359 with forward fill..
2025-07-18 15:21:35,756 - INFO - Imputing 39 NaN rows with 0.5 sentinel value
2025-07-18 15:21:35,790 - INFO - Processing BITU …
2025-07-18 15:21:36,391 - INFO - Imputing 6493 NaN rows out of 97359 with forward fill..
2025-07-18 15:21:36,988 - INFO - Imputing 39 NaN rows with 0.5 sentinel value
2025-07-18 15:21:37,020 - INFO - Processing CSCO …
2025-07-18 15:21:37,601 - INFO - Imputing 3929 NaN rows out of 97359 with forward

((79909, 50, 120, 15),
 (79909, 50),
 (79909, 50),
 (79909, 50),
 (7251, 50, 120, 15),
 (7251, 50),
 (7251, 50),
 (7251, 50))

In [5]:
env = PortfolioEnvironment(X_train, y_train, next_return_train, spread_train, X_test, y_test, next_return_test, spread_test, trading_days, transaction_fee=0.)

In [16]:
signal_predictor = config.model_config.model.to(torch.device('cuda'))
signal_predictor.load_state_dict(torch.load('best_model.pth'))
signal_predictor

  signal_predictor.load_state_dict(torch.load('best_model.pth'))


TemporalSpatial(
  (asset_embed): Embedding(50, 32)
  (asset_proj): Linear(in_features=32, out_features=512, bias=False)
  (lstm): LSTM(15, 256, num_layers=4, batch_first=True, dropout=0.2, bidirectional=True)
  (spatial_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (fc): Linear(in_features=512, out_features=1, bias=True)
  (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [28]:
actor = RlActor(signal_predictor, n_assets=len(config.data_config.symbol_or_symbols)).to()

In [31]:
rl_agent = RlAgent(actor, env)
policy_gradient = PolicyGradient(rl_agent)

In [32]:
policy_gradient.train(epochs=1)

2025-07-18 16:18:40,158 - INFO - loss: 2.2913924112799577e-05, rewards_t: -2.288945142936427e-05
2025-07-18 16:18:55,436 - INFO - loss: 5.428933036455419e-06, rewards_t: -5.427430096460739e-06
2025-07-18 16:19:11,699 - INFO - loss: 1.3680518122782814e-06, rewards_t: -1.3631030242322595e-06
2025-07-18 16:19:27,887 - INFO - loss: 1.7615773685975e-05, rewards_t: -1.759733095241245e-05
2025-07-18 16:19:44,070 - INFO - loss: 8.950911251304206e-06, rewards_t: -8.935324331105221e-06
2025-07-18 16:20:00,306 - INFO - loss: -7.663102587684989e-05, rewards_t: 7.772999379085377e-05
2025-07-18 16:20:16,875 - INFO - loss: 1.603580244591285e-06, rewards_t: -1.5848758039282984e-06
2025-07-18 16:20:34,511 - INFO - loss: -1.4868192010908388e-05, rewards_t: 1.4914343410055153e-05
2025-07-18 16:20:53,203 - INFO - loss: 1.4591849321732298e-05, rewards_t: -1.456500103813596e-05
2025-07-18 16:21:09,417 - INFO - loss: -1.7759693946572952e-05, rewards_t: 1.7798522094381042e-05


KeyboardInterrupt: 