# DQN from scratch

In [1]:
from dataclasses import dataclass, replace
from functools import cache
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
from tqdm.notebook import tqdm, trange
from typing import List

## Architecture

In [2]:
@dataclass(frozen=True)
class Observation:
    cash_balance: float
    asset_balance: float
    best_ask: float
    best_bid: float

    @property
    def array(self):
        return np.array(
            [self.cash_balance, self.asset_balance, self.best_ask, self.best_bid],
            dtype=np.float32,
        )
    
    @property
    def batch(self):
        return ObservationBatch([self])

    @property
    def total_balance(self):
        return self.cash_balance + self.asset_balance


@dataclass(frozen=True)
class ObservationBatch:
    observations: List[Observation]

    @property
    def tensor(self):
        return torch.tensor(
            np.stack([observation.array for observation in self.observations])
        )


@dataclass(frozen=True)
class Action:
    ask: float
    bid: float

    @property
    def array(self):
        return np.array([self.ask, self.bid], dtype=np.float32)
    
    @property
    def batch(self):
        return ActionBatch([self])


@dataclass(frozen=True)
class ActionBatch:
    actions: List[Action]

    @property
    def tensor(self):
        return torch.tensor(np.stack([action.array for action in self.actions]))


@dataclass(frozen=True)
class Experience:
    old_observation: Observation
    action: Action
    reward: float
    new_observation: Observation


In [3]:
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(6, 256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(256, 1),
        )

    def forward(self, tensor):
        return self.network(tensor)

    def evaluate(
        self, observations: ObservationBatch, actions: ActionBatch
    ) -> torch.Tensor:
        return self(torch.cat([observations.tensor, actions.tensor], axis=1)).squeeze(1)


@dataclass(frozen=False)
class Agent:
    critic: Critic
    target_critic: Critic
    num_tried_actions: int
    random_action_probability: float
    discount: float

    def act(self, observation: Observation) -> Action:
        if np.random.uniform() < self.random_action_probability:
            return self.random_action()
        return self.best_action(observation, critic=self.critic)

    def best_action(self, observation: Observation, critic: Critic) -> Action:
        random_actions = [self.random_action() for _ in range(self.num_tried_actions)]
        evaluations = critic.evaluate(
            ObservationBatch([observation] * self.num_tried_actions),
            ActionBatch(random_actions),
        )
        best_index = torch.argmax(evaluations, dim=0).item()
        return random_actions[best_index]

    @classmethod
    def random_action(cls) -> Action:
        random_price = lambda: np.random.uniform(0, 2)
        return Action(ask=random_price(), bid=random_price())

    def optimal(self, num_tried_actions=64):
        return replace(
            self, num_tried_actions=num_tried_actions, random_action_probability=0
        )


## Problem formulation

In [4]:
def sample_observation(cash_balance: float, asset_balance: float):
    mid_price = np.random.uniform(0.5, 1.5)
    spread = np.random.uniform(0.1, 0.3)
    return Observation(
        cash_balance=cash_balance,
        asset_balance=asset_balance,
        best_ask=mid_price + 0.5 * spread,
        best_bid=mid_price - 0.5 * spread,
    )

In [5]:
def mock_experience(agent: Agent) -> Experience:
    old_observation = sample_observation(cash_balance=1, asset_balance=1)
    
    action = agent.act(old_observation)
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    
    if action.ask <= old_observation.best_bid:
        # sold_assets = min(new_asset_balance, 1)
        sold_assets = 1
        new_cash_balance += action.ask * sold_assets
        new_asset_balance -= sold_assets
    if action.bid >= old_observation.best_ask:
        # bought_assets = min(1, new_cash_balance / action.bid)
        bought_assets = 1
        new_cash_balance -= action.bid * bought_assets
        new_asset_balance += bought_assets
    
    new_observation = sample_observation(cash_balance=new_cash_balance, asset_balance=new_asset_balance)
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=new_observation.total_balance - old_observation.total_balance,
        new_observation=new_observation,
    )


## Training

In [6]:
@dataclass(frozen=True)
class Experiment:
    agent: Agent
    experiences: List[Experience]
    td_errors: List[float]

    @classmethod
    def train(
        cls, learning_rate, num_tried_actions, num_steps, steps_per_target_update
    ):
        agent = Agent(
            critic=Critic(),
            target_critic=Critic(),
            num_tried_actions=num_tried_actions,
            random_action_probability=None,  # gets set dynamically
            discount=0.99,
        )
        optimizer = torch.optim.Adam(agent.critic.parameters(), lr=learning_rate)
        experiences = []
        td_errors = []
        for step_idx in trange(num_steps, desc="Training"):
            if step_idx % steps_per_target_update == 0:
                agent.target_critic.load_state_dict(agent.critic.state_dict())

            agent.random_action_probability = 1 / (step_idx + 1)
            experience = mock_experience(agent)
            experiences.append(experience)
            best_next_action = agent.best_action(
                experience.new_observation, critic=agent.target_critic
            )
            new_evaluation = agent.target_critic.evaluate(
                experience.new_observation.batch, best_next_action.batch
            )
            td_error = (
                experience.reward
                + agent.discount * new_evaluation.detach()
                - agent.critic.evaluate(
                    experience.old_observation.batch, experience.action.batch
                )
            )
            td_errors.append(td_error.item())

            critic_loss = td_error ** 2
            optimizer.zero_grad()
            critic_loss.backward()
            optimizer.step()

        return cls(agent=agent, experiences=experiences, td_errors=td_errors)


experiment = Experiment.train(
    learning_rate=1e-4,
    num_tried_actions=64,
    num_steps=50_000,
    steps_per_target_update=1000,
)


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

In [7]:
def plot_training_progress(td_errors):
    absolute_errors = [abs(error) for error in td_errors]
    absolute_errors_smoothed = pd.Series(absolute_errors).rolling(window=64).mean()
    return go.Figure(
        layout=dict(xaxis_title="Training step", yaxis_title="Absolute TD Error", yaxis_type="log"),
        data=[
            go.Scatter(name="Actual", y=absolute_errors, opacity=0.25),
            go.Scatter(name="Smoothed", y=absolute_errors_smoothed),
        ]
    )


plot_training_progress(experiment.td_errors)

In [8]:
def plot_earnings(experiences):
    earnings = [
        experience.new_observation.total_balance
        - experience.old_observation.total_balance
        for experience in experiences
    ]
    return go.Figure(
        layout=dict(xaxis_title="Training step", yaxis_title="Earnings per experience"),
        data=[
            go.Scatter(name="Actual", opacity=0.25, y=earnings),
            go.Scatter(
                name="Smoothed", y=pd.Series(earnings).rolling(window=256).mean()
            ),
        ],
    )


plot_earnings(experiment.experiences)


In [11]:
def plot_actions(agent: Agent, cash_balance=1, asset_balance=1):
    best_asks = np.linspace(0.5, 1.5)
    best_bids = np.linspace(0.5, 1.5)
    actions = [
        [
            agent.act(
                Observation(
                    cash_balance=cash_balance,
                    asset_balance=asset_balance,
                    best_bid=best_bid,
                    best_ask=best_ask,
                )
            )
            if best_ask >= best_bid
            else Action(np.nan, np.nan)
            for best_ask in best_asks
        ]
        for best_bid in tqdm(best_bids)
    ]
    figure = make_subplots(
        rows=1,
        cols=2,
        specs=[[dict(is_3d=True)] * 2],
        subplot_titles=["Sale", "Purchase"],
    )
    figure.update_layout(
        title="Chosen actions",
        scene=dict(
            xaxis_title="Best ask",
            yaxis_title="Best bid",
            zaxis_title="Ask price",
        ),
        scene2=dict(
            xaxis_title="Best ask",
            yaxis_title="Best bid",
            zaxis_title="Bid price",
        ),
    )
    figure.add_trace(
        go.Surface(
            name="Agent's choice",
            colorscale="thermal",
            showscale=False,
            x=best_asks,
            y=best_bids,
            z=[[action.ask for action in _actions] for _actions in actions],
        ),
        1,
        1,
    )
    figure.add_trace(
        go.Surface(
            name="Sell boundary",
            colorscale="gray",
            showscale=False,
            x=best_asks,
            y=best_bids,
            z=[[best_bid for best_ask in best_asks] for best_bid in best_bids],
        ),
        1,
        1,
    )
    figure.add_trace(
        go.Surface(
            name="Agent's choice",
            colorscale="aggrnyl",
            showscale=False,
            x=best_asks,
            y=best_bids,
            z=[[action.bid for action in _actions] for _actions in actions],
        ),
        1,
        2,
    )
    figure.add_trace(
        go.Surface(
            name="Buy boundary",
            colorscale="gray",
            showscale=False,
            x=best_asks,
            y=best_bids,
            z=[[best_ask for best_ask in best_asks] for best_bid in best_bids],
        ),
        1,
        2,
    )
    return figure


plot_actions(experiment.agent.optimal(num_tried_actions=1024))


  0%|          | 0/50 [00:00<?, ?it/s]

In [10]:
def plot_action_value(agent: Agent, observation: float):
    @cache
    def action_value(action):
        return agent.critic.evaluate(observations=observation.batch, actions=action.batch).item()

    asks = np.linspace(0, 2)
    bids = np.linspace(0, 2)
    values = [[action_value(Action(ask=ask, bid=bid)) for ask in asks] for bid in bids]
    best_action = max(
        (Action(ask=ask, bid=bid) for ask in asks for bid in bids), key=action_value
    )
    return go.Figure(
        layout=dict(
            scene=dict(xaxis_title="Ask", yaxis_title="Bid", zaxis_title="Action value")
        ),
        data=[
            go.Surface(
                name="Action value",
                x=asks,
                y=bids,
                z=values,
            ),
            go.Scatter3d(
                name="Best action",
                x=[best_action.ask],
                y=[best_action.bid],
                z=[action_value(best_action)],
            ),
        ],
    )


plot_action_value(experiment.agent, observation=Observation(cash_balance=1, asset_balance=1, best_ask=0.8, best_bid=0.6))