In [None]:
import tensortrade.env.default as default
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.instruments import USDT, BTC
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.agents import DQNAgent
from feed import Feed

import pandas as pd
import features as ft

In [None]:
def extract_features(quotes: pd.DataFrame) -> pd.DataFrame:
    mid = ((quotes["asks[0].price"] + quotes["bids[0].price"]) / 2).astype(float)
    features = [
        ft.lr(mid),
        ft.rsi(mid, period=20),
        ft.macd(mid, fast=10, slow=50, signal=5),
    ]
    features = pd.concat(features, axis=1)
    features.columns = ["lr", "rsi", "macd"]
    return features

In [None]:
def transform_trades(trades: pd.DataFrame) -> pd.DataFrame:
    trades["datetime"] = pd.to_datetime(trades["timestamp"], unit="us")
    return trades

In [None]:
feed = Feed.load("../data/", lambda df: df, transform_trades, extract_features, nrows=10000)
binance = Exchange("binance", service=execute_order)(feed.get_mid_price().rename("USDT-BTC"))
portfolio = Portfolio(USDT, [
    Wallet(binance, 10000 * USDT),
    Wallet(binance, 10 * BTC)
])
env = default.create(
    portfolio=portfolio,
    action_scheme="managed-risk",
    reward_scheme="risk-adjusted",
    feed=feed.get_features(),
    renderer_feed=feed.get_candles(),
    renderer=default.renderers.PlotlyTradingChart(),
    window_size=20
)

In [None]:
agent = DQNAgent(env)
agent.train(n_steps=10000, n_episodes=10, save_path="../agents/")
