In [2]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd
warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)
import feature_manager as fma
import rl.env_simple_crypto_trade as env
import importlib
import rl.models as rla
import config as cf
import tr_utils

In [3]:
importlib.reload(fma)

fm = fma.FeatureManager()
fm.import_data(symbol="BTCUSDT",timeframes=["1d","1w","1mo"])

fm.build_features(
    lags = [1,1,1],
    features=cf.FOR_1D_INDICATORS
)

Imported data 1d from ../data/BTCUSDT-1d.csv with 1963 rows
Imported data 1w from ../data/BTCUSDT-1w.csv with 247 rows
Imported data 1mo from ../data/BTCUSDT-1mo.csv with 65 rows
Calculating external features ...
sma_3_10, sma_7_30, rsi7, rsi14, rsi30, cci7, cci14, cci30, dx7, dx14, dx30, hashrate, fed_rate, gold, nasdaq, sp500, google_trend, 
sma_3_10, sma_7_30, rsi7, rsi14, rsi30, cci7, cci14, cci30, dx7, dx14, dx30, 
sma_3_10, rsi7, rsi14, cci7, cci14, dx7, dx14, 

Normalizing features with MinMax: sma_3_10_level0_lag_1, sma_7_30_level0_lag_1, rsi7_level0_lag_1, rsi14_level0_lag_1, rsi30_level0_lag_1, cci7_level0_lag_1, cci14_level0_lag_1, cci30_level0_lag_1, dx7_level0_lag_1, dx14_level0_lag_1, dx30_level0_lag_1, hashrate_level0_lag_1, fed_rate_level0_lag_1, gold_level0_lag_1, nasdaq_level0_lag_1, sp500_level0_lag_1, google_trend_level0_lag_1, sma_3_10_level1_lag_1, sma_7_30_level1_lag_1, rsi7_level1_lag_1, rsi14_level1_lag_1, rsi30_level1_lag_1, cci7_level1_lag_1, cci14_level1_lag

In [4]:
importlib.reload(env)
importlib.reload(cf)

env_kwargs = cf.TRADE_ENV_PARAMETER
state_space = 10 + len(fm.cols)

full_env = env.CryptoTradingEnv(
    trade_timeframe="1d",
    df = fm.df, 
    state_space=state_space,
    indicators=fm.cols,
    **env_kwargs)

In [5]:
importlib.reload(rla)
agent = rla.DRLTradeAgent(env=full_env)

In [6]:
catalog_name = "dqn_4layers_relu"

timestep = 960_000

selected_model = agent.load_model_from_checkpoint(
    "dqn",f"{catalog_name}/rl_model_{timestep}_steps")

Successfully load model from ../saved_models/checkpoint/dqn_4layers_relu/rl_model_960000_steps


In [7]:
selected_model.replay_buffer_class

stable_baselines3.common.buffers.ReplayBuffer

In [9]:
replay_buffer = selected_model.replay_buffer

In [26]:
selected_model.q_net_target

QNetwork(
  (features_extractor): FlattenExtractor(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (q_net): Sequential(
    (0): Linear(in_features=45, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=128, bias=True)
    (7): ReLU()
    (8): Linear(in_features=128, out_features=3, bias=True)
  )
)

In [18]:
selected_model.save_replay_buffer()

array([[0.],
       [0.],
       [0.],
       ...,
       [0.],
       [0.],
       [0.]], dtype=float32)

In [None]:
importlib.reload(rla)
agent = rla.DRLTradeAgent(env=full_env)

policy_kwargs = dict(
    net_arch = [256,128,256,128]
)

alpha_0 = 1e-6
alpha_end = 1e-9

def learning_rate_f(process_remaining):
    initial = alpha_0
    final = alpha_end
    interval = initial-final
    return final+interval*process_remaining

MODEL_PARAMS = {
    "learning_rate": learning_rate_f,
    "buffer_size": 100_000,  
    "learning_starts": 50_000,
    "batch_size": 64,
    "tau": 1.0,
    "gamma": 0.999,
    "train_freq": 4,
    "target_update_interval": 10_000,
    "exploration_fraction": 0.025,
    "exploration_initial_eps": 1.0,
    "exploration_final_eps": 0.05
}

catalog_name = tr_utils.get_name_with_kwargs(
    name="dqn",
    kwargs=MODEL_PARAMS,
    excludes=["learning_rate"]
)

dqn_model = agent.get_model(
    model_name="dqn",
    model_kwargs = MODEL_PARAMS,
    tensorboard_log=catalog_name,
    seed = 100
)

In [1]:
full_env

NameError: name 'full_env' is not defined

In [None]:
catalog_name

In [None]:
selected_model = dqn_model

selected_model = agent.train_model(
    model = selected_model,
    total_timesteps = 5_000_000,
    checkpoint = True,
    catalog_name = catalog_name,
    save_frequency = 20_000,
    progress_bar = True
)