In [1]:
import numpy as np
import torch
from tqdm.notebook import trange
from tradezoo.agent import Action, Actor, Agent, Critic, Observation
from tradezoo.game import Game, Client, SineWave, Trader
from tradezoo.market import Account, Market
from tradezoo.plots import balance_plot, decision_plot, trades_plot, uncertainty_plot, utility_plot
from tradezoo.trainer import Experience, Trainer

In [2]:
actor = Actor()
critic = Critic()
agent = Agent(
    actor=actor,
    actor_optimizer=torch.optim.Adam(actor.parameters(), lr=1e-5),
    critic=critic,
    critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-3),
    discount_factor=0.99,
)

In [3]:
def mock_experience() -> Experience:
    mid_price = np.random.uniform(0.5, 1.5)
    spread = np.random.uniform(0, 1)
    old_observation = Observation(
        cash_balance=np.random.uniform(1, 4096),
        asset_balance=np.random.uniform(1, 4096),
        best_ask=mid_price * (1 + spread),
        best_bid=mid_price / (1 + spread),
    )
    action = Action(mid_price=np.random.uniform(0.5, 1.5), spread=np.random.uniform(0, 1))
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    if action.ask <= old_observation.best_bid:
        new_cash_balance += action.ask
        new_asset_balance -= 1
    if action.bid >= old_observation.best_ask:
        new_cash_balance -= action.bid
        new_asset_balance += 1
    new_observation = Observation(
        cash_balance=new_cash_balance,
        asset_balance=new_asset_balance,
        best_ask=old_observation.best_ask,
        best_bid=old_observation.best_bid,
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=new_observation.cash_balance
        + new_observation.asset_balance
        * (new_observation.best_ask * new_observation.best_bid) ** 0.5,
        new_observation=new_observation,
    )

mock_experience()

Experience(old_observation=Observation(cash_balance=600.2938881477946, asset_balance=1230.927208288691, best_ask=2.0276772116546082, best_bid=0.7543291371121806), action=Action(mid_price=0.7339079684606244, spread=0.784517049926441), reward=2122.6358667141376, new_observation=Observation(cash_balance=600.2938881477946, asset_balance=1230.927208288691, best_ask=2.0276772116546082, best_bid=0.7543291371121806))

In [4]:
for _ in trange(20_000):
    Trainer.train_(agent, experiences=[mock_experience() for _ in range(16)])

  0%|          | 0/20000 [00:00<?, ?it/s]

In [5]:
trader_account = Account(cash_balance=2048, asset_balance=2048)
client_account = Account(cash_balance=float("inf"), asset_balance=float("inf"))
price_process = 1 + SineWave(period=256) * 0.2
trader = Trader(
    agent=agent,
    account=trader_account,
    client=Client(
        account=client_account,
        for_account=trader_account,
        ask_process=price_process * 1.1,
        bid_process=price_process * 0.9,
    ),
)
game = Game.new(
    market=Market.from_accounts([trader_account, client_account]),
    traders=[trader],
)
turn_results = [game.turn_() for _ in trange(4096)]

  0%|          | 0/4096 [00:00<?, ?it/s]

In [6]:
trades_plot(turn_results)

In [7]:
balance_plot(turn_results)

In [8]:
utility_plot(agent)

In [9]:
decision_plot(agent)