In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import numpy as np
import plotly
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange
from tradezoo.agent import Actor, Agent, Critic, Observation
from tradezoo.game import Game, Client, SineWave, Trader
from tradezoo.market import Account, Market
from tradezoo.trainer import Trainer

<IPython.core.display.Javascript object>

In [3]:
def make_agent():
    actor = Actor()
    critic = Critic()
    return Agent(
        actor=actor,
        actor_optimizer=torch.optim.Adam(actor.parameters(), lr=2e-5),
        critic=critic,
        critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-3),
        discount_factor=0.99,
    )


num_traders = 1
trader_accounts = [
    Account(cash_balance=4096, asset_balance=4096) for _ in range(num_traders)
]
client_accounts = [
    Account(cash_balance=float("inf"), asset_balance=float("inf"))
    for _ in range(num_traders)
]
price_process = 1 + SineWave(period=256) * 0.2
traders = [
    Trader(
        agent=make_agent(),
        account=trader_account,
        client=Client(
            account=client_account,
            for_account=trader_account,
            ask_process=price_process * 1.1,
            bid_process=price_process * 0.9,
        ),
    )
    for trader_account, client_account in zip(trader_accounts, client_accounts)
]
game = Game.new(
    market=Market.from_accounts(trader_accounts + client_accounts),
    traders=traders,
)
trainer = Trainer.new(game=game, replay_buffer_capacity=16, batch_size=16)
turn_results = [trainer.turn_() for _ in trange(20_000)]

  0%|          | 0/20000 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

In [4]:
go.Figure(
    layout=dict(
        xaxis_title="Turn number",
        yaxis_title="Price",
        yaxis_type="log",
    ),
    data=[
        go.Scatter(
            name="Trade",
            mode="markers",
            marker=dict(color=plotly.colors.qualitative.Plotly[0]),
            x=[
                idx
                for idx, turn_result in enumerate(turn_results)
                for trade in turn_result.trades
            ],
            y=[
                trade.price
                for turn_result in turn_results
                for trade in turn_result.trades
            ],
        ),
        go.Scatter(
            name="Client ask",
            marker=dict(color=plotly.colors.qualitative.Plotly[1]),
            y=[
                turn_result.observation.best_ask
                for idx, turn_result in enumerate(turn_results)
            ],
        ),
        go.Scatter(
            name="Client bid",
            marker=dict(color=plotly.colors.qualitative.Plotly[1]),
            y=[turn_result.observation.best_bid for turn_result in turn_results],
        ),
        go.Scatter(
            name="Mean trader ask",
            marker=dict(color=plotly.colors.qualitative.Plotly[2]),
            y=[
                turn_result.decision_batch.mid_price.torch_distribution.mean.item()
                * (1 + turn_result.decision_batch.spread.torch_distribution.mean.item())
                for turn_result in turn_results
            ],
        ),
        go.Scatter(
            name="Mean trader bid",
            marker=dict(color=plotly.colors.qualitative.Plotly[2]),
            y=[
                turn_result.decision_batch.mid_price.torch_distribution.mean.item()
                / (1 + turn_result.decision_batch.spread.torch_distribution.mean.item())
                for turn_result in turn_results
            ],
        ),
    ],
)

<IPython.core.display.Javascript object>

In [5]:
go.Figure(
    layout=dict(
        xaxis_title="Turn number",
        yaxis_title="Underlying standard deviation",
        yaxis_type="log",
    ),
    data=[
        go.Scatter(
            name="Mid-price",
            y=[
                turn_result.decision_batch.mid_price.underlying_stds.item()
                for turn_result in turn_results
            ],
        ),
        go.Scatter(
            name="Spread",
            y=[
                turn_result.decision_batch.spread.underlying_stds.item()
                for turn_result in turn_results
            ],
        ),
    ],
)

<IPython.core.display.Javascript object>

In [6]:
go.Figure(
    layout=dict(
        xaxis_title="Turn number",
        yaxis_title="Balance",
    ),
    data=[
        go.Scatter(
            name="Cash",
            y=[turn_result.observation.cash_balance for turn_result in turn_results],
        ),
        go.Scatter(
            name="Asset",
            y=[turn_result.observation.asset_balance for turn_result in turn_results],
        ),
        go.Scatter(
            name="Reward",
            y=[turn_result.reward for turn_result in turn_results],
        ),
        go.Scatter(
            name="Expected future reward",
            y=[
                traders[0].agent.evaluate(turn_result.observation.batch).item()
                * (1 - traders[0].agent.discount_factor)
                for turn_result in turn_results
            ],
        ),
    ],
)

<IPython.core.display.Javascript object>

In [7]:
cash_balances = np.linspace(0, 256, 16)
asset_balances = np.linspace(0, 256, 16)
observations = [
    [
        Observation(
            cash_balance=cash_balance,
            asset_balance=asset_balance,
            best_ask=2,
            best_bid=0.5,
        )
        for cash_balance in cash_balances
    ]
    for asset_balance in asset_balances
]
go.Figure(
    layout=dict(
        scene=dict(
            xaxis_title="Cash balance",
            yaxis_title="Asset balance",
            zaxis_title="Utility",
        )
    ),
    data=[
        go.Surface(
            x=cash_balances,
            y=asset_balances,
            z=[
                [
                    traders[0].agent.evaluate(observation.batch)[0].item()
                    for observation in obs
                ]
                for obs in observations
            ],
        )
    ],
)

<IPython.core.display.Javascript object>

In [8]:
best_asks = np.linspace(0, 4, 16)
best_bids = np.linspace(0, 4, 16)
observations = [
    [
        Observation(
            cash_balance=2048,
            asset_balance=2048,
            best_ask=best_ask,
            best_bid=best_bid,
        )
        for best_ask in best_asks
    ]
    for best_bid in best_bids
]

go.Figure(
    layout=dict(
        scene=dict(
            xaxis_title="Best ask",
            yaxis_title="Best bid",
            zaxis_title="Price",
        )
    ),
    data=[
        go.Surface(
            name="Mid price",
            x=best_asks,
            y=best_bids,
            z=[
                [
                    traders[0]
                    .agent.decide(observation.batch)
                    .mid_price.underlying_means.exp()
                    .item()
                    for observation in obs
                ]
                for obs in observations
            ],
        ),
    ],
)

<IPython.core.display.Javascript object>