<a href="https://colab.research.google.com/github/jadechip/rl-trading/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Baselines, a Fork of OpenAI Baselines - Training, Saving and Loading

Github Repo: [https://github.com/hill-a/stable-baselines](https://github.com/hill-a/stable-baselines)

Medium article: [https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-df87c4b2fc82)

[RL Baselines Zoo](https://github.com/araffin/rl-baselines-zoo) is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines.

It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos.

Documentation is available online: [https://stable-baselines.readthedocs.io/](https://stable-baselines.readthedocs.io/)

## Install Dependencies and Stable Baselines Using Pip

List of full dependencies can be found in the [README](https://github.com/hill-a/stable-baselines).

```
sudo apt-get update && sudo apt-get install cmake libopenmpi-dev zlib1g-dev
```


```
pip install stable-baselines[mpi]
```

In [6]:
# Stable Baselines only supports tensorflow 1.x for now
%tensorflow_version 1.x
!apt install swig cmake libopenmpi-dev zlib1g-dev
!pip install stable-baselines[mpi]==2.10.2 box2d box2d-kengz


TensorFlow 1.x selected.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
zlib1g-dev is already the newest version (1:1.2.11.dfsg-0ubuntu2).
zlib1g-dev set to manually installed.
libopenmpi-dev is already the newest version (2.1.1-8).
cmake is already the newest version (3.10.2-1ubuntu2.18.04.2).
The following additional packages will be installed:
  swig3.0
Suggested packages:
  swig-doc swig-examples swig3.0-examples swig3.0-doc
The following NEW packages will be installed:
  swig swig3.0
0 upgraded, 2 newly installed, 0 to remove and 37 not upgraded.
Need to get 1,100 kB of archives.
After this operation, 5,822 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 swig3.0 amd64 3.0.12-1 [1,094 kB]
Get:2 http://archive.ubuntu.com/ubuntu bionic/universe amd64 swig amd64 3.0.12-1 [6,460 B]
Fetched 1,100 kB in 1s (832 kB/s)
Selecting previously unselected package swig3.0.
(Reading database ..

In [7]:
from stable_baselines.common.policies import MlpPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines import PPO2

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



  "stable-baselines is in maintenance mode, please use [Stable-Baselines3 (SB3)](https://github.com/DLR-RM/stable-baselines3) for an up-to-date version. You can find a [migration guide](https://stable-baselines3.readthedocs.io/en/master/guide/migration.html) in SB3 documentation."


## Environment

In [8]:
import logging
import pandas as pd
import numpy as np
from numpy import inf
import gym
from gym import spaces
from sklearn import preprocessing
from statsmodels.tsa.statespace.sarimax import SARIMAX

  import pandas.util.testing as tm


### Indicators

In [9]:
!pip install ta==0.4.4

Collecting ta==0.4.4
  Downloading ta-0.4.4.tar.gz (12 kB)
Building wheels for collected packages: ta
  Building wheel for ta (setup.py) ... [?25l[?25hdone
  Created wheel for ta: filename=ta-0.4.4-py3-none-any.whl size=16709 sha256=6f1b639c3c4dc58495164eae487fb36da817526b21ff6b9633010405b41ceb42
  Stored in directory: /root/.cache/pip/wheels/41/d1/5b/e665e6917d563061679aa7ce2fa2b161fb8363ed0fc4da0662
Successfully built ta
Installing collected packages: ta
Successfully installed ta-0.4.4


In [10]:
import ta

def add_indicators(df):
    df['RSI'] = ta.rsi(df["Close"])
    df['MFI'] = ta.money_flow_index(
        df["High"], df["Low"], df["Close"], df["Volume BTC"])
    df['TSI'] = ta.tsi(df["Close"])
    df['UO'] = ta.uo(df["High"], df["Low"], df["Close"])
    df['AO'] = ta.ao(df["High"], df["Low"])

    df['MACD_diff'] = ta.macd_diff(df["Close"])
    df['Vortex_pos'] = ta.vortex_indicator_pos(
        df["High"], df["Low"], df["Close"])
    df['Vortex_neg'] = ta.vortex_indicator_neg(
        df["High"], df["Low"], df["Close"])
    df['Vortex_diff'] = abs(
        df['Vortex_pos'] -
        df['Vortex_neg'])
    df['Trix'] = ta.trix(df["Close"])
    df['Mass_index'] = ta.mass_index(df["High"], df["Low"])
    df['CCI'] = ta.cci(df["High"], df["Low"], df["Close"])
    df['DPO'] = ta.dpo(df["Close"])
    df['KST'] = ta.kst(df["Close"])
    df['KST_sig'] = ta.kst_sig(df["Close"])
    df['KST_diff'] = (
        df['KST'] -
        df['KST_sig'])
    df['Aroon_up'] = ta.aroon_up(df["Close"])
    df['Aroon_down'] = ta.aroon_down(df["Close"])
    df['Aroon_ind'] = (
        df['Aroon_up'] -
        df['Aroon_down']
    )

    df['BBH'] = ta.bollinger_hband(df["Close"])
    df['BBL'] = ta.bollinger_lband(df["Close"])
    df['BBM'] = ta.bollinger_mavg(df["Close"])
    df['BBHI'] = ta.bollinger_hband_indicator(
        df["Close"])
    df['BBLI'] = ta.bollinger_lband_indicator(
        df["Close"])
    df['KCHI'] = ta.keltner_channel_hband_indicator(df["High"],
                                                    df["Low"],
                                                    df["Close"])
    df['KCLI'] = ta.keltner_channel_lband_indicator(df["High"],
                                                    df["Low"],
                                                    df["Close"])
    df['DCHI'] = ta.donchian_channel_hband_indicator(df["Close"])
    df['DCLI'] = ta.donchian_channel_lband_indicator(df["Close"])

    df['ADI'] = ta.acc_dist_index(df["High"],
                                  df["Low"],
                                  df["Close"],
                                  df["Volume BTC"])
    df['OBV'] = ta.on_balance_volume(df["Close"],
                                     df["Volume BTC"])
    df['CMF'] = ta.chaikin_money_flow(df["High"],
                                      df["Low"],
                                      df["Close"],
                                      df["Volume BTC"])
    df['FI'] = ta.force_index(df["Close"],
                              df["Volume BTC"])
    df['EM'] = ta.ease_of_movement(df["High"],
                                   df["Low"],
                                   df["Close"],
                                   df["Volume BTC"])
    df['VPT'] = ta.volume_price_trend(df["Close"],
                                      df["Volume BTC"])
    df['NVI'] = ta.negative_volume_index(df["Close"],
                                         df["Volume BTC"])

    df['DR'] = ta.daily_return(df["Close"])
    df['DLR'] = ta.daily_log_return(df["Close"])

    df.fillna(method='bfill', inplace=True)

    return df

### Utils

In [11]:
def difference(df, columns):
    transformed_df = df.copy()

    for column in columns:
        transformed_df[column] = transformed_df[column] - \
            transformed_df[column].shift(1)

    transformed_df = transformed_df.fillna(method='bfill')

    return transformed_df


def log_and_difference(df, columns):
    transformed_df = df.copy()

    for column in columns:
        transformed_df[column] = np.log(
            transformed_df[column]) - np.log(transformed_df[column]).shift(1)

    transformed_df = transformed_df.fillna(method='bfill')

    return transformed_df

### Visualization

In [9]:
!pip install mplfinance

Collecting mplfinance
  Downloading mplfinance-0.12.7a17-py3-none-any.whl (62 kB)
[?25l[K     |█████▎                          | 10 kB 26.1 MB/s eta 0:00:01[K     |██████████▌                     | 20 kB 19.7 MB/s eta 0:00:01[K     |███████████████▊                | 30 kB 10.9 MB/s eta 0:00:01[K     |█████████████████████           | 40 kB 9.0 MB/s eta 0:00:01[K     |██████████████████████████▎     | 51 kB 5.1 MB/s eta 0:00:01[K     |███████████████████████████████▌| 61 kB 5.5 MB/s eta 0:00:01[K     |████████████████████████████████| 62 kB 769 kB/s 
Installing collected packages: mplfinance
Successfully installed mplfinance-0.12.7a17


In [12]:
import sys
import matplotlib
import matplotlib.pyplot as plt

from matplotlib import style
from datetime import datetime
from pandas.plotting import register_matplotlib_converters

style.use('ggplot')
register_matplotlib_converters()

VOLUME_CHART_HEIGHT = 0.33


class TradingChart:
    """An OHLCV trading visualization using matplotlib made to render gym environments"""

    def __init__(self, df):
        self.df = df

        # Create a figure on screen and set the title
        self.fig = plt.figure()

        # Create top subplot for net worth axis
        self.net_worth_ax = plt.subplot2grid((6, 1), (0, 0), rowspan=2, colspan=1)

        # Create bottom subplot for shared price/volume axis
        self.price_ax = plt.subplot2grid((6, 1), (2, 0), rowspan=8, colspan=1, sharex=self.net_worth_ax)

        # Create a new axis for volume which shares its x-axis with price
        self.volume_ax = self.price_ax.twinx()

        # Add padding to make graph easier to view
        plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)

        # Show the graph without blocking the rest of the program
        plt.show(block=False)

    def _render_net_worth(self, step_range, times, current_step, net_worths, benchmarks):
        # Clear the frame rendered last step
        self.net_worth_ax.clear()

        # Plot net worths
        self.net_worth_ax.plot(times, net_worths[step_range], label='Net Worth', color="g")

        self._render_benchmarks(step_range, times, benchmarks)

        # Show legend, which uses the label we defined for the plot above
        self.net_worth_ax.legend()
        legend = self.net_worth_ax.legend(loc=2, ncol=2, prop={'size': 8})
        legend.get_frame().set_alpha(0.4)

        last_time = self.df['Date'].values[current_step]
        last_net_worth = net_worths[current_step]

        # Annotate the current net worth on the net worth graph
        self.net_worth_ax.annotate('{0:.2f}'.format(last_net_worth), (last_time, last_net_worth),
                                   xytext=(last_time, last_net_worth),
                                   bbox=dict(boxstyle='round',
                                             fc='w', ec='k', lw=1),
                                   color="black",
                                   fontsize="small")

        # Add space above and below min/max net worth
        self.net_worth_ax.set_ylim(min(net_worths) / 1.25, max(net_worths) * 1.25)

    def _render_benchmarks(self, step_range, times, benchmarks):
        colors = ['orange', 'cyan', 'purple', 'blue',
                  'magenta', 'yellow', 'black', 'red', 'green']

        for i, benchmark in enumerate(benchmarks):
            self.net_worth_ax.plot(times, benchmark['values'][step_range],
                                   label=benchmark['label'], color=colors[i % len(colors)], alpha=0.3)

    def _render_price(self, step_range, times, current_step):
        self.price_ax.clear()

        # Plot price using candlestick graph from mpl_finance
        self.price_ax.plot(times, self.df['Close'].values[step_range], color="black")

        last_time = self.df['Date'].values[current_step]
        last_close = self.df['Close'].values[current_step]
        last_high = self.df['High'].values[current_step]

        # Print the current price to the price axis
        self.price_ax.annotate('{0:.2f}'.format(last_close), (last_time, last_close),
                               xytext=(last_time, last_high),
                               bbox=dict(boxstyle='round',
                                         fc='w', ec='k', lw=1),
                               color="black",
                               fontsize="small")

        # Shift price axis up to give volume chart space
        ylim = self.price_ax.get_ylim()
        self.price_ax.set_ylim(ylim[0] - (ylim[1] - ylim[0]) * VOLUME_CHART_HEIGHT, ylim[1])

    def _render_volume(self, step_range, times):
        self.volume_ax.clear()

        volume = np.array(self.df['Volume'].values[step_range])

        self.volume_ax.plot(times, volume,  color='blue')
        self.volume_ax.fill_between(times, volume, color='blue', alpha=0.5)

        self.volume_ax.set_ylim(0, max(volume) / VOLUME_CHART_HEIGHT)
        self.volume_ax.yaxis.set_ticks([])

    def _render_trades(self, step_range, trades):
        for trade in trades:
            if trade['step'] in range(sys.maxsize)[step_range]:
                date = self.df['Date'].values[trade['step']]
                close = self.df['Close'].values[trade['step']]

                if trade['type'] == 'buy':
                    color = 'g'
                else:
                    color = 'r'

                self.price_ax.annotate(' ', (date, close),
                                       xytext=(date, close),
                                       size="large",
                                       arrowprops=dict(arrowstyle='simple', facecolor=color))

    def render(self, current_step, net_worths, benchmarks, trades, window_size=200):
        net_worth = round(net_worths[-1], 2)
        initial_net_worth = round(net_worths[0], 2)
        profit_percent = round((net_worth - initial_net_worth) / initial_net_worth * 100, 2)

        self.fig.suptitle('Net worth: $' + str(net_worth) + ' | Profit: ' + str(profit_percent) + '%')

        window_start = max(current_step - window_size, 0)
        step_range = slice(window_start, current_step + 1)
        times = self.df['Date'].values[step_range]

        self._render_net_worth(step_range, times, current_step, net_worths, benchmarks)
        self._render_price(step_range, times, current_step)
        self._render_volume(step_range, times)
        self._render_trades(step_range, trades)

        date_col = pd.to_datetime(self.df['Date'], unit='s').dt.strftime('%m/%d/%Y %H:%M')
        date_labels = date_col.values[step_range]

        self.price_ax.set_xticklabels(date_labels, rotation=45, horizontalalignment='right')

        # Hide duplicate net worth date labels
        plt.setp(self.net_worth_ax.get_xticklabels(), visible=False)

        # Necessary to view frames before they are unrendered
        plt.pause(0.001)

    def close(self):
        plt.close()

In [17]:
class CryptoTradingEnv(gym.Env):
    """A Bitcoin trading environment for OpenAI gym"""
    metadata = {'render.modes': ['human', 'system', 'none']}
    scaler = preprocessing.MinMaxScaler()
    viewer = None

    def __init__(self, df, initial_balance=10000, commission=0.0003, **kwargs):
        super(CryptoTradingEnv, self).__init__()

        self.initial_balance = initial_balance
        self.commission = commission

        self.df = df.fillna(method='bfill')
        self.df = add_indicators(self.df.reset_index())
        self.stationary_df = log_and_difference(
            self.df, ['Open', 'High', 'Low', 'Close', 'Volume BTC'])

        self.n_forecasts = kwargs.get('n_forecasts', 10)
        self.confidence_interval = kwargs.get('confidence_interval', 0.95)
        self.obs_shape = (1, 5 + len(self.df.columns) -
                          2 + (self.n_forecasts * 3))

        # Actions of the format Buy 1/4, Sell 3/4, Hold (amount ignored), etc.
        self.action_space = spaces.Discrete(12)

        # Observes the price action, indicators, account action, price forecasts
        self.observation_space = spaces.Box(
            low=0, high=1, shape=self.obs_shape, dtype=np.float16)

    def _next_observation(self):
        features = self.stationary_df[self.stationary_df.columns.difference([
            'index', 'Date'])]

        scaled = features[:self.current_step + self.n_forecasts].values
        scaled[abs(scaled) == inf] = 0
        scaled = self.scaler.fit_transform(scaled.astype('float64'))
        scaled = pd.DataFrame(scaled, columns=features.columns)

        obs = scaled.values[-1]

        past_df = self.stationary_df['Close'][:
                                              self.current_step + self.n_forecasts]
        forecast_model = SARIMAX(past_df.values)
        model_fit = forecast_model.fit(
            method='bfgs', disp=False)
        forecast = model_fit.get_forecast(
            steps=self.n_forecasts, alpha=(1 - self.confidence_interval))

        obs = np.insert(obs, len(obs), forecast.predicted_mean, axis=0)
        obs = np.insert(obs, len(obs), forecast.conf_int().flatten(), axis=0)

        scaled_history = self.scaler.fit_transform(
            self.account_history.astype('float64'))

        obs = np.insert(
            obs, len(obs), scaled_history[:, self.current_step], axis=0)

        obs = np.reshape(obs.astype('float16'), self.obs_shape)

        return obs

    def _get_current_price(self):
        return self.df['Close'].values[self.current_step + self.n_forecasts]

    def _take_action(self, action, current_price):
        action_type = int(action / 4)
        amount = 1 / (action % 4 + 1)

        btc_bought = 0
        btc_sold = 0
        cost = 0
        sales = 0

        if action_type == 0:
            price = current_price * (1 + self.commission)
            btc_bought = min(self.balance * amount /
                             price, self.balance / price)
            cost = btc_bought * price

            self.btc_held += btc_bought
            self.balance -= cost
        elif action_type == 1:
            price = current_price * (1 - self.commission)
            btc_sold = self.btc_held * amount
            sales = btc_sold * price

            self.btc_held -= btc_sold
            self.balance += sales

        if btc_sold > 0 or btc_bought > 0:
            self.trades.append({'step': self.current_step,
                                'amount': btc_sold if btc_sold > 0 else btc_bought, 'total': sales if btc_sold > 0 else cost,
                                'type': "sell" if btc_sold > 0 else "buy"})

        self.net_worth = self.balance + self.btc_held * current_price

        self.account_history = np.append(self.account_history, [
            [self.balance],
            [btc_bought],
            [cost],
            [btc_sold],
            [sales]
        ], axis=1)

    def reset(self):
        self.balance = self.initial_balance
        self.net_worth = self.initial_balance
        self.btc_held = 0
        self.current_step = 0

        self.account_history = np.array([
            [self.balance],
            [0],
            [0],
            [0],
            [0]
        ])
        self.trades = []

        return self._next_observation()

    def step(self, action):
        current_price = self._get_current_price() + 0.01

        prev_net_worth = self.net_worth

        self._take_action(action, current_price)

        self.current_step += 1

        obs = self._next_observation()
        reward = self.net_worth - prev_net_worth
        done = self.net_worth < self.initial_balance / \
            10 or self.current_step == len(self.df) - self.n_forecasts

        return obs, reward, done, {}

    def render(self, mode='human', **kwargs):
        if mode == 'system':
            print('Price: ' + str(self._get_current_price()))
            print(
                'Bought: ' + str(self.account_history[2][self.current_step]))
            print(
                'Sold: ' + str(self.account_history[4][self.current_step]))
            print('Net worth: ' + str(self.net_worth))

        elif mode == 'human':
            if self.viewer is None:
                self.viewer = BitcoinTradingGraph(
                    self.df, kwargs.get('title', None))

            self.viewer.render(self.current_step, self.net_worth, self.trades)

        elif mode == 'human':
            if self.viewer is None:
                self.viewer = TradingChart(self.df)

            self.viewer.render(self.current_step,
                               self.net_worth,
                               self.render_benchmarks,
                               self.trades)



    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None

### Second

### Strategies

In [19]:
from abc import ABCMeta, abstractmethod
from typing import List, Callable, Tuple, Iterable

In [14]:
class BaseRewardStrategy(object, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self):
        pass

    @abstractmethod
    def reset_reward(self):
        raise NotImplementedError()

    @abstractmethod
    def get_reward(self,
                   current_step: int,
                   current_price: Callable[[str], float],
                   observations: pd.DataFrame,
                   account_history: pd.DataFrame,
                   net_worths: List[float]) -> float:
        raise NotImplementedError()


class IncrementalProfit(BaseRewardStrategy):
    last_bought: int = 0
    last_sold: int = 0

    def __init__(self):
        pass

    def reset_reward(self):
        pass

    def get_reward(self,
                   current_step: int,
                   current_price: Callable[[str], float],
                   observations: pd.DataFrame,
                   account_history: pd.DataFrame,
                   net_worths: List[float]) -> float:
        reward = 0

        curr_balance = account_history['balance'].values[-1]
        prev_balance = account_history['balance'].values[-2] if len(account_history['balance']) > 1 else curr_balance

        if curr_balance > prev_balance:
            reward = net_worths[-1] - net_worths[self.last_bought]
            self.last_sold = current_step
        elif curr_balance < prev_balance:
            reward = observations['Close'].values[self.last_sold] - current_price()
            self.last_bought = current_step

        return reward


class WeightedUnrealizedProfit(BaseRewardStrategy):
    def __init__(self, **kwargs):
        self.decay_rate = kwargs.get('decay_rate', 1e-2)
        self.decay_denominator = np.exp(-1 * self.decay_rate)

        self.reset_reward()

    def reset_reward(self):
        self.rewards = deque(np.zeros(1, dtype=float))
        self.sum = 0.0

    def calc_reward(self, reward):
        self.sum = self.sum - self.decay_denominator * self.rewards.popleft()
        self.sum = self.sum * self.decay_denominator
        self.sum = self.sum + reward

        self.rewards.append(reward)

        return self.sum / self.decay_denominator

    def get_reward(self,
                   current_step: int,
                   current_price: Callable[[str], float],
                   observations: pd.DataFrame,
                   account_history: pd.DataFrame,
                   net_worths: List[float]) -> float:
        if account_history['asset_sold'].values[-1] > 0:
            reward = self.calc_reward(account_history['sale_revenue'].values[-1])
        else:
            reward = self.calc_reward(account_history['asset_held'].values[-1] * current_price())

        return reward

In [15]:
class BaseTradeStrategy(object, metaclass=ABCMeta):
    @abstractmethod
    def __init__(self,
                 commissionPercent: float,
                 maxSlippagePercent: float,
                 base_precision: int,
                 asset_precision: int,
                 min_cost_limit: float,
                 min_amount_limit: float):
        pass

    @abstractmethod
    def trade(self,
              action: int,
              n_discrete_actions: int,
              balance: float,
              asset_held: float,
              current_price: Callable[[str], float]) -> Tuple[float, float, float, float]:
        raise NotImplementedError()


class LiveTradeStrategy(BaseTradeStrategy):
    def __init__(self,
                 commissionPercent: float,
                 maxSlippagePercent: float,
                 base_precision: int,
                 asset_precision: int,
                 min_cost_limit: float,
                 min_amount_limit: float):
        self.commissionPercent = commissionPercent
        self.maxSlippagePercent = maxSlippagePercent
        self.base_precision = base_precision
        self.asset_precision = asset_precision
        self.min_cost_limit = min_cost_limit
        self.min_amount_limit = min_amount_limit

    def trade(self,
              buy_amount: float,
              sell_amount: float,
              balance: float,
              asset_held: float,
              current_price: Callable[[str], float]) -> Tuple[float, float, float, float]:
        raise NotImplementedError()


class SimulatedTradeStrategy(BaseTradeStrategy):
    def __init__(self,
                 commissionPercent: float,
                 maxSlippagePercent: float,
                 base_precision: int,
                 asset_precision: int,
                 min_cost_limit: float,
                 min_amount_limit: float):
        self.commissionPercent = commissionPercent
        self.maxSlippagePercent = maxSlippagePercent
        self.base_precision = base_precision
        self.asset_precision = asset_precision
        self.min_cost_limit = min_cost_limit
        self.min_amount_limit = min_amount_limit

    def trade(self,
              buy_amount: float,
              sell_amount: float,
              balance: float,
              asset_held: float,
              current_price: Callable[[str], float]) -> Tuple[float, float, float, float]:
        current_price = current_price('Close')
        commission = self.commissionPercent / 100
        slippage = np.random.uniform(0, self.maxSlippagePercent) / 100

        asset_bought, asset_sold, purchase_cost, sale_revenue = buy_amount, sell_amount, 0, 0

        if buy_amount > 0 and balance >= self.min_cost_limit:
            price_adjustment = (1 + commission) * (1 + slippage)
            buy_price = round(current_price * price_adjustment, self.base_precision)
            purchase_cost = round(buy_price * buy_amount, self.base_precision)
        elif sell_amount > 0 and asset_held >= self.min_amount_limit:
            price_adjustment = (1 - commission) * (1 - slippage)
            sell_price = round(current_price * price_adjustment, self.base_precision)
            sale_revenue = round(sell_amount * sell_price, self.base_precision)

        return asset_bought, asset_sold, purchase_cost, sale_revenue

### Data Providers

In [28]:
from enum import Enum
from datetime import datetime

class ProviderDateFormat(Enum):
    TIMESTAMP_UTC = 1
    TIMESTAMP_MS = 2
    DATE = 3
    DATETIME_HOUR_12 = 4
    DATETIME_HOUR_24 = 5
    DATETIME_MINUTE_12 = 6
    DATETIME_MINUTE_24 = 7
    CUSTOM_DATIME = 8

class BaseDataProvider(object, metaclass=ABCMeta):
    columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
    in_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
    custom_datetime_format = None

    @abstractmethod
    def __init__(self, date_format: ProviderDateFormat, **kwargs):
        self.date_format = date_format

        self.custom_datetime_format: str = kwargs.get('custom_datetime_format', None)

        data_columns: Dict[str, str] = kwargs.get('data_columns', None)

        if data_columns is not None:
            self.data_columns = data_columns
            self.columns = list(data_columns.keys())
            self.in_columns = list(data_columns.values())
        else:
            self.data_columns = dict(zip(self.columns, self.in_columns))

    @abstractmethod
    def split_data_train_test(self, train_split_percentage: float = 0.8) -> Tuple:
        raise NotImplementedError

    @abstractmethod
    def historical_ohlcv(self) -> pd.DataFrame:
        raise NotImplementedError

    @abstractmethod
    def reset_ohlcv_index(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def has_next_ohlcv(self) -> bool:
        raise NotImplementedError

    @abstractmethod
    def next_ohlcv(self) -> pd.DataFrame:
        raise NotImplementedError

    def prepare_data(self, data_frame: pd.DataFrame, inplace: bool = True) -> pd.DataFrame:
        column_map = dict(zip(self.in_columns, self.columns))

        formatted = data_frame[self.in_columns]
        formatted = formatted.rename(index=str, columns=column_map)

        formatted = self._format_date_column(formatted, inplace=inplace)
        formatted = self._sort_by_date(formatted, inplace=inplace)

        return formatted

    def _sort_by_date(self, data_frame: pd.DataFrame, inplace: bool = True) -> pd.DataFrame:
        if inplace is True:
            formatted = data_frame
        else:
            formatted = data_frame.copy()

        formatted = formatted.sort_values(self.data_columns['Date'])

        return formatted

    def _format_date_column(self, data_frame: pd.DataFrame, inplace: bool = True) -> pd.DataFrame:
        if inplace is True:
            formatted = data_frame
        else:
            formatted = data_frame.copy()

        date_col = self.data_columns['Date']
        date_frame = formatted.loc[:, date_col]

        if self.date_format is ProviderDateFormat.TIMESTAMP_UTC:
            formatted[date_col] = date_frame.apply(
                lambda x: datetime.utcfromtimestamp(x).strftime('%Y-%m-%d %H:%M'))
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d %H:%M')
        elif self.date_format is ProviderDateFormat.TIMESTAMP_MS:
            formatted[date_col] = pd.to_datetime(date_frame, unit='ms')
        elif self.date_format is ProviderDateFormat.DATETIME_HOUR_12:
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d %I-%p')
        elif self.date_format is ProviderDateFormat.DATETIME_HOUR_24:
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d %H')
        elif self.date_format is ProviderDateFormat.DATETIME_MINUTE_12:
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d %I:%M-%p')
        elif self.date_format is ProviderDateFormat.DATETIME_MINUTE_24:
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d %H:%M')
        elif self.date_format is ProviderDateFormat.DATE:
            formatted[date_col] = pd.to_datetime(date_frame, format='%Y-%m-%d')
        elif self.date_format is ProviderDateFormat.CUSTOM_DATIME:
            formatted[date_col] = pd.to_datetime(
                date_frame, format=self.custom_datetime_format, infer_datetime_format=True)
        else:
            raise NotImplementedError

        formatted[date_col] = formatted[date_col].values.astype(np.int64) // 10 ** 9

        return formatted


class StaticDataProvider(BaseDataProvider):
    _current_index = 0

    def __init__(self, date_format: ProviderDateFormat, data_frame: pd.DataFrame = None, csv_data_path: str = None,
                 skip_prepare_data: bool = False, **kwargs):
        BaseDataProvider.__init__(self, date_format, **kwargs)

        self.kwargs = kwargs

        if data_frame is not None:
            self.data_frame = data_frame
        elif csv_data_path is not None:
            if not os.path.isfile(csv_data_path):
                raise ValueError(
                    'Invalid "csv_data_path" argument passed to StaticDataProvider, file could not be found.')

            self.data_frame = pd.read_csv(csv_data_path)
        else:
            raise ValueError(
                'StaticDataProvider requires either a "data_frame" or "csv_data_path argument".')

        if not skip_prepare_data:
            self.data_frame = self.prepare_data(self.data_frame)

    @staticmethod
    def from_prepared(data_frame: pd.DataFrame, date_format: ProviderDateFormat, **kwargs):
        return StaticDataProvider(date_format=date_format, data_frame=data_frame, skip_prepare_data=True, **kwargs)

    def split_data_train_test(self, train_split_percentage: float = 0.8) -> Tuple[BaseDataProvider, BaseDataProvider]:
        train_len = int(train_split_percentage * len(self.data_frame))

        train_df = self.data_frame[:train_len].copy()
        test_df = self.data_frame[train_len:].copy()

        train_provider = StaticDataProvider.from_prepared(
            data_frame=train_df, date_format=self.date_format, **self.kwargs)
        test_provider = StaticDataProvider.from_prepared(
            data_frame=test_df, date_format=self.date_format, **self.kwargs)

        return train_provider, test_provider

    def historical_ohlcv(self) -> pd.DataFrame:
        return self.data_frame

    def has_next_ohlcv(self) -> bool:
        return self._current_index < len(self.data_frame)

    def reset_ohlcv_index(self) -> int:
        self._current_index = 0

    def next_ohlcv(self) -> pd.DataFrame:
        frame = self.data_frame[self.columns].values[self._current_index]
        frame = pd.DataFrame([frame], columns=self.columns)

        self._current_index += 1

        return frame

### Transforms

In [20]:
@abstractmethod
def transform(iterable: Iterable, inplace: bool = True, columns: List[str] = None, transform_fn: Callable[[Iterable], Iterable] = None):
    if inplace is True:
        transformed_iterable = iterable
    else:
        transformed_iterable = iterable.copy()

    if isinstance(transformed_iterable, pd.DataFrame):
        is_list = False
    else:
        is_list = True
        transformed_iterable = pd.DataFrame(transformed_iterable, columns=columns)

    transformed_iterable.fillna(0, inplace=True)

    if transform_fn is None:
        raise NotImplementedError()

    if columns is None:
        columns = transformed_iterable.columns

    for column in columns:
        transformed_iterable[column] = transform_fn(transformed_iterable[column])

    transformed_iterable.fillna(method="bfill", inplace=True)
    transformed_iterable[np.bitwise_not(np.isfinite(transformed_iterable))] = 0

    if is_list:
        transformed_iterable = transformed_iterable.values

    return transformed_iterable


def max_min_normalize(iterable: Iterable, inplace: bool = True, columns: List[str] = None):
    return transform(iterable, inplace, columns, lambda t_iterable: (t_iterable - t_iterable.min()) / (t_iterable.max() - t_iterable.min()))


def mean_normalize(iterable: Iterable, inplace: bool = True, columns: List[str] = None):
    return transform(iterable, inplace, columns, lambda t_iterable: (t_iterable - t_iterable.mean()) / t_iterable.std())


def difference(iterable: Iterable, inplace: bool = True, columns: List[str] = None):
    return transform(iterable, inplace, columns, lambda t_iterable: t_iterable - t_iterable.shift(1))


def log_and_difference(iterable: Iterable, inplace: bool = True, columns: List[str] = None):
    return transform(iterable, inplace, columns, lambda t_iterable: np.log(t_iterable) - np.log(t_iterable).shift(1))

### Logger

In [22]:
!pip install colorlog

Collecting colorlog
  Downloading colorlog-6.4.1-py2.py3-none-any.whl (11 kB)
Installing collected packages: colorlog
Successfully installed colorlog-6.4.1


In [40]:
import os
import logging
import colorlog

def init_logger(dunder_name, show_debug=False) -> logging.Logger:
    log_format = (
        '%(asctime)s - '
        '%(name)s - '
        '%(funcName)s - '
        '%(levelname)s - '
        '%(message)s'
    )
    bold_seq = '\033[1m'
    colorlog_format = (
        f'{bold_seq} '
        '%(log_color)s '
        f'{log_format}'
    )
    colorlog.basicConfig(format=colorlog_format)
    logging.getLogger('tensorflow').disabled = True
    logger = logging.getLogger(dunder_name)

    if show_debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    return logger

### Environment

In [50]:
from gym import spaces
from enum import Enum
from typing import List, Dict


class TradingEnvAction(Enum):
    BUY = 0
    SELL = 1
    HOLD = 2


class TradingEnv(gym.Env):
    '''A reinforcement trading environment made for use with gym-enabled algorithms'''
    metadata = {'render.modes': ['human', 'system', 'none']}
    viewer = None

    def __init__(self,
                 data_provider: BaseDataProvider,
                 reward_strategy: BaseRewardStrategy = IncrementalProfit,
                 trade_strategy: BaseTradeStrategy = SimulatedTradeStrategy,
                 initial_balance: int = 10000,
                 commissionPercent: float = 0.25,
                 maxSlippagePercent: float = 2.0,
                 **kwargs):
        super(TradingEnv, self).__init__()

        self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))

        self.base_precision: int = kwargs.get('base_precision', 2)
        self.asset_precision: int = kwargs.get('asset_precision', 8)
        self.min_cost_limit: float = kwargs.get('min_cost_limit', 1E-3)
        self.min_amount_limit: float = kwargs.get('min_amount_limit', 1E-3)

        self.initial_balance = round(initial_balance, self.base_precision)
        self.commissionPercent = commissionPercent
        self.maxSlippagePercent = maxSlippagePercent

        self.data_provider = data_provider
        self.reward_strategy = reward_strategy()
        self.trade_strategy = trade_strategy(commissionPercent=self.commissionPercent,
                                             maxSlippagePercent=self.maxSlippagePercent,
                                             base_precision=self.base_precision,
                                             asset_precision=self.asset_precision,
                                             min_cost_limit=self.min_cost_limit,
                                             min_amount_limit=self.min_amount_limit)

        self.render_benchmarks: List[Dict] = kwargs.get('render_benchmarks', [])
        self.normalize_obs: bool = kwargs.get('normalize_obs', True)
        self.stationarize_obs: bool = kwargs.get('stationarize_obs', True)
        self.normalize_rewards: bool = kwargs.get('normalize_rewards', False)
        self.stationarize_rewards: bool = kwargs.get('stationarize_rewards', True)

        self.n_discrete_actions: int = kwargs.get('n_discrete_actions', 24)
        self.action_space = spaces.Discrete(self.n_discrete_actions)

        self.n_features = 6 + len(self.data_provider.columns)
        self.obs_shape = (1, self.n_features)
        self.observation_space = spaces.Box(low=0, high=1, shape=self.obs_shape, dtype=np.float16)

        self.observations = pd.DataFrame(None, columns=self.data_provider.columns)

    def _current_price(self, ohlcv_key: str = 'Close'):
        return float(self.current_ohlcv[ohlcv_key])

    def _get_trade(self, action: int):
        n_action_types = 3
        n_amount_bins = int(self.n_discrete_actions / n_action_types)

        action_type: TradingEnvAction = TradingEnvAction(action % n_action_types)
        action_amount = float(1 / (action % n_amount_bins + 1))

        amount_asset_to_buy = 0
        amount_asset_to_sell = 0

        if action_type == TradingEnvAction.BUY and self.balance >= self.min_cost_limit:
            price_adjustment = (1 + (self.commissionPercent / 100)) * (1 + (self.maxSlippagePercent / 100))
            buy_price = round(self._current_price() * price_adjustment, self.base_precision)
            amount_asset_to_buy = round(self.balance * action_amount / buy_price, self.asset_precision)
        elif action_type == TradingEnvAction.SELL and self.asset_held >= self.min_amount_limit:
            amount_asset_to_sell = round(self.asset_held * action_amount, self.asset_precision)

        return amount_asset_to_buy, amount_asset_to_sell

    def _take_action(self, action: int):
        amount_asset_to_buy, amount_asset_to_sell = self._get_trade(action)

        asset_bought, asset_sold, purchase_cost, sale_revenue = self.trade_strategy.trade(buy_amount=amount_asset_to_buy,
                                                                                          sell_amount=amount_asset_to_sell,
                                                                                          balance=self.balance,
                                                                                          asset_held=self.asset_held,
                                                                                          current_price=self._current_price)

        if asset_bought:
            self.asset_held += asset_bought
            self.balance -= purchase_cost

            self.trades.append({'step': self.current_step,
                                'amount': asset_bought,
                                'total': purchase_cost,
                                'type': 'buy'})
        elif asset_sold:
            self.asset_held -= asset_sold
            self.balance += sale_revenue

            self.reward_strategy.reset_reward()

            self.trades.append({'step': self.current_step,
                                'amount': asset_sold,
                                'total': sale_revenue,
                                'type': 'sell'})

        current_net_worth = round(self.balance + self.asset_held * self._current_price(), self.base_precision)
        self.net_worths.append(current_net_worth)
        self.account_history = self.account_history.append({
            'balance': self.balance,
            'asset_held': self.asset_held,
            'asset_bought': asset_bought,
            'purchase_cost': purchase_cost,
            'asset_sold': asset_sold,
            'sale_revenue': sale_revenue,
        }, ignore_index=True)

    def _done(self):
        lost_90_percent_net_worth = float(self.net_worths[-1]) < (self.initial_balance / 10)
        has_next_frame = self.data_provider.has_next_ohlcv()

        return lost_90_percent_net_worth or not has_next_frame

    def _reward(self):
        reward = self.reward_strategy.get_reward(current_step=self.current_step,
                                                 current_price=self._current_price,
                                                 observations=self.observations,
                                                 account_history=self.account_history,
                                                 net_worths=self.net_worths)

        reward = float(reward) if np.isfinite(float(reward)) else 0

        self.rewards.append(reward)

        if self.stationarize_rewards:
            rewards = difference(self.rewards, inplace=False)
        else:
            rewards = self.rewards

        if self.normalize_rewards:
            mean_normalize(rewards, inplace=True)

        rewards = np.array(rewards).flatten()

        return float(rewards[-1])

    def _next_observation(self):
        self.current_ohlcv = self.data_provider.next_ohlcv()
        self.timestamps.append(pd.to_datetime(self.current_ohlcv.Date.item(), unit='s'))
        self.observations = self.observations.append(self.current_ohlcv, ignore_index=True)

        if self.stationarize_obs:
            observations = log_and_difference(self.observations, inplace=False)
        else:
            observations = self.observations

        if self.normalize_obs:
            observations = max_min_normalize(observations)

        obs = observations.values[-1]

        if self.stationarize_obs:
            scaled_history = log_and_difference(self.account_history, inplace=False)
        else:
            scaled_history = self.account_history

        if self.normalize_obs:
            scaled_history = max_min_normalize(scaled_history, inplace=False)

        obs = np.insert(obs, len(obs), scaled_history.values[-1], axis=0)

        obs = np.reshape(obs.astype('float16'), self.obs_shape)
        obs[np.bitwise_not(np.isfinite(obs))] = 0

        return obs

    def reset(self):
        self.data_provider.reset_ohlcv_index()

        self.balance = self.initial_balance
        self.net_worths = [self.initial_balance]
        self.timestamps = []
        self.asset_held = 0
        self.current_step = 0

        self.reward_strategy.reset_reward()

        self.account_history = pd.DataFrame([{
            'balance': self.balance,
            'asset_held': self.asset_held,
            'asset_bought': 0,
            'purchase_cost': 0,
            'asset_sold': 0,
            'sale_revenue': 0,
        }])
        self.trades = []
        self.rewards = [0]

        return self._next_observation()

    def step(self, action):
        self._take_action(action)

        self.current_step += 1

        obs = self._next_observation()
        reward = self._reward()
        done = self._done()

        return obs, reward, done, {'net_worths': self.net_worths, 'timestamps': self.timestamps}

    def render(self, mode='human'):
        if mode == 'system':
            self.logger.info('Price: ' + str(self._current_price()))
            self.logger.info('Bought: ' + str(self.account_history['asset_bought'][self.current_step]))
            self.logger.info('Sold: ' + str(self.account_history['asset_sold'][self.current_step]))
            self.logger.info('Net worth: ' + str(self.net_worths[-1]))

        elif mode == 'human':
            if self.viewer is None:
                self.viewer = TradingChart(self.data_provider.data_frame)

            self.viewer.render(self.current_step,
                               self.net_worths,
                               self.render_benchmarks,
                               self.trades)

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None

## Data

In [27]:
!wget https://raw.githubusercontent.com/jadechip/rl-trading/main/data/BTCUSDT-1m-data.csv

--2021-10-05 12:54:39--  https://raw.githubusercontent.com/jadechip/rl-trading/main/data/BTCUSDT-1m-data.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42827776 (41M) [text/plain]
Saving to: ‘BTCUSDT-1m-data.csv’


2021-10-05 12:54:41 (195 MB/s) - ‘BTCUSDT-1m-data.csv’ saved [42827776/42827776]



In [43]:
data_columns = {'Date': 'Date', 'Open': 'Open', 'High': 'High',
                'Low': 'Low', 'Close': 'Close', 'Volume': 'Volume'}

data_provider = StaticDataProvider(date_format=ProviderDateFormat.DATETIME_MINUTE_24,
                                        csv_data_path='./BTCUSDT-1m-data.csv',
                                        data_columns=data_columns)

In [44]:
train_provider, test_provider = data_provider.split_data_train_test(0.8)

In [51]:
train_env = DummyVecEnv([lambda: TradingEnv(train_provider)])
train_env.seed(42)
test_env = DummyVecEnv([lambda: TradingEnv(test_provider)])
test_env.seed(42)

[None]

### Training

In [46]:
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy

In [47]:
model = PPO2(MlpLstmPolicy, train_env, verbose=1, nminibatches=1)

In [48]:
model.learn(total_timesteps=2000)

  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)


--------------------------------------
| approxkl           | 2.43817e-05   |
| clipfrac           | 0.0           |
| explained_variance | 0.000179      |
| fps                | 13            |
| n_updates          | 1             |
| policy_entropy     | 3.1770954     |
| policy_loss        | -0.0016880697 |
| serial_timesteps   | 128           |
| time_elapsed       | 4.24e-05      |
| total_timesteps    | 128           |
| value_loss         | 210.43909     |
--------------------------------------
---------------------------------------
| approxkl           | 3.087093e-05   |
| clipfrac           | 0.0            |
| explained_variance | 0.000146       |
| fps                | 21             |
| n_updates          | 2              |
| policy_entropy     | 3.1766505      |
| policy_loss        | -0.00037155347 |
| serial_timesteps   | 256            |
| time_elapsed       | 9.63           |
| total_timesteps    | 256            |
| value_loss         | 325.9423       |
-------------

<stable_baselines.ppo2.ppo2.PPO2 at 0x7f29e7952750>

### Testing

In [52]:
state = None
rewards = []
obs = test_env.reset()
for i in range(2000):
  action, state = model.predict(obs, state=state)
  obs, reward, done, info = test_env.step(action)
  rewards.append(reward)
  test_env.render(mode='system')

  if done:
    print(f"Rewards: {np.sum(rewards)}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[1m [32m 2021-10-05 13:53:22,863 - __main__ - render - INFO - Sold: 0.0[0m
[1m [32m 2021-10-05 13:53:22,873 - __main__ - render - INFO - Net worth: 5150.21[0m
[1m [32m 2021-10-05 13:53:22,920 - __main__ - render - INFO - Price: 8831.31[0m
[1m [32m 2021-10-05 13:53:22,923 - __main__ - render - INFO - Bought: 0.0[0m
[1m [32m 2021-10-05 13:53:22,931 - __main__ - render - INFO - Sold: 0.0[0m
[1m [32m 2021-10-05 13:53:22,940 - __main__ - render - INFO - Net worth: 5147.72[0m
[1m [32m 2021-10-05 13:53:22,981 - __main__ - render - INFO - Price: 8810.28[0m
[1m [32m 2021-10-05 13:53:22,986 - __main__ - render - INFO - Bought: 0.0[0m
[1m [32m 2021-10-05 13:53:22,988 - __main__ - render - INFO - Sold: 0.02879023[0m
[1m [32m 2021-10-05 13:53:22,992 - __main__ - render - INFO - Net worth: 5143.69[0m
[1m [32m 2021-10-05 13:53:23,039 - __main__ - render - INFO - Price: 8800.0[0m
[1m [32m 2021-10-05 13:

Rewards: -7.829931259155273


[1m [32m 2021-10-05 13:54:59,477 - __main__ - render - INFO - Sold: 0[0m
[1m [32m 2021-10-05 13:54:59,483 - __main__ - render - INFO - Net worth: 10000.0[0m
[1m [32m 2021-10-05 13:54:59,543 - __main__ - render - INFO - Price: 8582.61[0m
[1m [32m 2021-10-05 13:54:59,545 - __main__ - render - INFO - Bought: 0.1898299[0m
[1m [32m 2021-10-05 13:54:59,548 - __main__ - render - INFO - Sold: 0.0[0m
[1m [32m 2021-10-05 13:54:59,551 - __main__ - render - INFO - Net worth: 9983.86[0m
[1m [32m 2021-10-05 13:54:59,613 - __main__ - render - INFO - Price: 8577.15[0m
[1m [32m 2021-10-05 13:54:59,617 - __main__ - render - INFO - Bought: 0.0[0m
[1m [32m 2021-10-05 13:54:59,620 - __main__ - render - INFO - Sold: 0.04745747[0m
[1m [32m 2021-10-05 13:54:59,622 - __main__ - render - INFO - Net worth: 9979.99[0m
[1m [32m 2021-10-05 13:54:59,676 - __main__ - render - INFO - Price: 8575.97[0m
[1m [32m 2021-10-05 13:54:59,680 - __main__ - render - INFO - Bought: 0.0[0m
[1m 

In [19]:
df = pd.read_csv('./BTCUSDT-1m-data.csv')
df.rename(columns={'Volume': 'Volume BTC'}, inplace=True)
features = ["Date", "Open", "High", "Low", "Close", "Volume BTC"]
df = df.filter(features)
df.head()

Unnamed: 0,Date,Open,High,Low,Close,Volume BTC
0,2017-08-17 04:00:00,4261.48,4261.48,4261.48,4261.48,1.775183
1,2017-08-17 04:01:00,4261.48,4261.48,4261.48,4261.48,0.0
2,2017-08-17 04:02:00,4280.56,4280.56,4280.56,4280.56,0.261074
3,2017-08-17 04:03:00,4261.48,4261.48,4261.48,4261.48,0.012008
4,2017-08-17 04:04:00,4261.48,4261.48,4261.48,4261.48,0.140796


In [26]:
test_len = int(len(df) * 0.2)
train_len = int(len(df)) - test_len

train_df = df[:train_len]
test_df = df[train_len:]

In [20]:
env = DummyVecEnv([lambda: CryptoTradingEnv(train_df)])
test_env = DummyVecEnv([lambda: BitcoinTradingEnv(test_df)])

  result = getattr(ufunc, method)(*inputs, **kwargs)


In [None]:
test_len = int(len(df) * 0.2)
train_len = int(len(df)) - test_len

In [22]:
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy

In [23]:
model = PPO2(MlpLstmPolicy, env, verbose=1, nminibatches=1)





Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Please use `layer.__call__` method instead.





Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where





In [24]:
model.learn(total_timesteps=2000)



--------------------------------------
| approxkl           | 5.4507167e-05 |
| clipfrac           | 0.0           |
| explained_variance | 0.00136       |
| fps                | 8             |
| n_updates          | 1             |
| policy_entropy     | 2.483169      |
| policy_loss        | -0.0020612897 |
| serial_timesteps   | 128           |
| time_elapsed       | 3.27e-05      |
| total_timesteps    | 128           |
| value_loss         | 62.22         |
--------------------------------------




--------------------------------------
| approxkl           | 6.508232e-05  |
| clipfrac           | 0.0           |
| explained_variance | -0.000217     |
| fps                | 16            |
| n_updates          | 2             |
| policy_entropy     | 2.4826193     |
| policy_loss        | -0.0006701797 |
| serial_timesteps   | 256           |
| time_elapsed       | 14.6          |
| total_timesteps    | 256           |
| value_loss         | 286.40555     |
--------------------------------------




--------------------------------------
| approxkl           | 3.8321687e-05 |
| clipfrac           | 0.0           |
| explained_variance | 0.000923      |
| fps                | 14            |
| n_updates          | 3             |
| policy_entropy     | 2.4821532     |
| policy_loss        | 3.9421022e-05 |
| serial_timesteps   | 384           |
| time_elapsed       | 22.5          |
| total_timesteps    | 384           |
| value_loss         | 393.97342     |
--------------------------------------
--------------------------------------
| approxkl           | 7.546334e-05  |
| clipfrac           | 0.0           |
| explained_variance | -0.000665     |
| fps                | 15            |
| n_updates          | 4             |
| policy_entropy     | 2.4804926     |
| policy_loss        | -0.0012065694 |
| serial_timesteps   | 512           |
| time_elapsed       | 31.4          |
| total_timesteps    | 512           |
| value_loss         | 193.01077     |
-------------------------



--------------------------------------
| approxkl           | 6.074497e-05  |
| clipfrac           | 0.0           |
| explained_variance | 0.00752       |
| fps                | 11            |
| n_updates          | 9             |
| policy_entropy     | 2.464864      |
| policy_loss        | 0.00040089898 |
| serial_timesteps   | 1152          |
| time_elapsed       | 77.4          |
| total_timesteps    | 1152          |
| value_loss         | 148.61348     |
--------------------------------------
--------------------------------------
| approxkl           | 0.0001782569  |
| clipfrac           | 0.0           |
| explained_variance | 0.00654       |
| fps                | 11            |
| n_updates          | 10            |
| policy_entropy     | 2.4603672     |
| policy_loss        | 0.00029854476 |
| serial_timesteps   | 1280          |
| time_elapsed       | 88.8          |
| total_timesteps    | 1280          |
| value_loss         | 362.08475     |
-------------------------



-------------------------------------
| approxkl           | 0.0011362231 |
| clipfrac           | 0.0          |
| explained_variance | 0.01         |
| fps                | 6            |
| n_updates          | 15           |
| policy_entropy     | 2.412603     |
| policy_loss        | 0.0024978295 |
| serial_timesteps   | 1920         |
| time_elapsed       | 149          |
| total_timesteps    | 1920         |
| value_loss         | 98.06912     |
-------------------------------------


<stable_baselines.ppo2.ppo2.PPO2 at 0x7f8163433290>

In [25]:
obs = env.reset()
for i in range(2000):
  action, _states = model.predict(obs)
  obs, rewards, done, info = test_env.step(action)
  env.render('system')

Price: 4261.48
Bought: 0
Sold: 0
Net worth: 10000.0
Price: 4261.48
Bought: 0.0
Sold: 0.0
Net worth: 10000.0
Price: 4261.48
Bought: 2500.0
Sold: 0.0
Net worth: 9999.25022493252
Price: 4261.48
Bought: 0.0
Sold: 0.0
Net worth: 9999.25022493252




Price: 4261.48
Bought: 0.0
Sold: 0.0
Net worth: 9999.25022493252
Price: 4261.48
Bought: 1875.0
Sold: 0.0
Net worth: 9998.68789363191




Price: 4264.88
Bought: 0.0
Sold: 1457.458595754607
Net worth: 9998.250524842548
Price: 4264.88
Bought: 0.0
Sold: 0.0
Net worth: 10000.576869045843




Price: 4261.48
Bought: 3541.2292978773035
Sold: 0.0
Net worth: 9999.514818871532
Price: 4264.88
Bought: 0.0
Sold: 0.0
Net worth: 9994.366228691859
Price: 4266.29
Bought: 885.3073244693259
Sold: 0.0
Net worth: 9999.249306327953




Price: 4266.29
Bought: 663.9804933519944
Sold: 0.0
Net worth: 10001.477923055567
Price: 4266.29
Bought: 0.0
Sold: 0.0
Net worth: 10001.477923055567
Price: 4266.29
Bought: 0.0
Sold: 8007.1335820666845
Net worth: 9999.075062122667
Price: 4266.29
Bought: 0.0
Sold: 0.0
Net worth: 9999.075062122667
Price: 4266.29
Bought: 0.0
Sold: 0.0
Net worth: 9999.075062122667
Price: 4266.29
Bought: 9999.075062122667
Sold: 0.0
Net worth: 9996.076239250893
Price: 4266.29
Bought: 0.0
Sold: 4996.538708189559
Net worth: 9994.576827815006
Price: 4261.45
Bought: 0.0
Sold: 0.0
Net worth: 9994.576827815006
Price: 4280.0
Bought: 0.0
Sold: 1247.717568115315
Net worth: 9988.532263446961
Price: 4274.67
Bought: 0.0
Sold: 1253.148843051262
Net worth: 10004.454920398972
Price: 4274.67
Bought: 0.0
Sold: 834.3921776529133
Net worth: 10001.08243785713




Price: 4267.99
Bought: 2777.265765669683
Sold: 0.0
Net worth: 10000.249508006385
Price: 4296.63
Bought: 0.0
Sold: 0.0
Net worth: 9993.30222865819
Price: 4300.38
Bought: 0.0
Sold: 0.0
Net worth: 10023.088168857632
Price: 4300.38
Bought: 0.0
Sold: 1117.7787362644794
Net worth: 10026.652779150747




Price: 4300.38
Bought: 0.0
Sold: 838.3340521983596
Net worth: 10026.401203462381
Price: 4300.38
Bought: 7510.644319802204
Sold: 0.0
Net worth: 10024.148685921702
Price: 4300.38
Bought: 0.0
Sold: 0.0
Net worth: 10024.148685921702




Price: 4300.38
Bought: 0.0
Sold: 0.0
Net worth: 10024.148685921702
Price: 4300.38
Bought: 0.0
Sold: 10021.141441315927
Net worth: 10021.141441315927
Price: 4300.38
Bought: 0.0
Sold: 0.0
Net worth: 10021.141441315927




Price: 4310.07
Bought: 0.0
Sold: 0.0
Net worth: 10021.141441315927
Price: 4310.07
Bought: 0.0
Sold: 0.0
Net worth: 10021.141441315927




Price: 4310.07
Bought: 0.0
Sold: 0.0
Net worth: 10021.141441315927
Price: 4310.07
Bought: 0.0
Sold: 0.0
Net worth: 10021.141441315927
Price: 4310.07
Bought: 2505.285360328982
Sold: 0.0
Net worth: 10020.39008111589




Price: 4310.07
Bought: 0.0
Sold: 0.0
Net worth: 10020.39008111589
Price: 4292.01
Bought: 0.0
Sold: 0.0
Net worth: 10020.39008111589
Price: 4310.07
Bought: 0.0
Sold: 1246.6456731925693
Net worth: 10009.521533305822
Price: 4313.62
Bought: 0.0
Sold: 417.29710665481747
Net worth: 10014.643527543982
Price: 4313.6
Bought: 0.0
Sold: 835.2816273384348
Net worth: 10015.080488172767




Price: 4313.6
Bought: 10015.080488172767
Sold: 0.0
Net worth: 10012.076865113233
Price: 4313.6
Bought: 0.0
Sold: 0.0
Net worth: 10012.076865113233




Price: 4291.37
Bought: 0.0
Sold: 0.0
Net worth: 10012.076865113233
Price: 4313.6
Bought: 0.0
Sold: 0.0
Net worth: 9960.480065979451
Price: 4308.83
Bought: 0.0
Sold: 10009.0732420537
Net worth: 10009.0732420537
Price: 4308.83
Bought: 2502.268310513425
Sold: 0.0
Net worth: 10008.322786697154




Price: 4308.83
Bought: 3753.4024657701375
Sold: 0.0
Net worth: 10007.197103662333
Price: 4308.83
Bought: 0.0
Sold: 0.0
Net worth: 10007.197103662333




Price: 4308.83
Bought: 0.0
Sold: 0.0
Net worth: 10007.197103662333
Price: 4308.83
Bought: 0.0
Sold: 0.0
Net worth: 10007.197103662333




Price: 4308.83
Bought: 0.0
Sold: 0.0
Net worth: 10007.197103662333
Price: 4304.31
Bought: 0.0
Sold: 6251.918499500827
Net worth: 10005.320965270965
Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965
Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965




Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965
Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965
Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965




Price: 4328.69
Bought: 0.0
Sold: 0.0
Net worth: 10005.320965270965
Price: 4328.69
Bought: 10005.320965270965
Sold: 0.0
Net worth: 10002.320269190208
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 10002.320269190208




Price: 4304.31
Bought: 0.0
Sold: 9979.245627793234
Net worth: 9979.245627793234
Price: 4304.31
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234




Price: 4319.99
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4319.99
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4319.22
Bought: 0.0
Sold: 0.0
Net worth: 9979.245627793234
Price: 4320.0
Bought: 2494.8114069483086
Sold: 0.0
Net worth: 9978.497408836836
Price: 4320.0
Bought: 7484.434220844925
Sold: 0.0
Net worth: 9976.70314921128
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 9976.70314921128
Price: 4320.0
Bought: 0.0
Sold: 3324.570046088839
Net worth: 9975.70547889636
Price: 4320.0
Bought: 0.0
Sold: 0.0
Net worth: 



Price: 4291.37
Bought: 0.0
Sold: 0.0
Net worth: 9968.319394941434
Price: 4291.37
Bought: 0.0
Sold: 0.0
Net worth: 9968.319394941434
Price: 4291.37
Bought: 0.0
Sold: 0.0
Net worth: 9968.319394941434
Price: 4297.04
Bought: 0.0
Sold: 9965.328899122953
Net worth: 9965.328899122953
Price: 4297.04
Bought: 0.0
Sold: 0.0
Net worth: 9965.328899122953
Price: 4297.04
Bought: 4982.664449561476
Sold: 0.0
Net worth: 9963.834548093393
Price: 4297.04
Bought: 0.0
Sold: 0.0
Net worth: 9963.834548093393
Price: 4297.04
Bought: 0.0
Sold: 0.0
Net worth: 9963.834548093393
Price: 4297.04
Bought: 0.0
Sold: 1659.8919158341191
Net worth: 9963.33643108354
Price: 4297.04
Bought: 0.0
Sold: 3319.7838316682387
Net worth: 9962.340197063835
Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9962.340197063835




Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9962.340197063835
Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9962.340197063835
Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9962.340197063835
Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9962.340197063835
Price: 4315.32
Bought: 4981.170098531918
Sold: 0.0
Net worth: 9960.846294205134
Price: 4315.32
Bought: 0.0
Sold: 4978.182292814514
Net worth: 9959.35239134643
Price: 4315.32
Bought: 4979.676195673215
Sold: 0.0
Net worth: 9957.858936524175
Price: 4315.32
Bought: 1244.9190489183038
Sold: 0.0
Net worth: 9957.485572818612
Price: 4315.32
Bought: 0.0
Sold: 0.0
Net worth: 9957.485572818612
Price: 4315.32
Bought: 1867.3785733774557
Sold: 0.0
Net worth: 9956.925527260268
Price: 4315.32
Bought: 0.0
Sold: 2021.7800224491616
Net worth: 9956.318811238725
Price: 4330.29
Bought: 0.0
Sold: 0.0
Net worth: 9956.318811238725




Price: 4318.39
Bought: 1944.5792979133087
Sold: 0.0
Net worth: 9976.782760234248
Price: 4318.39
Bought: 486.14482447832717
Sold: 0.0
Net worth: 9954.563842913345
Price: 4318.39
Bought: 729.2172367174908
Sold: 0.0
Net worth: 9954.345143352199
Price: 4318.39
Bought: 182.3043091793727
Sold: 0.0
Net worth: 9954.29046846191
Price: 4318.39
Bought: 0.0
Sold: 0.0
Net worth: 9954.29046846191
Price: 4318.39
Bought: 0.0
Sold: 3134.8517758871712
Net worth: 9953.349730707818
Price: 4330.0
Bought: 0.0
Sold: 2089.901183924781
Net worth: 9952.72257220509
Price: 4330.0
Bought: 0.0
Sold: 0.0
Net worth: 9963.96332528752
Price: 4330.0
Bought: 0.0
Sold: 0.0
Net worth: 9963.96332528752
Price: 4330.0
Bought: 0.0
Sold: 1397.0132495686892
Net worth: 9963.544095543726
Price: 4330.0
Bought: 0.0
Sold: 931.342166379126
Net worth: 9963.264609047863
Price: 4330.0
Bought: 2025.0053258244716
Sold: 0.0
Net worth: 9962.657289645937
Price: 4330.0
Bought: 0.0
Sold: 0.0
Net worth: 9962.657289645937
Price: 4330.0
Bought: 0.



Price: 4331.65
Bought: 4964.8899933026905
Sold: 0.0
Net worth: 9928.290966313478
Price: 4331.65
Bought: 0.0
Sold: 0.0
Net worth: 9912.528553720209
Price: 4345.45
Bought: 2482.4449966513453
Sold: 0.0
Net worth: 9911.784043574258
Price: 4345.45
Bought: 0.0
Sold: 2483.5906239290243
Net worth: 9934.707469993242
Price: 4345.45
Bought: 0.0
Sold: 1655.7270826193494
Net worth: 9934.2106028083
Price: 4345.45
Bought: 0.0
Sold: 0.0
Net worth: 9934.2106028083
Price: 4345.45
Bought: 3310.881351599859
Sold: 0.0
Net worth: 9933.217636292775
Price: 4345.45
Bought: 0.0
Sold: 0.0
Net worth: 9933.217636292775
Price: 4345.45
Bought: 0.0
Sold: 0.0
Net worth: 9933.217636292775
Price: 4345.45
Bought: 1103.6271171999529
Sold: 0.0
Net worth: 9932.886647454266
Price: 4345.45
Bought: 0.0
Sold: 3861.657361665222
Net worth: 9931.727802592308
Price: 4330.12
Bought: 0.0
Sold: 3861.657361665222
Net worth: 9930.568957730351
Price: 4330.12
Bought: 9930.568957730351
Sold: 0.0
Net worth: 9927.590680526195
Price: 4330.12




Price: 4377.85
Bought: 597.4885168950042
Sold: 0.0
Net worth: 9921.48706153686
Price: 4377.85
Bought: 0.0
Sold: 0.0
Net worth: 9921.48706153686
Price: 4360.71
Bought: 0.0
Sold: 2708.8609347995307
Net worth: 9920.674159385775




Price: 4360.71
Bought: 2250.663242742272
Sold: 0.0
Net worth: 9898.78157734939
Price: 4360.71
Bought: 0.0
Sold: 3822.9119495533682
Net worth: 9897.6343595992
Price: 4360.71
Bought: 0.0
Sold: 3822.9119495533682
Net worth: 9896.48714184901
Price: 4360.71
Bought: 2474.1217854622523
Sold: 0.0
Net worth: 9895.745127917551
Price: 4360.71
Bought: 0.0
Sold: 618.1594393998336
Net worth: 9895.559624434687
Price: 4360.71
Bought: 0.0
Sold: 463.6195795498752
Net worth: 9895.420496822537
Price: 4360.71
Bought: 0.0
Sold: 347.7146846624064
Net worth: 9895.316151113424
Price: 4360.71
Bought: 0.0
Sold: 521.5720269936096
Net worth: 9895.159632549758
Price: 4360.71
Bought: 0.0
Sold: 130.3930067484024
Net worth: 9895.12050290884
Price: 4360.71
Bought: 2375.9560234352207
Sold: 0.0
Net worth: 9894.40792987372
Price: 4360.71
Bought: 2375.9560234352202
Sold: 0.0
Net worth: 9893.6953568386
Price: 4360.71
Bought: 0.0
Sold: 0.0
Net worth: 9893.6953568386
Price: 4360.71
Bought: 0.0
Sold: 0.0
Net worth: 9893.695356



Price: 4360.7
Bought: 0.0
Sold: 0.0
Net worth: 9890.92244969514
Price: 4360.7
Bought: 0.0
Sold: 0.0
Net worth: 9890.92244969514




Price: 4360.7
Bought: 2472.730612423785
Sold: 0.0
Net worth: 9890.180852990423
Price: 4360.0
Bought: 2472.730612423785
Sold: 0.0
Net worth: 9889.439256285707




Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9888.64562758258
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9889.427918732805




Price: 4360.69
Bought: 1236.3653062118924
Sold: 0.0
Net worth: 9889.057120380447
Price: 4360.69
Bought: 1236.3653062118924
Sold: 0.0
Net worth: 9888.68632202809




Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9888.68632202809
Price: 4360.69
Bought: 618.1826531059462
Sold: 0.0
Net worth: 9888.50092285191




Price: 4360.69
Bought: 0.0
Sold: 2007.8856944112529
Net worth: 9887.898376379646
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9887.898376379646




Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9887.898376379646




Price: 4360.7
Bought: 0.0
Sold: 0.0
Net worth: 9887.898376379646
Price: 4360.7
Bought: 1931.2168268645457
Sold: 0.0
Net worth: 9887.333002743759
Price: 4360.69
Bought: 0.0
Sold: 3976.8646705132246
Net worth: 9886.139585317378
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9886.130462816482
Price: 4360.69
Bought: 0.0
Sold: 1988.4277753745396
Net worth: 9885.533755471664
Price: 4360.69
Bought: 0.0
Sold: 1988.4277753745396
Net worth: 9884.937048126849
Price: 4360.69
Bought: 4942.468524063424
Sold: 0.0
Net worth: 9883.454752258389
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9883.454752258389
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9883.454752258389
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9883.454752258389
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net worth: 9883.454752258389
Price: 4360.69
Bought: 1235.617131015856
Sold: 0.0
Net worth: 9883.084178291276
Price: 4360.69
Bought: 0.0
Sold: 6174.3799154081335
Net worth: 9881.231308455703
Price: 4360.69
Bought: 0.0
Sold: 0.0
Net 



Price: 4367.93
Bought: 0.0
Sold: 3496.038506769441
Net worth: 9328.057548282528
Price: 4367.93
Bought: 5830.96991522317
Sold: 0.0
Net worth: 9332.669296003236




Price: 4367.93
Bought: 0.0
Sold: 3109.9564984048116
Net worth: 9331.736029073636




Price: 4367.93
Bought: 0.0
Sold: 0.0
Net worth: 9331.736029073636
Price: 4360.93
Bought: 0.0
Sold: 0.0
Net worth: 9331.736029073636




Price: 4367.93
Bought: 1554.9782492024058
Sold: 0.0
Net worth: 9321.298738011601
Price: 4360.93
Bought: 388.74456230060144
Sold: 0.0
Net worth: 9333.648325127842
Price: 4348.34
Bought: 0.0
Sold: 0.0
Net worth: 9320.55933982411
Price: 4349.33
Bought: 0.0
Sold: 4064.172471336323
Net worth: 9295.798247201139
Price: 4349.33
Bought: 0.0
Sold: 2032.5488859546635
Net worth: 9296.113877797536
Price: 4349.33
Bought: 0.0
Sold: 0.0
Net worth: 9296.113877797536
Price: 4349.33
Bought: 0.0
Sold: 0.0
Net worth: 9296.113877797536
Price: 4322.65
Bought: 0.0
Sold: 0.0
Net worth: 9296.113877797536
Price: 4349.33
Bought: 0.0
Sold: 1010.040347197596
Net worth: 9283.338841623054
Price: 4349.33
Bought: 0.0
Sold: 0.0
Net worth: 9289.574808192761
Price: 4349.33
Bought: 0.0
Sold: 338.7581476591106
Net worth: 9289.47315025108
Price: 4349.32
Bought: 4305.87676952475
Sold: 0.0
Net worth: 9288.18177463291
Price: 4349.32
Bought: 2152.938384762375
Sold: 0.0
Net worth: 9287.524631511053
Price: 4349.32
Bought: 0.0
Sold

KeyboardInterrupt: ignored

## Import policy, RL agent, ...

In [None]:
import gym
import numpy as np

from stable_baselines import DQN

## Create the Gym env and instantiate the agent

For this example, we will use Lunar Lander environment.

"Landing outside landing pad is possible. Fuel is infinite, so an agent can learn to fly and then land on its first attempt. Four discrete actions available: do nothing, fire left orientation engine, fire main engine, fire right orientation engine. "

Lunar Lander environment: [https://gym.openai.com/envs/LunarLander-v2/](https://gym.openai.com/envs/LunarLander-v2/)

![Lunar Lander](https://cdn-images-1.medium.com/max/960/1*f4VZPKOI0PYNWiwt0la0Rg.gif)

Note: vectorized environments allow to easily multiprocess training. In this example, we are using only one process, hence the DummyVecEnv.

We chose the MlpPolicy because input of CartPole is a feature vector, not images.

The type of action to use (discrete/continuous) will be automatically deduced from the environment action space



In [None]:
env = gym.make('LunarLander-v2')


model = DQN('MlpPolicy', env, learning_rate=1e-3, prioritized_replay=True, verbose=1)

We create a helper function to evaluate the agent:

In [None]:
def evaluate(model, num_steps=1000):
  """
  Evaluate a RL agent
  :param model: (BaseRLModel object) the RL Agent
  :param num_steps: (int) number of timesteps to evaluate it
  :return: (float) Mean reward for the last 100 episodes
  """
  episode_rewards = [0.0]
  obs = env.reset()
  for i in range(num_steps):
      # _states are only useful when using LSTM policies
      action, _states = model.predict(obs)

      obs, reward, done, info = env.step(action)
      
      # Stats
      episode_rewards[-1] += reward
      if done:
          obs = env.reset()
          episode_rewards.append(0.0)
  # Compute mean reward for the last 100 episodes
  mean_100ep_reward = round(np.mean(episode_rewards[-100:]), 1)
  print("Mean reward:", mean_100ep_reward, "Num episodes:", len(episode_rewards))
  
  return mean_100ep_reward

Let's evaluate the un-trained agent, this should be a random agent.

In [None]:
# Random Agent, before training
mean_reward_before_train = evaluate(model, num_steps=10000)

Mean reward: -895.1 Num episodes: 88


## Train the agent and save it

Warning: this may take a while

In [None]:
# Train the agent
model.learn(total_timesteps=int(2e4), log_interval=10)
# Save the agent
model.save("dqn_lunar")
del model  # delete trained model to demonstrate loading

## Load the trained agent

In [None]:
model = DQN.load("dqn_lunar")

Loading a model without an environment, this model cannot be trained until it has a valid environment.


In [None]:
# Evaluate the trained agent
mean_reward = evaluate(model, num_steps=10000)