# DQN from scratch

In [1]:
from dataclasses import dataclass, replace
import numpy as np
from observation import Observation
import plotly.graph_objects as go
import torch
from tqdm.notebook import tqdm, trange

from tradezoo.game import SineWave
from tradezoo.market import Account

## Architecture

In [2]:
@dataclass(frozen=True)
class Observation:
    cash_balance: float
    asset_balance: float
    best_ask: float
    best_bid: float

    @property
    def tensor(self):
        return torch.log(
            1
            + torch.tensor(
                [[self.cash_balance, self.asset_balance, self.best_ask, self.best_bid]],
                dtype=torch.float32,
            )
        )


@dataclass(frozen=True)
class Action:
    ask: float
    bid: float

    @property
    def tensor(self):
        return torch.tensor([[self.ask, self.bid]], dtype=torch.float)


@dataclass(frozen=True)
class Experience:
    old_observation: Observation
    action: Action
    reward: float
    new_observation: Observation


In [3]:
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(6, 64),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(64, 1),
        )

    def forward(self, tensor):
        return self.network(tensor)

    def evaluate(self, observation: Observation, action: Action) -> torch.Tensor:
        return self(
            torch.cat([observation.tensor, action.tensor], axis=1)
        ).squeeze()


@dataclass(frozen=False)
class Agent:
    critic: Critic
    critic_optimizer: torch.optim.Optimizer
    target_critic: Critic
    num_tried_actions: int
    random_action_probability: float
    discount: float

    def act(self, observation: Observation) -> Action:
        if np.random.uniform() < self.random_action_probability:
            return self.random_action()
        return self.best_action(observation)
        
    def best_action(self, observation: Observation) -> Action:
        random_actions = [self.random_action() for _ in range(self.num_tried_actions)]
        return max(
            random_actions, key=lambda action: self.critic.evaluate(observation, action).item()
        )
    
    def random_action(self) -> Action:
        random_price = lambda: np.random.uniform(0, 2)
        return Action(ask=random_price(), bid=random_price())

    def train_(self, experience: Experience):
        best_next_action = self.best_action(experience.new_observation)
        new_evaluation = self.target_critic.evaluate(experience.new_observation, best_next_action)
        td_error = (
            experience.reward
            + self.discount * new_evaluation.detach()
            - self.critic.evaluate(experience.old_observation, experience.action)
        )

        critic_loss = td_error ** 2
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        return td_error.item()
    
    def update_target_critic_(self):
        self.target_critic.load_state_dict(self.critic.state_dict())
    
    def optimal(self, num_tried_actions=64):
        return replace(self, num_tried_actions=num_tried_actions, random_action_probability=0)

## Training

In [4]:
trader_account = Account(cash_balance=64, asset_balance=64)
price_process = (SineWave(period=16) * 0.25).exp()
ask_process = price_process * 1.1
bid_process = price_process / 1.1

In [5]:
def mock_experience(agent: Agent) -> Experience:
    step = np.random.randint(0, 1024)
    old_observation = Observation(
        cash_balance=np.random.uniform(0, trader_account.cash_balance * 2),
        asset_balance=np.random.uniform(0, trader_account.asset_balance * 2),
        best_ask=ask_process.value(step),
        best_bid=bid_process.value(step),
    )
    action = agent.act(old_observation)
    new_cash_balance = old_observation.cash_balance
    new_asset_balance = old_observation.asset_balance
    if action.ask <= old_observation.best_bid:
        sold_assets = min(new_asset_balance, 1)
        new_cash_balance += action.ask * sold_assets
        new_asset_balance -= sold_assets
    if action.bid >= old_observation.best_ask:
        bought_assets = min(1, new_cash_balance / action.bid)
        new_cash_balance -= action.bid * bought_assets
        new_asset_balance += bought_assets
    new_observation = Observation(
        cash_balance=new_cash_balance,
        asset_balance=new_asset_balance,
        best_ask=ask_process.value(step + 1),
        best_bid=bid_process.value(step + 1),
    )
    return Experience(
        old_observation=old_observation,
        action=action,
        reward=-((action.ask - 1.5) ** 2 + (action.bid - 0.5) ** 2),
        new_observation=new_observation,
    )

In [6]:
@dataclass(frozen=True)
class Experiment:
    agent: Agent
    td_errors: np.ndarray

    @classmethod
    def train(cls, critic_lr, num_tried_actions, num_steps=10_000, steps_per_target_update=100):
        critic = Critic()
        agent = Agent(
            critic=critic,
            critic_optimizer=torch.optim.Adam(critic.parameters(), lr=critic_lr),
            target_critic=Critic(),
            num_tried_actions=num_tried_actions,
            random_action_probability=1,
            discount=0.99,
        )
        td_errors = []
        for step_idx in trange(num_steps, desc="Training"):
            if step_idx % steps_per_target_update == 0:
                agent.update_target_critic_()
            experience = mock_experience(agent)
            td_errors.append(agent.train_(experience))
            agent.random_action_probability *= 0.999
        return cls(agent=agent, td_errors=np.array(td_errors))


experiment = Experiment.train(critic_lr=1e-3, num_tried_actions=64)

Training:   0%|          | 0/10000 [00:00<?, ?it/s]

In [7]:
go.Figure(
    layout=dict(xaxis_title="Training step", yaxis_title="TD error squared", yaxis_type="log"),
    data=[go.Scatter(y=[error ** 2 for error in experiment.td_errors])]
)

In [8]:
def plot_action_value(agent: Agent):
    asks = np.linspace(0, 2)
    bids = np.linspace(0, 2)
    values = [
        [
            agent.critic.evaluate(
                Observation(cash_balance=1, asset_balance=1, best_ask=1, best_bid=1),
                action=Action(ask=ask, bid=bid)
            ).item()
            for ask in asks
        ]
        for bid in bids
    ]
    return go.Figure(
        layout=dict(
            scene=dict(
                xaxis_title="Ask",
                yaxis_title="Bid",
                zaxis_title="Action value"
            )
        ),
        data=[
            go.Surface(
                x=asks,
                y=bids,
                z=values,
            )
        ]
    )


plot_action_value(experiment.agent)

In [9]:
def plot_action_distribution(agent: Agent, precision=1024):
    actions = [mock_experience(agent).action for _ in trange(precision, desc="Sampling")]
    return go.Figure(
        data=[
            go.Histogram(
                name="Ask",
                x=[action.ask for action in actions],
            ),
            go.Histogram(
                name="Bid",
                x=[action.bid for action in actions],
            ),
        ]
    )


plot_action_distribution(experiment.agent.optimal(num_tried_actions=256), precision=256)

Sampling:   0%|          | 0/256 [00:00<?, ?it/s]