## Install TensorTrade

In [1]:
# !python3 -m pip install git+https://github.com/nsarang/tensortrade.git

## Setup

In [87]:
# Put these at the top of every notebook, to get automatic reloading and inline plotting
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [88]:
import asyncio
import ccxt
# import ccxt.async_support as ccxt

apiKey = "jxlzo1mxQ1PDckz4aYgH2WDgFxpJjBu47r3OB4vyLyZkEeyJ4xjOM6m32mvsIgmu"
secret = "EffQgaLRPl52q0YEpVKcIHDeqyrFBQWm2K1Er99egbQ1c75X7fDREg4UtzhSaCJM"

exchange = ccxt.binance({
        "apiKey": apiKey,
        "secret": secret,
        "enableRateLimit": True,
        # 'options': {
        #     'defaultType': 'spot', // spot, future, margin
        # },
    }
)

In [89]:
import re
import sys
import time
import pandas as pd
import numpy as np
from datetime import datetime, timedelta, timezone
from tenacity import retry, retry_if_exception_type, stop_after_attempt
import pytz


@retry(retry=retry_if_exception_type(ccxt.NetworkError), stop=stop_after_attempt(3))
def get_historical_data(
    symbol, exchange, timeframe, start_date=None, limit=500, max_per_page=500
):
    """Get historical OHLCV for a symbol pair

    Decorators:
        retry

    Args:
        symbol (str): Contains the symbol pair to operate on i.e. BURST/BTC
        exchange (str): Contains the exchange to fetch the historical data from.
        timeframe (str): A string specifying the ccxt time unit i.e. 5m or 1d.
        start_date (int, optional): Timestamp in milliseconds.
        max_periods (int, optional): Defaults to 100. Maximum number of time periods
          back to fetch data for.

    Returns:
        list: Contains a list of lists which contain timestamp, open, high, low, close, volume.
    """

    try:
        if timeframe not in exchange.timeframes:
            raise ValueError(
                "{} does not support {} timeframe for OHLCV data. Possible values are: {}".format(
                    exchange, timeframe, list(exchange.timeframes)
                )
            )
    except AttributeError:
        self.logger.error(
            "%s interface does not support timeframe queries! We are unable to fetch data!",
            exchange,
        )
        raise AttributeError(sys.exc_info())

    timeframe_regex = re.compile("([0-9]+)([a-zA-Z])")
    timeframe_matches = timeframe_regex.match(timeframe)
    time_quantity = timeframe_matches.group(1)
    time_period = timeframe_matches.group(2)
    timedelta_values = {
        "m": "minutes",
        "h": "hours",
        "d": "days",
        "w": "weeks",
        "M": "months",
        "y": "years",
    }

    timedelta_args = {timedelta_values[time_period]: int(time_quantity)}
    single_frame = timedelta(**timedelta_args)

    if not start_date:
        start_datetime = datetime.now() - (limit * single_frame)
        start_date = int(start_datetime.timestamp() * 1000)
        total = limit

    else:
        total = (datetime.now() - start_date * 1000) // single_frame.total_seconds()
        if limit:
            total = min(limit, total)

    historical_data = []
    for cursor in range(0, total, max_per_page):
        curr_start_date = start_date + int(cursor * single_frame.total_seconds() * 1000)
        limit = min(total - cursor, max_per_page)
        historical_data += exchange.fetch_ohlcv(
            symbol, timeframe=timeframe, since=curr_start_date, limit=limit
        )

    if not historical_data:
        raise ValueError("No historical data provided returned by exchange.")

    if len(historical_data) != total:
        raise ValueError("Gaps detected in historical data.")

    # Sort by timestamp in ascending order
    historical_data.sort(key=lambda d: d[0])

    return historical_data


def timestamp_to_datetime(
    timestamp, timezone=pytz.timezone("America/Montreal"), to_str=False
):
    time = datetime.fromtimestamp(timestamp, timezone)
    if to_str:
        time = time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
    return time


def convert_to_dataframe(historical_data):
    """Converts historical data matrix to a pandas dataframe.

    Args:
        historical_data (list): A matrix of historical OHCLV data.

    Returns:
        pandas.DataFrame: Contains the historical data in a pandas dataframe.
    """

    dataframe = pd.DataFrame(historical_data)
    dataframe.transpose()

    dataframe.columns = ["timestamp", "open", "high", "low", "close", "volume"]
    dataframe["datetime"] = dataframe.timestamp.apply(
        lambda x: timestamp_to_datetime(x / 1000)
    )

    dataframe.set_index("datetime", inplace=True, drop=True)
    dataframe.drop("timestamp", axis=1, inplace=True)

    return dataframe

In [90]:
def shift(values: np.ndarray, periods: int, axis, fill_value) -> np.ndarray:
    new_values = values

    if periods == 0 or values.size == 0:
        return new_values.copy()

    # make sure array sent to np.roll is c_contiguous
    f_ordered = values.flags.f_contiguous
    if f_ordered:
        new_values = new_values.T
        axis = new_values.ndim - axis - 1

    if np.prod(new_values.shape):
        new_values = np.roll(new_values, periods, axis=axis)

    axis_indexer = [ slice(None)] * values.ndim
    if periods > 0:
        axis_indexer[axis] = slice(None, periods)
    else:
        axis_indexer[axis] = slice(periods, None)
    new_values[tuple(axis_indexer)] = fill_value

    # restore original order
    if f_ordered:
        new_values = new_values.T

    return new_values


def crossing(a, b):
    a_plus = shift(a, 1, axis=0, fill_value=0)
    b_plus = shift(b, 1, axis=0, fill_value=0)
    cross = np.where(
        (a <= b) & (a_plus >= b_plus),
        1,
        np.where(((a >= b) & (a_plus <= b_plus)), -1, 0),
    )
    return cross


def SWING_CALLS(df):
    ema = ta.EMA(df.close, 5)
    sma = ta.SMA(df.close, 50)
    rsi = ta.RSI(df.close, 14)

    color = np.where(
        (rsi >= 85) | (rsi <= 15),
        "YELLOW",
        np.where(df.low > sma, "LIME", np.where(df.high < sma, "RED", "YELLOW")),
    )

    buyexit = rsi > 80
    sellexit = rsi < 30

    sellcall = (crossing(sma, ema) > 0) & (df.open > df.close)
    buycall = (crossing(sma, ema) < 0) & (df.high > sma)

    return buyexit, sellexit, sellcall, buycall


def smooth_range(series, period, mult):
    wper = period * 2 - 1
    diff = (series - series.shift(1, fill_value=0)).abs()
    average = ta.EMA(diff, period)
    smoothed = ta.EMA(average, wper) * mult
    smoothed = pd.Series(smoothed, index=series.index)
    return smoothed


def filter_range(series, smoothrng):
    result = series.shift(1, fill_value=0)
    for time, (close, smth) in enumerate(zip(series, smoothrng)):
        prev = result.iloc[time]
        if time == 0 or ((close >= prev - smth) and (close <=  prev +smth)):
            continue
            
        if close > prev + smth:
            prev = close - smth
        else:
            prev = close + smth
        result.iloc[time] = prev
    return result


def Range_Filter_Buy_Sell(df, period=100, range_multiplier=3):
    # Smooth Average Range
    smoothed = smooth_range(df.close, period, range_multiplier)

    # Range Filter
    filtered = filter_range(df.close, smoothed)

    buycall = (df.close > filtered) & (df.close > df.close.shift(1)) & (filtered > filtered.shift(1))
    sellcall = (df.close < filtered) & (df.close < df.close.shift(1)) & (filtered < filtered.shift(1))
    return buycall, sellcall


def calculate_profit(ohlvc, buycall, sellcall, start_from=100, trade_fee=0.1):
    money = 1
    asset = 0
    last_buy = ohlvc.iloc[start_from]["close"]
    trade_cost = 0
    trade_fee /= 100
    for time, (buy, sell) in enumerate(zip(buycall, sellcall)):
        if time < start_from:
            continue

        if buy and money and (time != len(ohlvc) - 1):
            trade_cost += money * trade_fee
            money *= (1 - trade_fee)
            asset = money / ohlvc.iloc[time]["close"]
            money = 0
            last_buy = ohlvc.iloc[time]["close"]
        
        elif (sell or (time == len(ohlvc) - 1)) and asset:
            money = asset * ohlvc.iloc[time]["close"]
            trade_cost += money * trade_fee
            money *= (1 - trade_fee)
            asset = 0
    
    return money, trade_cost

## Data

### Load

In [91]:
import pandas as pd
import tensortrade.env.default as default

from tensortrade.data.cdd import CryptoDataDownload
from tensortrade.feed.core import Stream, DataFeed
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.instruments import USD, BTC, ETH
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.agents import DQNAgent


%matplotlib inline

In [92]:
# cdd = CryptoDataDownload()
# data = cdd.fetch("Coinbase", "USD", "BTC", "1h")


data = pd.read_csv("data/Coinbase_BTCUSD_1h.csv", skiprows=1)
data["date"] = pd.to_datetime(data["date"], format="%Y-%m-%d %I-%p")
data = data.sort_values("date")
data.head()

Unnamed: 0,date,symbol,open,high,low,close,volume_btc,volume
20110,2017-07-01 11:00:00,BTCUSD,2505.56,2513.38,2495.12,2509.17,114.6,287000.32
20109,2017-07-01 12:00:00,BTCUSD,2509.17,2512.87,2484.99,2488.43,157.36,393142.5
20108,2017-07-01 13:00:00,BTCUSD,2488.43,2488.43,2454.4,2454.43,280.28,693254.01
20107,2017-07-01 14:00:00,BTCUSD,2454.43,2473.93,2450.83,2459.35,289.42,712864.8
20106,2017-07-01 15:00:00,BTCUSD,2459.35,2475.0,2450.0,2467.83,276.82,682105.41


### Create features with the feed module

In [93]:
def rsi(price: Stream[float], period: float) -> Stream[float]:
    r = price.diff()
    upside = r.clamp_min(0).abs()
    downside = r.clamp_max(0).abs()
    rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()
    return 100*(1 - (1 + rs) ** -1)


def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:
    fm = price.ewm(span=fast, adjust=False).mean()
    sm = price.ewm(span=slow, adjust=False).mean()
    md = fm - sm
    signal = md - md.ewm(span=signal, adjust=False).mean()
    return signal

In [94]:
features = [
    Stream.source(list(data[c]), dtype="float").rename(data[c].name)
    for c in data.columns[1:]
]

In [95]:
close = Stream.select(features, lambda s: s.name == "close")

In [96]:
class Listener:
    def on_next(self, value):
        print(value)

close.attach(Listener())

<abc.Stream at 0x2b5e59ee7b00>

In [97]:
from tensortrade.feed.core import Stream
ss = Stream.source([1, 2, 3, 4, 5], dtype="float")

In [107]:
ff = DataFeed([ss.ewm(span=1).mean()])

abc.Stream # <abc._Stream object at 0x2b5e59ef37b8>
<class 'tuple'> # (<abc._Stream object at 0x2b5e59ef37b8>,)


In [111]:
ff.next()

{'stream:/207': 4.0}

In [53]:
features = [
    close.ewm(span=14).mean().rename("ema"),
    close.ewm(alpha=1).mean().rename("sma"),
    close.log().diff().rename("lr"),
    rsi(close, period=20).rename("rsi"),
    macd(close, fast=10, slow=50, signal=5).rename("macd")
]

feed = DataFeed(features)
feed.compile()

In [54]:
import json

for i in range(5):
    obsv = feed.next()
#     print(data)
    print(json.dumps(obsv, indent=4))

2509.17
{
    "ema": 2509.17,
    "sma": 2509.17,
    "lr": NaN,
    "rsi": NaN,
    "macd": 0.0
}
2488.43
{
    "ema": 2498.0592857142856,
    "sma": 2488.43,
    "lr": -0.008300031641449657,
    "rsi": 0.0,
    "macd": -1.9717171717171975
}
2454.43
{
    "ema": 2481.3927504244484,
    "sma": 2454.43,
    "lr": -0.01375743446296962,
    "rsi": 0.0,
    "macd": -6.082702245269603
}
2459.35
{
    "ema": 2474.649251269036,
    "sma": 2459.35,
    "lr": 0.0020025323250756344,
    "rsi": 8.795475693113076,
    "macd": -7.287625162566419
}
2467.83
{
    "ema": 2472.8701190470056,
    "sma": 2467.83,
    "lr": 0.00344213459739251,
    "rsi": 21.34663357024277,
    "macd": -6.522181201739986
}


## Setup Trading Environment

In [59]:
coinbase = Exchange("coinbase", service=execute_order)(
    Stream.source(list(data["close"]), dtype="float").rename("USD-BTC")
)

portfolio = Portfolio(USD, [
    Wallet(coinbase, 10000 * USD),
    Wallet(coinbase, 10 * BTC)
])


renderer_feed = DataFeed([
    Stream.source(list(data["date"])).rename("date"),
    Stream.source(list(data["open"]), dtype="float").rename("open"),
    Stream.source(list(data["high"]), dtype="float").rename("high"),
    Stream.source(list(data["low"]), dtype="float").rename("low"),
    Stream.source(list(data["close"]), dtype="float").rename("close"), 
    Stream.source(list(data["volume"]), dtype="float").rename("volume") 
])


env = default.create(
    portfolio=portfolio,
    action_scheme="managed-risk",
    reward_scheme="risk-adjusted",
    feed=feed,
    renderer_feed=renderer_feed,
#     renderer=default.renderers.PlotlyTradingChart(),
    window_size=20
)

In [60]:
env.observer.feed.next()

{'internal': {'coinbase:/USD-BTC': 2509.17,
  'coinbase:/USD:/free': 10000.0,
  'coinbase:/USD:/locked': 0.0,
  'coinbase:/USD:/total': 10000.0,
  'coinbase:/BTC:/free': 10.0,
  'coinbase:/BTC:/locked': 0.0,
  'coinbase:/BTC:/total': 10.0,
  'coinbase:/BTC:/worth': 25091.7,
  'net_worth': 35091.7},
 'external': {'lr': nan, 'rsi': nan, 'macd': 0.0},
 'renderer': {'date': Timestamp('2017-07-01 11:00:00'),
  'open': 2505.56,
  'high': 2513.38,
  'low': 2495.12,
  'close': 2509.17,
  'volume': 287000.32}}

## Setup and Train DQN Agent

In [61]:
agent = DQNAgent(env)

agent.train(n_steps=200, n_episodes=2, save_path="agents/")

====      AGENT ID: 419593c4-ddab-4681-b233-f28da56e5c04      ====


OSError: Unable to create file (unable to open file: name = 'agents/policy_network__419593c4-ddab-4681-b233-f28da56e5c04__001.hdf5', errno = 2, error message = 'No such file or directory', flags = 13, o_flags = 242)