# DQN from scratch

In [1]:
from dataclasses import dataclass, replace
from functools import cache
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
from tqdm.notebook import tqdm, trange
from typing import List

## Architecture

In [2]:
@dataclass(frozen=True)
class Observation:
    cash_balance: float
    asset_balance: float
    best_ask: float
    best_bid: float

    @property
    def array(self):
        return np.array(
            [self.cash_balance, self.asset_balance, self.best_ask, self.best_bid],
            dtype=np.float32,
        )
    
    @property
    def batch(self):
        return ObservationBatch([self])

    @property
    def total_balance(self):
        return self.cash_balance + self.asset_balance


@dataclass(frozen=True)
class ObservationBatch:
    observations: List[Observation]

    @property
    def tensor(self):
        return torch.tensor(
            np.stack([observation.array for observation in self.observations])
        )


@dataclass(frozen=True)
class Action:
    ask: float
    bid: float

    @property
    def array(self):
        return np.array([self.ask, self.bid], dtype=np.float32)
    
    @property
    def batch(self):
        return ActionBatch([self])


@dataclass(frozen=True)
class ActionBatch:
    actions: List[Action]

    @property
    def tensor(self):
        return torch.tensor(np.stack([action.array for action in self.actions]))


@dataclass(frozen=True)
class Experience:
    old_observation: Observation
    action: Action
    reward: float
    new_observation: Observation


In [3]:
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(6, 256),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(256, 1),
        )

    def forward(self, tensor):
        return self.network(tensor)

    def evaluate(
        self, observations: ObservationBatch, actions: ActionBatch
    ) -> torch.Tensor:
        return self(torch.cat([observations.tensor, actions.tensor], axis=1)).squeeze(1)


@dataclass(frozen=False)
class Agent:
    critic: Critic
    target_critic: Critic
    num_tried_actions: int
    random_action_probability: float
    discount: float

    def act(self, observation: Observation) -> Action:
        if np.random.uniform() < self.random_action_probability:
            return self.random_action()
        return self.best_action(observation, critic=self.critic)

    def best_action(self, observation: Observation, critic: Critic) -> Action:
        random_actions = [self.random_action() for _ in range(self.num_tried_actions)]
        evaluations = critic.evaluate(
            ObservationBatch([observation] * self.num_tried_actions),
            ActionBatch(random_actions),
        )
        best_index = torch.argmax(evaluations, dim=0).item()
        return random_actions[best_index]

    @classmethod
    def random_action(cls) -> Action:
        random_price = lambda: np.random.uniform(0, 2)
        return Action(ask=random_price(), bid=random_price())

    def optimal(self, num_tried_actions=64):
        return replace(
            self, num_tried_actions=num_tried_actions, random_action_probability=0
        )


## Problem forumlation

In [4]:
sample_balance = lambda: np.random.uniform(0, 2)
sample_price = lambda: np.random.uniform(0.5, 1.5)

In [5]:
def mock_experience(agent: Agent) -> Experience:
    step = np.random.randint(0, 1024)
    old_observation = Observation(
        cash_balance=sample_balance(),
        asset_balance=sample_balance(),
        best_ask=sample_price(),
        best_bid=sample_price(),
    )
    action = agent.act(old_observation)
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    if action.ask <= old_observation.best_bid:
        sold_assets = min(new_asset_balance, 1)
        new_cash_balance += action.ask * sold_assets
        new_asset_balance -= sold_assets
    if action.bid >= old_observation.best_ask:
        bought_assets = min(1, new_cash_balance / action.bid)
        new_cash_balance -= action.bid * bought_assets
        new_asset_balance += bought_assets
    new_observation = Observation(
        cash_balance=new_cash_balance,
        asset_balance=new_asset_balance,
        best_ask=sample_price(),
        best_bid=sample_price(),
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=new_cash_balance + new_asset_balance,
        new_observation=new_observation,
    )


## Training

In [6]:
@dataclass(frozen=True)
class Experiment:
    agent: Agent
    td_errors: np.ndarray

    @classmethod
    def train(cls, learning_rate, num_tried_actions, num_steps=10_000, steps_per_target_update=1000):
        agent = Agent(
            critic=Critic(),
            target_critic=Critic(),
            num_tried_actions=num_tried_actions,
            random_action_probability=1,
            discount=0.99,
        )
        optimizer = torch.optim.Adam(agent.critic.parameters(), lr=learning_rate)
        td_errors = []
        for step_idx in trange(num_steps, desc="Training"):
            if step_idx % steps_per_target_update == 0:
                agent.target_critic.load_state_dict(agent.critic.state_dict())
            
            experience = mock_experience(agent)
            best_next_action = agent.best_action(experience.new_observation, critic=agent.target_critic)
            new_evaluation = agent.target_critic.evaluate(
                experience.new_observation.batch, best_next_action.batch
            )
            td_error = (
                experience.reward
                + agent.discount * new_evaluation.detach()
                - agent.critic.evaluate(
                    experience.old_observation.batch, experience.action.batch
                )
            )

            critic_loss = td_error ** 2
            optimizer.zero_grad()
            critic_loss.backward()
            optimizer.step()
            
            td_errors.append(td_error.item())
            agent.random_action_probability *= 0.999
        
        return cls(agent=agent, td_errors=np.array(td_errors))


experiment = Experiment.train(learning_rate=1e-3, num_tried_actions=64)

Training:   0%|          | 0/10000 [00:00<?, ?it/s]

In [7]:
def plot_training_progress(td_errors):
    absolute_errors = [abs(error) for error in td_errors]
    absolute_errors_smoothed = pd.Series(absolute_errors).rolling(window=64).mean()
    return go.Figure(
        layout=dict(xaxis_title="Training step", yaxis_title="Absolute TD Error", yaxis_type="log"),
        data=[
            go.Scatter(name="Actual", y=absolute_errors, opacity=0.25),
            go.Scatter(name="Smoothed", y=absolute_errors_smoothed),
        ]
    )


plot_training_progress(experiment.td_errors)

In [8]:
def plot_action_value(agent: Agent, observation: float):
    @cache
    def action_value(action):
        return agent.critic.evaluate(observations=observation.batch, actions=action.batch).item()

    asks = np.linspace(0, 2)
    bids = np.linspace(0, 2)
    values = [[action_value(Action(ask=ask, bid=bid)) for ask in asks] for bid in bids]
    best_action = max(
        (Action(ask=ask, bid=bid) for ask in asks for bid in bids), key=action_value
    )
    return go.Figure(
        layout=dict(
            scene=dict(xaxis_title="Ask", yaxis_title="Bid", zaxis_title="Action value")
        ),
        data=[
            go.Surface(
                name="Action value",
                x=asks,
                y=bids,
                z=values,
            ),
            go.Scatter3d(
                name="Best action",
                x=[best_action.ask],
                y=[best_action.bid],
                z=[action_value(best_action)],
            ),
        ],
    )


plot_action_value(experiment.agent, observation=Observation(cash_balance=1, asset_balance=1, best_ask=0.8, best_bid=0.6))


In [9]:
def plot_action_distribution(agent: Agent, precision=1024):
    actions = [mock_experience(agent).action for _ in trange(precision, desc="Sampling")]
    return go.Figure(
        data=[
            go.Histogram(
                name="Ask",
                x=[action.ask for action in actions],
            ),
            go.Histogram(
                name="Bid",
                x=[action.bid for action in actions],
            ),
        ]
    )


plot_action_distribution(experiment.agent.optimal(num_tried_actions=256), precision=256)

Sampling:   0%|          | 0/256 [00:00<?, ?it/s]