In [1]:
from dataclasses import dataclass
import torch
from tqdm.notebook import trange
from tradezoo.agent import DecisionBatch, ObservationBatch
from tradezoo.game import Game, Client, SineWave, Trader
from tradezoo.market import Account, Market
from tradezoo.plots import balance_plot, trades_plot

In [2]:
@dataclass(frozen=True)
class MockAgent:
    discount_factor: float

    def decide(self, observation_batch: ObservationBatch) -> DecisionBatch:
        return DecisionBatch(
            log_mid_price=torch.distributions.Normal(
                loc=torch.tensor([0], dtype=torch.float),
                scale=torch.tensor([0.01], dtype=torch.float),
            ),
            log_spread=torch.distributions.Normal(
                loc=torch.tensor([-3], dtype=torch.float),
                scale=torch.tensor([0.01], dtype=torch.float),
            ),
        )
    
    def evaluate(self, observation_batch: ObservationBatch) -> torch.Tensor:
        return torch.tensor(
            [observation.cash_balance + observation.asset_balance for observation in observation_batch.observations],
            dtype=torch.float
        ).log() / (1 - self.discount_factor)

agent = MockAgent(discount_factor=0.99)

In [3]:
trader_account = Account(cash_balance=1024, asset_balance=1024)
client_account = Account(cash_balance=float("inf"), asset_balance=float("inf"))
price_process = (0.25 * SineWave(period=16)).exp()
trader = Trader(
    agent=agent,
    account=trader_account,
    client=Client(
        account=client_account,
        for_account=trader_account,
        ask_process=price_process * 1.1,
        bid_process=price_process / 1.1,
    ),
)
game = Game.new(
    market=Market.from_accounts([trader_account, client_account]),
    traders=[trader],
)
turn_results = [game.turn_() for _ in trange(1024)]

  0%|          | 0/1024 [00:00<?, ?it/s]

In [4]:
trades_plot(turn_results)

In [5]:
balance_plot(turn_results)