## Install TensorTrade

In [1]:
## Setup Data Fetching

In [2]:
import inspect
import sys
import os
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, "{}".format(parentdir))

In [3]:
import ssl
import pandas as pd

from tensortrade.utils import CryptoDataDownload

ssl._create_default_https_context = ssl._create_unverified_context # Only used if pandas gives a SSLError

cdd = CryptoDataDownload()

In [5]:
data = pd.concat([
    cdd.fetch("Coinbase", "USD", "BTC", "1h")
    #cdd.fetch("Coinbase", "USD", "ETH", "1h").add_prefix("ETH:")
], axis=1)
#data = data.drop(["ETH:date"], axis=1)
#data = data.rename({"BTC:date": "date"}, axis=1)

In [6]:
data.head()

Unnamed: 0,date,unix timestamp,open,high,low,close,volume
0,2017-07-01 11:00:00,1498907000.0,2505.56,2513.38,2495.12,2509.17,287000.32
1,2017-07-01 12:00:00,1498910000.0,2509.17,2512.87,2484.99,2488.43,393142.5
2,2017-07-01 13:00:00,1498914000.0,2488.43,2488.43,2454.4,2454.43,693254.01
3,2017-07-01 14:00:00,1498918000.0,2454.43,2473.93,2450.83,2459.35,712864.8
4,2017-07-01 15:00:00,1498921000.0,2459.35,2475.0,2450.0,2467.83,682105.41


## Create features with the data module

In [8]:
import ta
dataset = ta.add_all_ta_features(data, 'open', 'high', 'low', 'close', 'volume', fillna=True)
dataset.head()

Unnamed: 0,date,unix timestamp,open,high,low,close,volume,volume_adi,volume_obv,volume_cmf,...,momentum_uo,momentum_stoch,momentum_stoch_signal,momentum_wr,momentum_ao,momentum_kama,momentum_roc,others_dr,others_dlr,others_cr
0,2017-07-01 11:00:00,1498907000.0,2505.56,2513.38,2495.12,2509.17,287000.32,154659.5,287000.32,0.538883,...,0.0,76.94414,76.94414,-23.05586,0.0,2509.17,0.0,-67.089655,0.0,0.0
1,2017-07-01 12:00:00,1498910000.0,2509.17,2512.87,2484.99,2488.43,393142.5,-141466.4,-106142.18,-0.207995,...,7.45557,12.116943,44.530541,-87.883057,0.0,2499.529881,0.0,-0.826568,-0.830003,-0.826568
2,2017-07-01 13:00:00,1498914000.0,2488.43,2488.43,2454.4,2454.43,693254.01,-833498.1,-799396.19,-0.606888,...,4.328302,0.050865,29.703982,-99.949135,0.0,2478.933617,0.0,-1.366323,-1.375743,-2.181598
3,2017-07-01 14:00:00,1498918000.0,2454.43,2473.93,2450.83,2459.35,712864.8,-1020509.0,-86531.39,-0.489157,...,11.610342,13.621103,8.596303,-86.378897,0.0,2470.163427,0.0,0.200454,0.200253,-1.985517
4,2017-07-01 15:00:00,1498921000.0,2459.35,2475.0,2450.0,2467.83,682105.41,-729659.7,595574.02,-0.26357,...,23.247837,28.131903,13.934624,-71.868097,0.0,2469.130885,0.0,0.344807,0.344213,-1.647557


In [9]:
from tensortrade.data import Node, Module, DataFeed, Stream, Select


def rsi(price: Node, period: float):
    r = price.diff()
    upside = r.clamp_min(0).abs()
    downside = r.clamp_max(0).abs()
    rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()
    return 100*(1 - (1 + rs) ** -1)


def macd(price: Node, fast: float, slow: float, signal: float) -> Node:
    fm = price.ewm(span=fast, adjust=False).mean()
    sm = price.ewm(span=slow, adjust=False).mean()
    md = fm - sm
    signal = md - md.ewm(span=signal, adjust=False).mean()
    return signal


features = []
for c in data.columns[1:]:
    s = Stream(list(data[c])).rename(data[c].name)
    features += [s]

btc_close = Select("close")(*features)
#eth_close = Select("ETH:close")(*features)

features += [
    rsi(btc_close, period=20).rename("rsi"),
    macd(btc_close, fast=10, slow=50, signal=5).rename("macd")
    #rsi(eth_close, period=20).rename("ETH:rsi"),
    #macd(eth_close, fast=10, slow=50, signal=5).rename("ETH:macd")
]
        


In [12]:
# Get Features
features = []
# exclude date column
for c in dataset.columns[1:]:
    s = Stream(list(dataset[c])).rename(dataset[c].name)
    features += [s]
feed = DataFeed(features)
feed.compile()

In [15]:
feed.next()

{'unix timestamp': 1498914000.0,
 'open': 2488.43,
 'high': 2488.43,
 'low': 2454.4,
 'close': 2454.43,
 'volume': 693254.01,
 'volume_adi': -833498.1482762388,
 'volume_obv': -799396.19,
 'volume_cmf': -0.6068880676506576,
 'volume_fi': -10356184.14857151,
 'momentum_mfi': 0.0,
 'volume_em': -135063.8346830462,
 'volume_sma_em': -86395.51005917703,
 'volume_vpt': -12721.682056263691,
 'volume_nvi': 1000.0,
 'volume_vwap': 2482.6358501672094,
 'volatility_atr': 0.0,
 'volatility_bbm': 2484.01,
 'volatility_bbh': 2529.1400107098007,
 'volatility_bbl': 2438.8799892902,
 'volatility_bbw': 3.633641628640822,
 'volatility_bbp': 0.17228015754074724,
 'volatility_bbhi': 0.0,
 'volatility_bbli': 0.0,
 'volatility_kcc': 2489.0244444444447,
 'volatility_kch': 2515.7477777777776,
 'volatility_kcl': 2462.301111111111,
 'volatility_kcw': 2.14729376346467,
 'volatility_kcp': -0.14727038376782592,
 'volatility_kchi': 0.0,
 'volatility_kcli': 1.0,
 'volatility_dcl': 2.374386576543573,
 'volatility_dch

## Setup Trading Environment

In [17]:
from tensortrade.exchanges import Exchange
from tensortrade.exchanges.services.execution.simulated import execute_order
from tensortrade.data import Stream, DataFeed, Module
from tensortrade.instruments import USD, BTC, ETH
from tensortrade.wallets import Wallet, Portfolio
from tensortrade.environments import TradingEnvironment


coinbase = Exchange("coinbase", service=execute_order)(
    Stream(list(data["close"])).rename("USD-BTC"),
    #Stream(list(data["ETH:close"])).rename("USD-ETH")
)

portfolio = Portfolio(USD, [
    Wallet(coinbase, 10000 * USD),
    Wallet(coinbase, 10 * BTC),
    #Wallet(coinbase, 5 * ETH),
])

## Example Data Feed Observation

Even though this observation contains data from the internal data feed, since `use_internal=False` this data will not be provided as input to the observation history. The data that will be added to observation history of the environment will strictly be the nodes that have been included into the data feed that has been provided as a parameter to the trading environment.

In [18]:
from tensortrade.environments.render import PlotlyTradingChart
from tensortrade.environments.render import FileLogger

chart_renderer = PlotlyTradingChart(
    display=True,  # show the chart on screen (default)
    height=800,  # affects both displayed and saved file height. None for 100% height.
    save_format='html',  # save the chart to an HTML file
    auto_open_html=True,  # open the saved HTML chart in a new browser tab
)

file_logger = FileLogger(
    filename='example.log',  # omit or None for automatic file name
    path='training_logs'  # create a new directory if doesn't exist, None for no directory
)

env = TradingEnvironment(
    feed=feed,
    portfolio=portfolio,
    use_internal=False,
    action_scheme="managed-risk",
    reward_scheme="risk-adjusted",
    window_size=20
)

env.feed.next()

{'unix timestamp': 1498906800.0,
 'open': 2505.56,
 'high': 2513.38,
 'low': 2495.12,
 'close': 2509.17,
 'volume': 287000.32,
 'volume_adi': 154659.5371741516,
 'volume_obv': 287000.32,
 'volume_cmf': 0.5388828039430464,
 'volume_fi': 0.0,
 'momentum_mfi': 50.0,
 'volume_em': 0.0,
 'volume_sma_em': 0.0,
 'volume_vpt': -193179.4152008041,
 'volume_nvi': 1000.0,
 'volume_vwap': 2505.89,
 'volatility_atr': 0.0,
 'volatility_bbm': 2509.17,
 'volatility_bbh': 2509.17,
 'volatility_bbl': 2509.17,
 'volatility_bbw': 0.0,
 'volatility_bbp': 0.0,
 'volatility_bbhi': 0.0,
 'volatility_bbli': 0.0,
 'volatility_kcc': 2505.89,
 'volatility_kch': 2524.15,
 'volatility_kcl': 2487.6299999999997,
 'volatility_kcw': 1.457366444656407,
 'volatility_kcp': 0.5898138006571786,
 'volatility_kchi': 0.0,
 'volatility_kcli': 0.0,
 'volatility_dcl': 0.7277306838516409,
 'volatility_dch': 0.7694414019715232,
 'trend_macd': 0.0,
 'trend_macd_signal': 0.0,
 'trend_macd_diff': 0.0,
 'trend_sma_fast': 2509.17,
 'tre

## Setup and Train DQN Agent

In [None]:
import gym

from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.policies import MlpPolicy, MlpLnLstmPolicy
from stable_baselines import DQN, PPO2, A2C

In [None]:
# DQN-Model
from stable_baselines.deepq.policies import MlpPolicy
agent = DQN(MlpPolicy, env, verbose=1, tensorboard_log=os.path.join(currentdir,"tf_board_log","DQN"))
agent.learn(total_timesteps=25000)
agent.save(save_path=os.path.join(currentdir, "agents","DQN_MlpPolicys.zip"))

In [None]:
# PPO2-Model
agent = PPO2(MlpPolicy, env, verbose=1)
agent.learn(total_timesteps=25000)
agent.save(save_path=os.path.join(currentdir, "agents","PPO2_MlpPolicy.zip"))

In [None]:
# A2C-Model
agent = A2C(MlpPolicy, env, verbose=1)
agent.learn(total_timesteps=25000)
agent.save(save_path=os.path.join(currentdir, "agents","A2C_MlpPolicy.zip"))

In [None]:
from stable_baselines.common.policies import MlpLnLstmPolicy
from stable_baselines import PPO2

model = PPO2
policy = MlpLnLstmPolicy
params = { "learning_rate": 1e-5 }

agent = model(policy, env, learning_rate=1e-5, nminibatches=1)
agent.learn(total_timesteps=25000)
agent.save(save_path = os.path.join(currentdir, "agents","MlpLnLstmPolicy.zip"))

## Plot Performance

In [None]:
%matplotlib inline

portfolio.performance.plot()

In [None]:
portfolio.performance.net_worth.plot()

## Setup and Train Parallel DQN Agent

In [None]:
from tensortrade.agents import ParallelDQNAgent

def create_env():
    
    env = TradingEnvironment(
        feed=feed,
        portfolio=portfolio,
        action_scheme='managed-risk',
        reward_scheme='risk-adjusted',
        window_size=20
    )
    
    return env

agent = ParallelDQNAgent(create_env)

agent.train(n_envs=4, n_steps=200, save_path=currentdir + "/examples/agents/ParallelDQNAgent.zip")

## Test Agent


### Environment with Multiple Renderers
Create PlotlyTradingChart and FileLogger renderers. Configuring renderers is optional as they can be used with their default settings.

In [61]:
# test agent
# multiprocess environment
agent = PPO2.load(load_path=os.path.join(currentdir, "agents","MlpPolicy.zip"), env=env)
done = False
obs = env.reset()
count = 0
while not done:
    action, _states = agent.predict(obs)
    obs, reward, done, info = env.step(action)
    count += 1
    done = count > 1000

AssertionError: Error: the environment passed is not a vectorized environment, however PPO2 requires it