In [1]:
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange
from tradezoo.agent import Action, DecisionBatch

In [2]:
class Normal(torch.nn.Module):
    def __init__(self, uncertainty: float):
        super().__init__()
        self.uncertainty = uncertainty
        self.loc = torch.nn.Parameter(torch.tensor([0], dtype=torch.float))
        self.scale = torch.nn.Parameter(torch.tensor([1], dtype=torch.float))
    
    def torch_distribution(self):
        return torch.distributions.Normal(self.loc, self.uncertainty + self.scale.abs())

Normal(uncertainty=0.001).torch_distribution().sample()

tensor([-0.5903])

In [3]:
class MockAgent(torch.nn.Module):
    def __init__(self, uncertainty: float):
        super().__init__()
        self.log_mid_price = Normal(uncertainty=uncertainty)
        self.log_spread = Normal(uncertainty=uncertainty)

    def decide(self):
        return DecisionBatch(
            log_mid_price=self.log_mid_price.torch_distribution(),
            log_spread=self.log_spread.torch_distribution(),
        )


agent = MockAgent(uncertainty=0.001)
agent.decide().sample()[0]


Action(log_mid_price=-0.3743041455745697, log_spread=1.7203329801559448)

In [4]:
def make_reward(action: Action):
    return -((action.ask - 1.5) ** 2 + (action.bid - 0.5) ** 2)

make_reward(Action(log_mid_price=0, log_spread=-1))

-0.07084390882368852

In [5]:
def make_loss(action: Action):
    return -make_reward(action) * agent.decide().log_probabilities([action])

make_loss(Action(log_mid_price=0, log_spread=-1))

tensor([-0.1657], grad_fn=<MulBackward0>)

In [6]:
optimizer = torch.optim.Adam(agent.parameters(), lr=5e-2)
for step_id in trange(1024):
    (action,) = agent.decide().sample()
    optimizer.zero_grad()
    make_loss(action).backward()
    optimizer.step()

  0%|          | 0/1024 [00:00<?, ?it/s]

In [7]:
example_actions = [agent.decide().sample()[0] for _ in range(4096)]
go.Figure(
    data=[
        go.Histogram(
            name="Mid",
            x=[action.mid_price for action in example_actions],
        ),
        go.Histogram(
            name="Ask",
            x=[action.ask for action in example_actions],
        ),
        go.Histogram(
            name="Bid",
            x=[action.bid for action in example_actions],
        ),
    ]
)