In [1]:
import numpy as np
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange
from tradezoo.agent import Action, Actor, Agent, Critic, Observation
from tradezoo.trainer import Experience, Trainer

In [2]:
actor = Actor()
critic = Critic()
agent = Agent(
    actor=actor,
    actor_optimizer=torch.optim.Adam(actor.parameters(), lr=1e-5),
    critic=critic,
    critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-3),
    discount_factor=0.99,
)

In [3]:
def mock_experience() -> Experience:
    mid_price = np.random.uniform(0.5, 1.5)
    spread = np.random.uniform(0, 1)
    old_observation = Observation(
        cash_balance=np.random.uniform(1, 4096),
        asset_balance=np.random.uniform(1, 4096),
        best_ask=mid_price * (1 + spread),
        best_bid=mid_price / (1 + spread),
    )
    action = Action(mid_price=np.random.uniform(0.5, 1.5), spread=np.random.uniform(0, 1))
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    if action.ask <= old_observation.best_bid:
        new_cash_balance += action.ask
        new_asset_balance -= 1
    if action.bid >= old_observation.best_ask:
        new_cash_balance -= action.bid
        new_asset_balance += 1
    new_observation = Observation(
        cash_balance=new_cash_balance,
        asset_balance=new_asset_balance,
        best_ask=old_observation.best_ask,
        best_bid=old_observation.best_bid,
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=new_observation.cash_balance
        + new_observation.asset_balance
        * (new_observation.best_ask * new_observation.best_bid) ** 0.5,
        new_observation=new_observation,
    )

mock_experience()

Experience(old_observation=Observation(cash_balance=1097.2500867620674, asset_balance=2190.102228643431, best_ask=1.3171265233011231, best_bid=1.0893193796978105), action=Action(mid_price=1.3027424224249726, spread=0.28471351779789655), reward=3720.5971451905893, new_observation=Observation(cash_balance=1097.2500867620674, asset_balance=2190.102228643431, best_ask=1.3171265233011231, best_bid=1.0893193796978105))

In [4]:
for _ in trange(4096):
    Trainer.train_(agent, experiences=[mock_experience() for _ in range(16)])

  0%|          | 0/4096 [00:00<?, ?it/s]

TODO: check if training led to anything