In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange
from tradezoo.agent import Actor, Agent, Critic
from tradezoo.game import Game, Constant, MarketMaker, Trader
from tradezoo.market import Account, Market
from tradezoo.trainer import Trainer

<IPython.core.display.Javascript object>

In [3]:
def make_agent():
    actor = Actor()
    critic = Critic()
    return Agent(
        actor=actor,
        actor_optimizer=torch.optim.Adam(actor.parameters(), lr=1e-4),
        critic=critic,
        critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-3),
        discount_factor=0.99,
    )


num_traders = 1
traders = [
    Trader(
        agent=make_agent(),
        account=Account(cash_balance=256, stock_balance=256),
        market_maker=MarketMaker.inexhaustible(
            ask_process=Constant(value=2),
            bid_process=Constant(value=0.5),
        ),
    )
    for _ in range(num_traders)
]
game = Game.new(
    market=Market.from_accounts([trader.account for trader in traders]),
    traders=traders,
)
trainer = Trainer.new(game=game, replay_buffer_capacity=64)

<IPython.core.display.Javascript object>

In [4]:
turn_results = []
batch_size = 16
for step_id in trange(4096):
    turn_result = trainer.turn_()
    agent = turn_result.trader.agent
    turn_results.append(turn_result)
    if len(trainer.replay_buffers[agent].experiences) >= batch_size:
        trainer.train_step_(agent=agent, batch_size=batch_size)

  0%|          | 0/4096 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

In [6]:
def trades(buyer: Account):
    return {
        turn_idx: trade
        for turn_idx, turn_result in enumerate(turn_results)
        for trade in turn_result.trades
        if trade.buyer is buyer
    }


go.Figure(
    layout=dict(
        xaxis_title="Turn number",
        yaxis_title="Trade price",
        yaxis_type="log",
    ),
    data=[
        go.Scatter(
            name="Buys",
            mode="markers",
            x=[idx for idx in trades(buyer=traders[0].account).keys()],
            y=[trade.price for trade in trades(buyer=traders[0].account).values()],
        ),
        go.Scatter(
            name="Sells",
            mode="markers",
            x=[idx for idx in trades(buyer=traders[0].market_maker.account).keys()],
            y=[
                trade.price
                for trade in trades(buyer=traders[0].market_maker.account).values()
            ],
        ),
    ],
)

<IPython.core.display.Javascript object>