In [1]:
import numpy as np
import gym
from gym import spaces
from gym.utils import seeding
from enum import Enum
import matplotlib.pyplot as plt
import pandas as pd
from collections import deque
import itertools
from typing import List
import pickle as pkl

def safe_divide(a, b):
    return np.divide(a, b, out=np.zeros_like(a), where=b!=0)

def moving_average(iterable, n=3):
    # moving_average([40, 30, 50, 46, 39, 44]) --> 40.0 42.0 45.0 43.0
    # http://en.wikipedia.org/wiki/Moving_average
    it = iter(iterable)
    d = deque(itertools.islice(it, n-1))
    print(next(it))
    d.appendleft(0)
    s = sum(d)
    print(d)
    print(s)
    for elem in it:
        s += elem - d.popleft()
        d.append(elem)
        yield s / n

In [4]:
# temp data
price = np.load('price.pkl.npy')
df = pd.DataFrame()
df['SSI'] = price
df['HPG'] = price

In [None]:
n = 5000
h = 5
pe = 50
sig = 0.1

# trading params
tick_size = 0.1
lot_size = 100
n_action = 5
M = 10

# calculated params
dt = 1 / n
lmbda = np.log(2) / h
action_space = lot_size * np.arange(-n_action, n_action+1)
holdings = np.arange(-M, M+1)


In [24]:
action_space = spaces.MultiDiscrete(np.ones(5) * (5))
action_space.sample()

array([2, 1, 3, 3, 4], dtype=int64)

In [94]:
class TradingEnv(gym.Env):

    metadata = {'render.modes': ['human']}

    def __init__(
        self,
        df: pd.DataFrame,
        window_size: int = 1,
        n_action: int = 5,
        # max_asset: int = 5,
        tick_size: float = 0.1,
        lot_size: int = 100,
        start_nav: float = 1e6,
        kappa: float = 0.02,
    ):

        self.seed()
        self.df = df
        self.window_size = window_size
        self.window = deque(maxlen=self.window_size)
        self.max_asset = self.df.shape[1]
        self.shape = (window_size, self.max_asset)
        self.tick_size = tick_size
        self.lot_size = lot_size
        self.kappa = kappa
        self.start_nav = start_nav
        self.shares = np.arange(-200, 300, 100)

        # spaces
        self.n_action = n_action
        self.action_space = spaces.MultiDiscrete(np.ones(self.max_asset) * self.n_action)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float32)

        # episode
        self._start_tick = self.window_size - 1
        self._end_tick = self.df.shape[0] - 1
        self._current_tick = None
        self._last_trade_tick = None
        self._first_rendering = None
        self.done = None
        # self.position = None
        # self.position_history = None
        self.total_reward = None
        self.total_profit = None
        self.history = None

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def reset(self):
        self._current_tick = self._start_tick
        self._first_rendering = True
        self.window.clear()
        self.window.extend(df.iloc[:self.window_size].to_numpy())
        self.done = False
        self.total_reward = 0
        self.total_profit = 0  # unit
        self.history = {
            'actions': (self.window_size-0) * [None],
            'shares': (self.window_size-0) * [np.zeros(self.max_asset, dtype=self.action_space.dtype)],
            'delta_vt': self.window_size * [0],
            'total_reward': self.window_size * [self.total_reward],
            'total_profit': self.window_size * [self.total_profit]
        }
        return self._get_observation()

    def step(self, action: List[int]):
        self._current_tick += 1
        new_prices = df.iloc[self._current_tick].to_numpy()
        
        self.done = False
        if self._current_tick == self._end_tick:
            self.done = True
        
        delta_vt = self.delta_vt(action, new_prices)
        
        step_reward = self._calculate_reward(delta_vt)
        self.total_reward += step_reward
        self.total_profit += delta_vt
        
        # always update history last
        self.window.append(new_prices)
        info = dict(
            actions = action,
            delta_vt = delta_vt,
            total_reward = self.total_reward,
            total_profit = self.total_profit,
            shares = self._decode_action(action)
        )
        self._update_history(info)

        return new_prices, step_reward, self.done, info

    def _get_observation(self):
        # process window
        return self.window[-1]

    def _update_history(self, info):
        if not self.history:
            self.history = {key: [] for key in info.keys()}

        for key, value in info.items():
            self.history[key].append(value)
    
    def spread_cost(self, dn: np.ndarray) -> float:
        return sum(dn * self.tick_size)

    def impact_cost(self, dn: np.ndarray) -> float:
        return sum(dn ** 2 * self.tick_size / self.lot_size)
    
    def total_cost(self, dn: np.ndarray) -> float:
        return self.spread_cost(dn) + self.impact_cost(dn)
    
    def delta_vt(
        self, 
        action: np.ndarray,
        prices: np.ndarray,
    ):
        shares = self._decode_action(action)
        prev_shares = self.history['shares'][-1]
        dn = shares - prev_shares
        rate = safe_divide(prices, self.window[-1]) - 1
        return np.sum(prev_shares * self.window[-1] * rate) - self.total_cost(dn)
    
    def _decode_action(self, action: np.ndarray) -> np.ndarray:
        return np.take(self.shares, action)

    def render(self, mode='human'):
        # def _plot_position(position, tick):
        #     color = None
        #     if position == Positions.Short:
        #         color = 'red'
        #     elif position == Positions.Long:
        #         color = 'green'
        #     if color:
        #         plt.scatter(tick, self.prices[tick], color=color)

        # if self._first_rendering:
        #     self._first_rendering = False
        #     plt.cla()
        #     plt.plot(self.prices)
        #     start_position = self.history['position_history'][self._start_tick]
        #     _plot_position(start_position, self._start_tick)

        # _plot_position(self._position, self._current_tick)

        # plt.suptitle(
        #     "Total Reward: %.6f" % self._total_reward + ' ~ ' +
        #     "Total Profit: %.6f" % self._total_profit
        # )

        # plt.pause(0.01)
        pass


    def render_all(self, mode='human'):
        # window_ticks = np.arange(len(self.history['position_history']))
        # plt.plot(self.prices)

        # short_ticks = []
        # long_ticks = []
        # for i, tick in enumerate(window_ticks):
        #     if self.history['position_history'][i] == Positions.Short:
        #         short_ticks.append(tick)
        #     elif self.history['position_history'][i] == Positions.Long:
        #         long_ticks.append(tick)

        # plt.plot(short_ticks, self.prices[short_ticks], 'ro')
        # plt.plot(long_ticks, self.prices[long_ticks], 'go')

        # plt.suptitle(
        #     "Total Reward: %.6f" % self._total_reward + ' ~ ' +
        #     "Total Profit: %.6f" % self._total_profit
        # )
        pass
        
        
    def close(self):
        # plt.close()
        pass


    def save_rendering(self, filepath):
        # plt.savefig(filepath)
        pass


    def pause_rendering(self):
        # plt.show()
        pass

    def _calculate_reward(self, delta_vt):
        return delta_vt - self.kappa * (delta_vt ** 2)

    def max_possible_profit(self):  # trade fees are ignored
        raise NotImplementedError

In [95]:

env = TradingEnv(df, window_size=5)
obs = env.reset()
done = False

while not done:
    action = env.action_space.sample()
    next_obs, reward, done, _ = env.step(action)
    

In [96]:
df

Unnamed: 0,SSI,HPG
0,50.000000,50.000000
1,50.124891,50.124891
2,50.153262,50.153262
3,50.222724,50.222724
4,50.382129,50.382129
...,...,...
4995,45.423907,45.423907
4996,45.417516,45.417516
4997,45.465620,45.465620
4998,45.525523,45.525523


In [97]:
history = pd.DataFrame(env.history)

In [98]:
pd.concat([df, history], axis=1).head(10)

Unnamed: 0,SSI,HPG,actions,shares,delta_vt,total_reward,total_profit
0,50.0,50.0,,"[0, 0]",0.0,0.0,0.0
1,50.124891,50.124891,,"[0, 0]",0.0,0.0,0.0
2,50.153262,50.153262,,"[0, 0]",0.0,0.0,0.0
3,50.222724,50.222724,,"[0, 0]",0.0,0.0,0.0
4,50.382129,50.382129,,"[0, 0]",0.0,0.0,0.0
5,50.515358,50.515358,"[2, 4]","[0, 200]",-60.0,-132.0,-60.0
6,50.445576,50.445576,"[0, 2]","[-200, 0]",-53.95628,-244.181882,-113.95628
7,50.513388,50.513388,"[0, 4]","[-200, 200]",-73.562455,-425.973033,-187.518735
8,50.502563,50.502563,"[2, 0]","[0, -200]",-180.0,-1253.973033,-367.518735
9,50.495178,50.495178,"[2, 4]","[0, 200]",-198.522921,-2240.722953,-566.041655
