# DQN from scratch

In [1]:
from dataclasses import dataclass, replace
from functools import cache
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
from tqdm.notebook import tqdm, trange
from typing import List

## Architecture

In [2]:
class Critic(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bias = torch.nn.Parameter(data=torch.tensor(0, dtype=torch.float), requires_grad=True)

    def forward(self):
        return self.bias


@dataclass(frozen=False)
class Agent:
    critic: Critic
    target_critic: Critic
    discount: float

## Training

In [3]:
@dataclass(frozen=True)
class Experiment:
    agent: Agent
    reward: float
    td_errors: List[float]
    biases: List[float]
    bias_grads: List[float]

    @classmethod
    def train(
        cls,
        learning_rate,
        num_steps,
        steps_per_target_update,
        reward: float,
    ):
        agent = Agent(
            critic=Critic(),
            target_critic=Critic(),
            discount=0.99,
        )
        optimizer = torch.optim.Adam(agent.critic.parameters(), lr=learning_rate)
        td_errors = []
        biases = []
        bias_grads = []
        for step_idx in trange(num_steps, desc="Training"):
            if step_idx % steps_per_target_update == 0:
                agent.target_critic.load_state_dict(agent.critic.state_dict())

            new_evaluation = agent.target_critic()
            td_error = (
                reward + agent.discount * new_evaluation.detach() - agent.critic()
            )
            td_errors.append(td_error.item())

            critic_loss = td_error ** 2
            optimizer.zero_grad()
            critic_loss.backward()
            optimizer.step()
            biases.append(agent.critic.bias.item())
            bias_grads.append(agent.critic.bias.grad.item())

        return cls(
            agent=agent,
            reward=reward,
            td_errors=td_errors,
            biases=biases,
            bias_grads=bias_grads,
        )


base_experiment = Experiment.train(
    learning_rate=1e-4,
    num_steps=50_000,
    steps_per_target_update=5_000,
    reward=0,
)
offset_experiment = Experiment.train(
    learning_rate=1e-4,
    num_steps=50_000,
    steps_per_target_update=5_000,
    reward=2,
)
fast_offset_experiment = Experiment.train(
    learning_rate=1e-3,
    num_steps=50_000,
    steps_per_target_update=5_000,
    reward=2,
)


Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Training:   0%|          | 0/50000 [00:00<?, ?it/s]

Training:   0%|          | 0/50000 [00:00<?, ?it/s]

In [5]:
def theoretical_bound(num_updates: int):
    if num_updates == 0:
        return 2
    old_bound = theoretical_bound(num_updates=num_updates - 1)
    discount = offset_experiment.agent.discount
    return old_bound * discount + 2


go.Figure(
    layout=dict(
        xaxis_title="Training step",
        yaxis_title="Bias",
    ),
    data=[
        go.Scatter(name=f"Target = 0", y=base_experiment.biases),
        go.Scatter(name=f"Target = 2", y=offset_experiment.biases),
        go.Scatter(name=f"Target = 2, higher lr", y=fast_offset_experiment.biases),
        go.Scatter(
            name="Theoretical bound",
            y=[
                theoretical_bound(step // 5_000)
                for step in range(len(offset_experiment.biases))
            ],
        ),
    ],
)


It's hard to learn the bias because the target network is updated relatively rarely - and each update can only get us so much closer to the target value.