# DQN from scratch

In [1]:
from dataclasses import dataclass
import numpy as np
from observation import Observation
import pandas as pd
import plotly.graph_objects as go
import torch
from tqdm.notebook import tqdm, trange

## Architecture

In [2]:
@dataclass(frozen=True)
class Action:
    ask: float
    bid: float

    @property
    def tensor(self):
        return torch.tensor([[self.ask, self.bid]], dtype=torch.float)


@dataclass(frozen=True)
class Experience:
    old_observation: Observation
    action: Action
    reward: float
    new_observation: Observation

In [3]:
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(6, 64),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(64, 1),
        )

    def forward(self, tensor):
        return self.network(tensor)

    def evaluate(self, observation: Observation, action: Action) -> torch.Tensor:
        return self(
            torch.cat([observation.tensor, action.tensor], axis=1)
        ).squeeze()


@dataclass(frozen=True)
class Agent:
    critic: Critic
    critic_optimizer: torch.optim.Optimizer
    target_critic: Critic
    num_tried_actions: int
    random_action_probability: float
    discount: float

    def act(self, observation: Observation) -> Action:
        if np.random.uniform() < self.random_action_probability:
            return self.random_action()
        return self.best_action(observation)
        
    def best_action(self, observation: Observation) -> Action:
        random_actions = [self.random_action() for _ in range(self.num_tried_actions)]
        return max(
            random_actions, key=lambda action: self.critic.evaluate(observation, action).item()
        )
    
    def random_action(self) -> Action:
        random_price = lambda: np.random.uniform(0, 2)
        return Action(ask=random_price(), bid=random_price())

    def train_(self, experience: Experience):
        best_next_action = self.best_action(experience.new_observation)
        new_evaluation = self.target_critic.evaluate(experience.new_observation, best_next_action)
        td_error = (
            experience.reward
            + self.discount * new_evaluation.detach()
            - self.critic.evaluate(experience.old_observation, experience.action)
        )

        critic_loss = td_error ** 2
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        return td_error.item()
    
    def update_target_critic_(self):
        self.target_critic.load_state_dict(self.critic.state_dict())

## Training

In [4]:
def mock_experience(agent: Agent) -> Experience:
    old_observation = Observation(
        cash_balance=np.random.normal(),
        asset_balance=np.random.normal(),
        best_ask=np.random.normal(),
        best_bid=np.random.normal(),
    )
    action = agent.act(old_observation)
    new_observation = Observation(
        cash_balance=np.random.normal(),
        asset_balance=np.random.normal(),
        best_ask=np.random.normal(),
        best_bid=np.random.normal(),
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=-((action.ask - 1.5) ** 2 + (action.bid - 0.5) ** 2),
        new_observation=new_observation,
    )


In [5]:
@dataclass(frozen=True)
class Experiment:
    agent: Agent
    td_errors: np.ndarray

    @classmethod
    def train(cls, critic_lr, num_tried_actions, num_steps=10_000, steps_per_target_update=100):
        critic = Critic()
        agent = Agent(
            critic=critic,
            critic_optimizer=torch.optim.Adam(critic.parameters(), lr=critic_lr),
            target_critic=Critic(),
            num_tried_actions=num_tried_actions,
            random_action_probability=0.25,
            discount=0.99,
        )
        td_errors = []
        for step_idx in trange(num_steps, desc="Training"):
            if step_idx % steps_per_target_update == 0:
                agent.update_target_critic_()
            experience = mock_experience(agent)
            td_errors.append(agent.train_(experience))
        return cls(agent=agent, td_errors=np.array(td_errors))
    
    def example_actions(self, num: int):
        return [mock_experience(self.agent).action for _ in trange(num, desc="Sampling")]


experiment = Experiment.train(critic_lr=1e-3, num_tried_actions=64)

Training:   0%|          | 0/10000 [00:00<?, ?it/s]

In [6]:
go.Figure(
    layout=dict(yaxis_type="log"),
    data=[go.Scatter(y=[error ** 2 for error in experiment.td_errors])]
)

In [7]:
go.Figure(
    data=[
        go.Histogram(
            name="Ask",
            x=[action.ask for action in experiment.example_actions(1024)],
        ),
        go.Histogram(
            name="Bid",
            x=[action.bid for action in experiment.example_actions(1024)],
        ),
    ]
)

Sampling:   0%|          | 0/1024 [00:00<?, ?it/s]

Sampling:   0%|          | 0/1024 [00:00<?, ?it/s]