In [1]:
import numpy as np
import torch
from tqdm.notebook import trange
from tradezoo.agent import Action, Agent, Critic, Observation
from tradezoo.game import Game, Client, SineWave, Trader
from tradezoo.market import Account, Market
from tradezoo.plots import balance_plot, td_error_plot, training_plot, trades_plot
from tradezoo.trainer import Experience, Trainer

In [2]:
class MockActor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.outputs = torch.nn.Parameter(torch.tensor([0, 0.01, -3, 0.01], requires_grad=True))

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return torch.tile(self.outputs, dims=[observations.shape[0], 1])

mock_actor = MockActor()

In [3]:
class MockOptimizer:
    def step(self):
        pass

    def zero_grad(self):
        pass

mock_optimizer = MockOptimizer()

In [4]:
critic = Critic()
agent = Agent(
    actor=mock_actor,
    actor_optimizer=mock_optimizer,
    critic=critic,
    critic_optimizer=torch.optim.Adam(critic.parameters(), lr=1e-3),
    discount_factor=0.99,
    uncertainty=1e-3,
)

In [5]:
trader_account = Account(cash_balance=1024, asset_balance=1024)
client_account = Account(cash_balance=float("inf"), asset_balance=float("inf"))
price_process = (0.25 * SineWave(period=16)).exp()
trader = Trader(
    agent=agent,
    account=trader_account,
    client=Client(
        account=client_account,
        for_account=trader_account,
        ask_process=price_process * 1.1,
        bid_process=price_process / 1.1,
    ),
)

In [6]:
def mock_experience(agent) -> Experience:
    step = np.random.randint(0, 4096)
    old_observation = Observation(
        cash_balance=np.random.uniform(2, trader_account.cash_balance * 2),
        asset_balance=np.random.uniform(2, trader_account.asset_balance * 2),
        best_ask=trader.client.ask_process.value(step),
        best_bid=trader.client.bid_process.value(step),
    )
    (action,) = agent.decide(old_observation.batch).sample()
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    if action.ask <= old_observation.best_bid:
        new_cash_balance += action.ask
        new_asset_balance -= 1
    if action.bid >= old_observation.best_ask:
        bought_assets = min(1, new_cash_balance / action.bid)
        new_cash_balance -= action.bid * bought_assets
        new_asset_balance += bought_assets
    new_observation = Observation(
        cash_balance=new_cash_balance,
        asset_balance=new_asset_balance,
        best_ask=trader.client.ask_process.value(step + 1),
        best_bid=trader.client.bid_process.value(step + 1),
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=trader.utility(new_observation),
        new_observation=new_observation,
    )


train_results = [
    Trainer.train_(agent, experiences=[mock_experience(agent) for _ in range(32)])
    for _ in trange(4096)
]


  0%|          | 0/4096 [00:00<?, ?it/s]

In [7]:
training_plot(train_results)

In [8]:
game = Game.new(
    market=Market.from_accounts([trader_account, client_account]),
    traders=[trader],
)
turn_results = [game.turn_() for _ in trange(1024)]

  0%|          | 0/1024 [00:00<?, ?it/s]

In [9]:
trades_plot(turn_results)

In [10]:
td_error_plot(turn_results)

In [11]:
balance_plot(turn_results)