In [1]:
from drl4dypm.agent import *
from drl4dypm.env import *
import time, copy

from matplotlib import pyplot as plt


# Base agent

## setup

In [2]:
# environment params
trading_days = 252
asset_names = ['AAL','AMZN','GOOG','FB','TSLA','CVS','FDX']
k = 10
cost_bps = 1e-3
path_to_data = 'data/stock_price.csv'

# agent params
num_assets = len(asset_names)
state_dim = 3*num_assets
action_dim = 1+num_assets

critic_learning_rate = 0.1**3
actor_learning_rate = critic_learning_rate * 0.5

network_params = {
    'actor': {
        'lstm': {
            'hidden_dim': 20,
            'num_layers': 2
        },
        'fc': [128,64,32],
        'dropout': 0.5,
    },
    'critic': {
        'lstm': {
            'hidden_dim': 20,
            'num_layers': 2
        },
        'fc': [128,64,32],
        'dropout': 0.5,
    }
}

# training params
max_episode = 10
min_episode_to_train = -1

In [3]:
# trading environment
env = TradingEnvironment(num_steps=trading_days, 
                         asset_names=asset_names, 
                         k=k, 
                         cost_bps=cost_bps,
                         agent_names=['base'],
                         path_to_data=path_to_data
                        )

In [4]:
# agent
agent = RLAgent(state_dim,
                action_dim,
                k,
                network_params,
                actor_learning_rate,
                critic_learning_rate,
                replay_capacity=int(1e3)
               )



## training

In [None]:
env.reset()

In [5]:
reward_sm = 0
critic_loss = []
actor_loss = []

actor_loss_i = np.inf
critic_loss_i = np.inf

elp = 0
start_time = time.time()

cols = ['episode','reward','reward_sm','critic_loss','actor_loss','elp','elp_sum']
line = '|'.join([f'{col:<12}' for col in cols])
print(line)


for e in range(max_episode):
    state, end = env.init_step()
    last_actions = env.simulator.last_actions
    
    while not end:
        with torch.no_grad():
            # generate action
            action = agent.get_action(torch.tensor(state[1], dtype=torch.float32).view(1,k,-1))
        
            # execute action and move to next step
            actions = {'base': action.numpy().reshape(-1)}
            rewards, next_state, end = env.take_step(actions, state[0])
            
            # store experience
            agent.store_transition({
                'state': {'state': torch.tensor(state[1], dtype=torch.float32).view(1,k,-1)},
                'action': {'action': torch.tensor(actions['base'], dtype=torch.float32).view(1,-1)},
                'next_state': {'state': torch.tensor(next_state[1], dtype=torch.float32).view(1,k,-1)},
                'reward': rewards['base'],
                'terminal': end
            })
            
            
        state = next_state
        last_actions = actions
        
    # update ddpg
    if e > min_episode_to_train:
        actor_loss_i, critic_loss_i = agent.update(return_loss=True)
        
        
    actor_loss.append(actor_loss_i)
    critic_loss.append(critic_loss_i)
    
    rewards = env.get_total_rewards()
    reward_sm = 0.9*reward_sm + 0.1*rewards['base']
    reward_corr = reward_sm/(1-0.9**(e+1))
    
    
    
    if e%1 == 0:
        elp_episode = time.time()-start_time
        elp += elp_episode
        start_time = time.time()
        
        line = f'{e:<12}|' + '|'.join([f'{col:<12.4f}' for col in [rewards['base'], reward_corr, 
                                                              critic_loss[-1], actor_loss[-1], 
                                                              elp_episode, elp]])
        print(line)
    
    
    # reset environment
    env.reset()
    
        

The framework is not responsible for any un-matching device issues caused by this operation.[0m


episode     |reward      |reward_sm   |critic_loss |actor_loss  |elp         |elp_sum     


The framework is not responsible for any un-matching device issues caused by this operation.[0m
The framework is not responsible for any un-matching device issues caused by this operation.[0m
The framework is not responsible for any un-matching device issues caused by this operation.[0m


0           |9.9145      |9.9145      |0.0010      |-0.1437     |2.7792      |2.7792      
1           |9.9166      |9.9156      |0.0003      |-0.1204     |2.6235      |5.4027      
2           |9.8152      |9.8786      |0.0002      |-0.1012     |2.6453      |8.0480      
3           |9.8457      |9.8690      |0.0003      |-0.1003     |2.6152      |10.6632     
4           |9.8618      |9.8672      |0.0002      |-0.1049     |2.6152      |13.2784     
5           |9.4830      |9.7852      |0.0003      |-0.1123     |2.6257      |15.9041     
6           |10.0350     |9.8331      |0.0001      |-0.1176     |2.6221      |18.5262     
7           |10.0262     |9.8670      |0.0003      |-0.1235     |2.6349      |21.1610     
8           |9.9147      |9.8748      |0.0001      |-0.1262     |2.7098      |23.8708     
9           |9.8325      |9.8683      |0.0000      |-0.1290     |2.6881      |26.5589     
