<a href="https://colab.research.google.com/github/geekpradd/Reinforcement-Learning-Stock-Trader/blob/master/Env.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
Balance_Normal = 1000000
Shares_Normal = 10000
Intial_Balance = 10000
class StockEnv(gym.Env):
    metadata = {'render.modes': ['human']}
    
    def __init__(self,df, train = True, num_stocks = 1, **kwargs):
        super(StockEnv,self).__init__()
        
        self.MAX_Shares = 2147483647
        self.Min_Brokerage = 30
        self.Brokerage_rate = 0.001
        self.num_stocks  = num_stocks
        self.Balance_Normal = Balance_Normal
        self.Shares_Normal = Shares_Normal
        
        if "Balance" in kwargs.keys():
            Balance_Normal = kwargs["Balance_Normal"]
        if "Max_Shares" in kwargs.keys():
            self.MAX_Shares = kwargs["MAX_Shares"]
        if "Broke_limit" in kwargs.keys():
            self.Min_Brokerage = kwargs["Broke_limit"]
        if "Broke_rate" in kwargs.keys():
            self.Brokerage_rate = kwargs["Broke_rate"]

        self.dfs = df 
        assert len(df) == num_stocks, "Size of database not equal to number of stocks"
        self.max_step  = min([len(d.loc[:,'Open']) for d in self.dfs])
        self.action_space = spaces.Box(low = -1, high = 1, shape =  (7, self.num_stocks), dtype = np.float32)
        self.observation_space = spaces.Box(low = 0, high = 1, shape = (7, self.num_stocks), dtype = np.float32)
        
        
    def _get_price(self):
        return np.array([np.random.uniform(df.loc[self.current_step,"Open"], df.loc[self.current_step,"Close"]) for df in self.dfs])
    
    def _get_high(self):
        return np.array([df.loc[self.current_step,"High"] for df in self.dfs])
      
    def _set_high(self, high):
        self.highest_price = np.maximum(self.highest_price, high)
    
    def validate(self, act):
        sum = 0
        for i in range(self.num_stocks):
            if act[i] < 0:
                if shares_held[i] < -act[i]:
                    return False
            sum -= self._broke(self.current_price[i]*abs(act[i]))
        
        sum += np.sum(self.current_price*act)
        if sum + self.balance < 0:
            return False
        return True, sum

    def _observe(self):
        frame = np.zeros((7,self.num_stocks))
        for i in range(self.num_stocks):
            frame[:5,i] = np.array([self.dfs[i].loc[self.current_step,'Open'],self.dfs[i].loc[self.current_step,'High'],self.dfs[i].loc[self.current_step,'Low'],self.dfs[i].loc[self.current_step,'Close'],self.dfs[i].loc[self.current_step,'Volume']])
        frame[:4, :] = frame[:4, :] / self.highest_price
        frame[4, :] = frame[4, :]/ self.MAX_shares
        frame[5, :] = self.balance/self.Balance_Normal
        frame[6, :] = self.shares_held/self.Shares_Normal
        info = {
            'highest_price': self.highest_price,
            'current_price': self.current_price,
            'max_worth': self.max_net_worth,
            'broke_limit': self.Min_Brokerage,
            'broke_rate': self.Brokerage_rate,
            'max_shares': self.MAX_shares,
        }
        return frame, info
    
    def reset(self, balance = Intial_Balance, initial_shares = None, **kwargs):

        if "Balance" in kwargs.keys():
            Balance_Normal = kwargs["Balance_Normal"]
        if "Max_Shares" in kwargs.keys():
            self.MAX_Shares = kwargs["MAX_Shares"]
        if "Broke_limit" in kwargs.keys():
            self.Min_Brokerage = kwargs["Broke_limit"]
        if "Broke_rate" in kwargs.keys():
            self.Brokerage_rate = kwargs["Broke_rate"]

        self.current_step = 0
        self.balance = balance
        self.shares_held = np.zeros(self.num_stocks)
        self.current_price = self._get_price()
        self.net_worth = self.balance + np.sum(initial_shares*self.current_price)
        self.initial_worth = self.net_worth
        self.max_net_worth = self.net_worth
        self.highest_price = self._get_high()
        return  self._observe()
        
    def _broke(self,amount):
        return max(amount * self.Brokerage_rate, self.Min_Brokerage)
    
    def update(self, reward):
        self.net_worth += reward
        self.max_net_worth = max(self.max_net_worth, self.net_worth)
    
    def _take_action(self,act):
        act = act*self.MAX_shares
        self.current_price = self._get_price()
        validation = self.validate(act)
        if validatation[0] == False:
            return -5000, False
        high = self._get_high()
        self._set_high(high)
        self.balance += validation[1]
        self.shares_held += act
        reward = self.balance + np.sum(self.shares_held * self.current_price) - self.net_worth
        self.update(reward)
        return reward, True
            
    def step(self,action):
        reward, status = self._take_action(action)
        if status:
            self.current_step = (self.current_step + 1) % (self.max_step - 1)
        done = self.net_worth <= 0
        obs, info = self._observe()
        return obs, reward, done, info
    
    def render(self, mode='human', close = False):
        profit = self.net_worth - self.initial_worth
        print('Step: {}'.format(self.current_step))
        print('Net Worth: {}'.format(self.net_worth))
        print('Profit: {}'.format(profit))
        
def create_stock_env(locations, train=True):
    dfs = [pd.read_csv(location).sort_values('Date') for location in locations]
#     for df in dfs:
#         print(len(df.loc[:,'Open']))
#     for df in dfs:
#         print(df.shape)
    return StockEnv(dfs, train, len(locations))