In [10]:
from datetime import timedelta
from dateutil import parser

from coinbase_train import constants as c
from coinbase_train import utils
from coinbase_train.environment import MockEnvironment
from coinbase_train.model import build_actor, build_critic
from coinbase_train.train import create_agent

In [11]:
hyper_params = dict( 
    actor_account_funds_attention_dim=100,
    actor_account_funds_hidden_dim=100, 
    actor_account_orders_hidden_dim=100, 
    actor_account_orders_attention_dim=100, 
    actor_matches_attention_dim=100, 
    actor_matches_hidden_dim=100, 
    actor_merged_branch_attention_dim=100, 
    actor_merged_branch_hidden_dim=100, 
    actor_order_book_num_filters=100, 
    actor_order_book_kernel_size=4, 
    actor_orders_attention_dim=100, 
    actor_orders_hidden_dim=100,
    batch_size=1,
    critic_account_funds_attention_dim=100, 
    critic_account_funds_hidden_dim=100, 
    critic_account_orders_hidden_dim=100, 
    critic_account_orders_attention_dim=100, 
    critic_matches_attention_dim=100, 
    critic_matches_hidden_dim=100, 
    critic_merged_branch_attention_dim=100, 
    critic_merged_branch_hidden_dim=100, 
    critic_order_book_num_filters=100, 
    critic_order_book_kernel_size=4, 
    critic_orders_attention_dim=100, 
    critic_orders_hidden_dim=100,
    critic_output_branch_hidden_dim=100,
    num_time_steps=c.NUM_TIME_STEPS)  

train_environment_configs = dict( 
    end_dt=parser.parse('2019-01-18 05:19:59.264'),
    initial_btc=0,
    initial_usd=10000,
    num_episodes=2,
    start_dt=parser.parse('2019-01-17 05:19:59.264'),
    time_delta=timedelta(seconds=10)
    )

test_environment_configs = dict( 
    end_dt=parser.parse('2019-01-19 05:19:59.264'),
    initial_btc=0,
    initial_usd=10000,
    num_episodes=2,
    start_dt=parser.parse('2019-01-18 05:19:59.264'),
    time_delta=timedelta(seconds=10)
    )

In [12]:
hyper_params = utils.HyperParameters(**hyper_params)
test_environment_configs = utils.EnvironmentConfigs(**test_environment_configs)
train_environment_configs = utils.EnvironmentConfigs(**train_environment_configs)

In [13]:
actor = build_actor(
    account_funds_attention_dim=hyper_params.actor_account_funds_attention_dim,
    account_funds_hidden_dim=hyper_params.actor_account_funds_hidden_dim,
    account_orders_attention_dim=hyper_params.actor_account_orders_attention_dim,
    account_orders_hidden_dim=hyper_params.actor_account_orders_hidden_dim,
    matches_attention_dim=hyper_params.actor_matches_attention_dim,
    matches_hidden_dim=hyper_params.actor_matches_hidden_dim,
    merged_branch_attention_dim=hyper_params.actor_merged_branch_attention_dim,
    merged_branch_hidden_dim=hyper_params.actor_merged_branch_hidden_dim,
    order_book_kernel_size=hyper_params.actor_order_book_kernel_size,
    order_book_num_filters=hyper_params.actor_order_book_num_filters,
    orders_attention_dim=hyper_params.actor_orders_attention_dim,
    orders_hidden_dim=hyper_params.actor_orders_hidden_dim) 

critic = build_critic(
    account_funds_attention_dim=hyper_params.critic_account_funds_attention_dim,
    account_funds_hidden_dim=hyper_params.critic_account_funds_hidden_dim,
    account_orders_attention_dim=hyper_params.critic_account_orders_attention_dim,
    account_orders_hidden_dim=hyper_params.critic_account_orders_hidden_dim,
    matches_attention_dim=hyper_params.critic_matches_attention_dim,
    matches_hidden_dim=hyper_params.critic_matches_hidden_dim,
    merged_branch_attention_dim=hyper_params.critic_merged_branch_attention_dim,
    merged_branch_hidden_dim=hyper_params.critic_merged_branch_hidden_dim,
    order_book_kernel_size=hyper_params.critic_order_book_kernel_size,
    order_book_num_filters=hyper_params.critic_order_book_num_filters,
    orders_attention_dim=hyper_params.critic_orders_attention_dim,
    orders_hidden_dim=hyper_params.critic_orders_hidden_dim,
    output_branch_hidden_dim=hyper_params.critic_output_branch_hidden_dim)

train_environment = MockEnvironment(
    end_dt=train_environment_configs.end_dt,
    initial_usd=train_environment_configs.initial_usd,
    initial_btc=train_environment_configs.initial_btc, 
    num_workers=c.NUM_DATABASE_WORKERS,
    num_time_steps=hyper_params.num_time_steps,
    start_dt=train_environment_configs.start_dt,
    time_delta=train_environment_configs.time_delta)

agent = create_agent(
    actor=actor,
    critic=critic,
    hyper_params=hyper_params)

nb_max_episode_steps = utils.calc_nb_max_episode_steps(
    end_dt=train_environment_configs.end_dt,
    start_dt=train_environment_configs.start_dt,
    time_delta=train_environment_configs.time_delta)

In [14]:
history = agent.fit(
    env=train_environment, 
    log_interval=nb_max_episode_steps,
    nb_max_episode_steps=nb_max_episode_steps, 
    nb_steps=train_environment_configs.num_episodes * nb_max_episode_steps,
    verbose=2) 

Training for 36 steps ...
  3/36: episode: 1, duration: 53.688s, episode steps: 3, steps per second: 0, episode reward: 0.000, mean reward: 0.000 [0.000, 0.000], mean action: 0.275 [-0.153, 0.527], mean observation: 0.000 [0.000, 0.000], loss: 0.000980, mean_q: 0.075135
 18/36: episode: 2, duration: 45.986s, episode steps: 15, steps per second: 0, episode reward: 0.000, mean reward: 0.000 [0.000, 0.000], mean action: 0.297 [-0.273, 0.684], mean observation: 0.000 [0.000, 0.000], loss: 0.001504, mean_q: 0.105653
 36/36: episode: 3, duration: 57.273s, episode steps: 18, steps per second: 0, episode reward: 0.000, mean reward: 0.000 [0.000, 0.000], mean action: 0.250 [-0.159, 0.633], mean observation: 0.000 [0.000, 0.000], loss: 0.001963, mean_q: 0.098116
done, took 156.953 seconds
