### Offline Reinforcement Learning

In offline RL, we are given a dataset of transitions

$$
\mathcal{D} = \{ (s_n, a_n, r_n, t_n, s^\prime_{n}) \}_n
$$

consisting of states, actions, rewards, termination signals and next states. The dataset was collected by a potentially unknown behaviour policy $\mu$. Our goal is to learn a policy $\pi$ using $\mathcal{D}$ without generating new data.

Off-policy algorithms such as DQN are able to learn from data that was produced by a different policy. In this exercise, we will first investigate whether dueling DQN (DDQN) is able to learn from offline data.

Subsequently, we will implement a generalized version called batch-constrained Q-learning (BCQ) that was constructed for the task of offline RL.

### Implementation

Make sure that the files `rl_gui.py` and `rl_tests.py` are in the same folder as the notebook.

In [1]:
import collections
import copy
import random

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import os

from torch import nn, optim

import rl_gui
import rl_tests

In [2]:
class ReplayMemory:

    def __init__(self, capacity, rng):
        # create a queue that removes old transitions when capacity is reached
        self.transitions = collections.deque([], maxlen=capacity)

        # random number generator used for sampling batches
        self.rng = rng

    def append(self, transition):
        # append a transition (a tuple) to the queue
        self.transitions.append(transition)

    def sample(self, batch_size):
        # randomly sample a list of indices
        idx = self.rng.choice(len(self.transitions), batch_size, replace=False)

        # select the transitions using the indices
        transitions = [self.transitions[i] for i in idx]

        batches = tuple(torch.as_tensor(np.array(batch)) for batch in zip(*transitions))
        return batches

    def save_transitions(self, suffix=""):
        torch.save(self.transitions, f"transitions_{suffix}")

    def load_transitions(self, path_name):
        self.transitions = torch.load(path_name)

We already implemented the dueling DQN algorithm that we saw in Exercise 9.

In [3]:
class DDQN(nn.Module):

    def __init__(self, state_dim, num_actions, learning_rate, gamma):
        super().__init__()
        # create a simple neural network with two fully-connected layers
        # and a ReLU nonlinearity
        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.gamma = gamma

    def compute_q(self, states, actions):
        # states has shape (batch_size, state_dim)
        # actions has shape (batch_size)

        # compute q[s], which has shape (batch_size, num_actions)
        q_all = self.network(states)

        # select q[s,a], which has shape (batch_size)
        q = torch.gather(q_all, dim=1, index=actions.long().unsqueeze(1)).squeeze(1)
        return q

    def compute_max_q(self, states):
        # states has shape (batch_size, state_dim)

        # compute q[s], which has shape (batch_size, num_actions)
        q_all = self.network(states)

        # select max_a' q[s,a'], which has shape (batch_size)
        max_q = q_all.max(dim=1)[0]
        return max_q

    def compute_arg_max(self, states):
        # states has shape (batch_size, state_dim)

        # compute q[s], which has shape (batch_size, num_actions)
        q_all = self.network(states)

        # select argmax_a' q[s,a'], which has shape (batch_size)
        actions = q_all.argmax(dim=1)
        return actions

    def compute_loss(self, target_dqn, batches):
        states, actions, rewards, terminations, next_states = batches

        # turn off gradient computation
        with torch.no_grad():
            arg_max = self.compute_arg_max(next_states)
            targets = target_dqn.compute_q(next_states, arg_max)
            targets = rewards + self.gamma * (terminations != 1).float() * targets

        # compute predictions q[s,a]
        q = self.compute_q(states, actions)

        # compute mean squared error between q[s,a] and targets
        loss = torch.mean((q - targets.detach()) ** 2)
        return loss

    def update(self, memory, batch_size, target_dqn):
        batches = memory.sample(batch_size)
        # minimize the loss function using SGD
        self.train()  # switch to training mode
        loss = self.compute_loss(target_dqn, batches)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss

In [4]:
# helper function for generating datasets
def generate_data(dqn, env, epsilon, num_transitions, rng):
    replay_memory = ReplayMemory(num_transitions, rng=rng)

    state, _ = env.reset()

    for _ in range(num_transitions):

        # epsilon-greedy policy
        if rng.random() < epsilon:
            action = rng.randint(env.action_space.n)
        else:
            action = dqn.compute_arg_max(torch.as_tensor(state).unsqueeze(0)).item()

        next_state, reward, terminated, truncated, info = env.step(action)

        # store transition in replay memory
        replay_memory.append((state, action, reward, terminated, next_state))

        if terminated or truncated:
            state, _ = env.reset()
        else:
            state = next_state

    return replay_memory

### Implement the `train_offline` and `evaluate_agent` function.

The `train_offline` method performs offline RL using a replay memory.

The `evaluate_agent` method runs $n$ episodes. In each episode $i$ it should calculate the sum of rewards $g_{0, i}$ and the approximated value of the starting state $V(s_0)_i$ (given by the maximum of the Q-function). It then returns the average over all episodes of these two quantities, i.e.

$$
1/n \sum_{i=1}^{n} g_{0, i}, \quad 1/n \sum_{i=1}^{n} V(s_0)_i
$$

In [14]:
def train_offline(dqn, target_dqn, target_interval, memory, num_updates, batch_size):
    #######################################################################
    # TODO Perform num_updates many updates of the dqn agent.             #
    # use the "update()" function of DDQN. Update the target_dqn after    #
    # target_interval many update steps by using copy.deepcopy()          #
    #######################################################################

    for i in range(num_updates):
        dqn.update(memory, batch_size, target_dqn)
        if i % target_interval == 0 and i > 0:
            target_dqn = copy.deepcopy(dqn)

    return target_dqn


from itertools import count


def evaluate_agent(env, agent, num_episodes):
    '''
    :param env: environment for interaction
    :param agent: RL agent
    :param num_episodes: Run num_episodes many episodes
    :return: the average return across episodes and the average value estimates of the starting states
    '''

    #######################################################################
    # TODO Run num_episodes episodes and compute the average return and   #
    # the average value estimates of the starting states.                 #
    # Use argmax for action selection.                                    #
    #######################################################################

    avg_return = avg_value = 0

    for _ in range(num_episodes):
        state, _ = env.reset()
        state = torch.as_tensor(state).unsqueeze(0)
        avg_value += agent.compute_max_q(state).item()

        for t in count(0):
            action = agent.compute_arg_max(state).item()
            state, reward, terminated, *_ = env.step(action)
            avg_return += reward
            if terminated: break
            state = torch.as_tensor(state).unsqueeze(0)

    avg_return /= num_episodes
    avg_value /= num_episodes

    return avg_return, avg_value

In [15]:
def create_env(seed):
    env_id = 'LunarLander-v3'
    #env_id = "ALE/MsPacman-v5"
    env = gym.make(env_id, render_mode='rgb_array')
    env.reset(seed=seed)
    return env

In [16]:
def test_train():
    rng = np.random.Generator(np.random.PCG64(seed=42))
    torch.manual_seed(42)
    env = create_env(42)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    agent = DDQN(env.observation_space.shape[0], env.action_space.n, learning_rate=0.1, gamma=0.99)
    target_agent = copy.deepcopy(agent)
    memory = generate_data(agent, env, 0.0, 12, rng)

    sample_states = lambda batch_size: torch.as_tensor(rng.standard_normal((batch_size, state_dim), dtype=np.float32))
    sample_actions = lambda batch_size: torch.as_tensor(rng.choice(action_dim, batch_size))

    yield "train_offline()"

    for expected_q_values in [
        [0.14601666, 0.20298843, 0.07980790, 0.20461105, 0.47405303],
        [0.32259068, -2.07440472, 0.01698574, -14.59191322, -2.82291675]
    ]:
        states = sample_states(5)
        actions = sample_actions(5)
        target_agent = train_offline(agent, target_agent, 2, memory, 5, 5)
        q_values = agent.compute_q(states, actions).detach()

        yield torch.allclose(q_values, torch.as_tensor(
            expected_q_values)), f'Q-values are incorrect (error = {(torch.abs(q_values - torch.as_tensor(expected_q_values))).sum().item()})'
        yield None


def test_evaluate():
    torch.manual_seed(42)
    env = create_env(42)
    agent = DDQN(env.observation_space.shape[0], env.action_space.n, learning_rate=0.1, gamma=0.99)

    yield "evaluate_agent()"
    num_episodes = [1, 3, 5]
    avg_returns = [-591.34799665, -2014.67668242, -1145.47214997]
    avg_values = [0.13210351, 0.21220374, 0.16889345]

    for num, ret, val in zip(num_episodes, avg_returns, avg_values):
        avg_return, avg_value = evaluate_agent(env, agent, num)

        yield torch.allclose(torch.Tensor([ret]), torch.Tensor([avg_return])), f'Average return is incorrect, (error = {abs(ret - avg_return)})'
        yield torch.allclose(torch.Tensor([val]), torch.Tensor([avg_value])), f'Average value is incorrect, (error = {abs(val - avg_value)})'

        yield None


rl_tests.run_tests(test_train())
rl_tests.run_tests(test_evaluate())

Testing train_offline()...
1/2 tests passed!
Test #2 failed: Q-values are incorrect (error = 4.1093677282333374e-05)
Testing evaluate_agent()...
3/3 tests passed!


### We already collected data using the following expert agent.

In [None]:
env = create_env(seed=1)
render = rl_gui.create_renderer(env, fps=60, figsize=(4, 3))

dqn = DDQN(env.observation_space.shape[0], env.action_space.n, learning_rate=0.1, gamma=0.99)
dqn.load_state_dict(torch.load('data_and_models/expert_dqn.pt'))

state, _ = env.reset()
render()
reward_sum = 0.0
for _ in range(300):
    action = dqn.compute_arg_max(torch.as_tensor(state).unsqueeze(0)).item()
    state, reward, terminated, truncated, _ = env.step(action)
    reward_sum += reward
    render(f'sum of rewards: {reward_sum:.2f}')
    if terminated or truncated:
        break

### We collected 3 datasets with 100, 1000, and 10000 transitions.

Let us run offline RL with the DDQN agent using the implemented `train_offline` method. We evaluate the trained agent every 10000 updates using the `evaluate_agent` method.

In [None]:
seed = 1
rng = np.random.default_rng(seed)
torch.manual_seed(seed)

env = create_env(seed)
state_dim = env.observation_space.shape[0]
num_actions = env.action_space.n

batch_size = 64  # number of transitions in a batch
replay_capacity = int(1e3)  # number of transitions that are stored in memory, does not matter here
gamma = 0.99  # discount factor
learning_rate = 0.0001  # learning rate for DDQN
target_interval = 100  # synchronize the target network after this number of steps
num_offline_updates = 10000  # number of offline updates to perform

In [None]:
data_paths = ["transitions_expert_0.1k", "transitions_expert_1k", "transitions_expert_10k"]
data_paths = [os.path.join("data_and_models", path) for path in data_paths]

average_returns_all = []
q_values_all = []

for path in data_paths:
    memory = ReplayMemory(replay_capacity, rng)
    memory.load_transitions(path)
    print(f"Number of transitions in {path} dataset: {len(memory.transitions)}")

    ddqn = DDQN(state_dim, num_actions, learning_rate, gamma)
    average_returns = []
    q_values = []

    print("Start training...")

    for i in range(20):
        # Train for num_offline_updates many steps
        target_ddqn = copy.deepcopy(ddqn)
        train_offline(ddqn, target_ddqn, target_interval, memory,
                      num_offline_updates, batch_size)

        # Evaluate the agent
        performance, qs = evaluate_agent(env, ddqn, 20)
        average_returns.append(performance)
        q_values.append(qs)
        print(f"Iteration: {i + 1}/20, Average return: {performance}, Value-Estimates: {qs}")

    average_returns_all.append(average_returns)
    q_values_all.append(q_values)

Let's plot the results for better visualization.

In [None]:
labels = ["0.1k Expert data", "1k Expert data", "10k Expert data"]

x = [(i + 1) for i in range(20)]
fig, axes = plt.subplots(1, 2, sharex=True, sharey=False, figsize=(10, 4))

axes[0].set_ylabel('Average Return')
axes[0].set_ylim([-700, 0])
axes[0].set_xlabel('Num updates / 10k')
axes[0].set_xticks(x)
for i, n in enumerate(average_returns_all):
    axes[0].plot(x, n, label=labels[i])
axes[0].legend()

axes[1].set_ylabel('Value-Estimates')
axes[1].set_ylim([0, 5000])
axes[1].set_xlabel('Num updates / 10k')
axes[1].set_xticks(x)
for i, n in enumerate(q_values_all):
    axes[1].plot(x, n, label=labels[i])
axes[1].legend()

fig.tight_layout()

### What can you observe?

### BCQ

Batch-constrained Q-learning tries to overcome the overestimations that result from the target

$$
r_t + \gamma \cdot max_{a^\prime} Q(s_{t+1}, a^\prime)
$$
if actions $a^\prime$ are out of distribution, i.e. $(s_{t+1}, a^\prime)$ is not in the dataset and can thus be not updated.

The idea is to only allow actions that were likely used by the behaviour policy $\mu$ that created the dataset. Since these actions are likely to be in the dataset, the overestimation is reduced since we obtain more reliable estimates of the targets.

In BCQ, in addition to the Q-network, we learn an imitator network $G_\omega$ that tries to imitate the behaviour policy $\mu$. The imitator network is trained to maximize likelihood of the actions in the dataset by minimizing the loss

$$
\mathcal{L}_{imitator} = -\mathbb{E}_{(s, a) \sim D} \log G_\omega(s, a).
$$

The BCQ agent selects actions, given a state $s$, according to

$$
\pi(s) = argmax_{a | G_\omega(s,a) / max_{a^\prime} G_\omega(s,a^\prime)) \geq \tau} \text{ } Q_\theta (s, a),
$$
where $\tau$ is a threshold value. We thus only consider actions that are "reliable" according to the imitator network. The Q-network parameterized by $\theta$ aims to minimize

$$
r + \gamma Q_{\theta^\prime}(s^\prime, a^\prime) - Q_\theta(s, a), \quad a^\prime = argmax_{a^\prime | G_\omega(s^\prime,a^\prime) / max_{a^\prime} G_\omega(s^\prime,a^\prime)) \geq \tau} \text{ } Q_\theta (s^\prime, a^\prime),
$$
where $\theta^\prime$ are the parameters of the target network.

In [None]:
class BCQ(nn.Module):

    def __init__(self, state_dim, num_actions, learning_rate, gamma, threshold=0.0):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )

        self.imitator = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.threshold = threshold

    def compute_log_probs(self, states):
        #############################################################################
        # TODO Compute the log-probabilities for every state in states using        #
        # the imitator network and the torch function logsumexp (as in Exercise 11) #
        # the input has shape (batch_size, state_dim)                               #
        # the output has shape (batch_size, num_actions)                            #
        #############################################################################

        raise NotImplementedError

    def imitator_loss(self, batches):
        #############################################################################
        # TODO Calculate the imitator loss L_imitator as described above            #
        #                                                                           #
        #############################################################################

        states, actions, _, _, _ = batches

        loss = 0
        return loss

    def compute_q(self, states, actions):
        q_all = self.network(states)
        q = torch.gather(q_all, dim=1, index=actions.long().unsqueeze(1)).squeeze(1)
        return q

    def compute_loss(self, target_dqn, batches):
        states, actions, rewards, terminations, next_states = batches

        # turn off gradient computation
        with torch.no_grad():
            arg_max = self.compute_arg_max(next_states)
            targets = target_dqn.compute_q(next_states, arg_max)
            targets = rewards + self.gamma * (terminations != 1).float() * targets

        # compute predictions q[s,a]
        q = self.compute_q(states, actions)

        # compute mean squared error between q[s,a] and targets
        loss = torch.mean((q - targets.detach()) ** 2)
        return loss

    def compute_max_q(self, states):
        ##################################################################################
        # TODO Calculate the max q value for every state                                 #
        # The max operator should be only applied to actions that satisfy the            #
        # constraint mentioned above, i.e. over the set                                  #
        # {a | G_\omega(s,a) / max_{a^\prime} G_\omega(s,a^\prime)) \geq self.threshold} #
        ##################################################################################

        q_all = self.network(states)

        # select max_a' q[s,a'], which has shape (batch_size)
        max_q = q_all.max(dim=1)[0]
        return max_q

    def compute_arg_max(self, states):
        ##################################################################################
        # TODO Calculate the arg max q value for every state                             #
        # The max operator should be only applied to actions that satisfy the            #
        # constraint mentioned above, i.e. over the set                                  #
        # {a | G_\omega(s,a) / max_{a^\prime} G_\omega(s,a^\prime)) \geq self.threshold} #
        ##################################################################################

        q_all = self.network(states)

        actions = q_all.argmax(dim=1)
        return actions

    def update(self, memory, batch_size, target_dqn):
        batches = memory.sample(batch_size)
        # minimize the loss function using SGD
        self.train()  # switch to training mode
        loss = self.compute_loss(target_dqn, batches) + self.imitator_loss(batches)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss


In [None]:
def test_bcq():
    torch.manual_seed(42)
    rng = np.random.Generator(np.random.PCG64(seed=42))
    state_dim = 5
    num_actions = 3
    gamma = 0.8
    learning_rate = 0.1
    bcq = BCQ(state_dim, num_actions, learning_rate, gamma, threshold=0.3)

    with torch.no_grad():
        for layer in (bcq.network[0], bcq.network[2], bcq.imitator[0], bcq.imitator[2]):
            mean = rng.uniform(-0.5, 0.5)
            layer.weight[:] = torch.as_tensor(rng.normal(mean, 0.1, layer.weight.shape))
            nn.init.zeros_(layer.bias)

    sample_states = lambda batch_size: torch.as_tensor(rng.standard_normal((batch_size, state_dim), dtype=np.float32))
    sample_actions = lambda batch_size: torch.as_tensor(rng.choice(num_actions, batch_size))

    yield 'compute_log_probs()'

    for expected_log_softmax in [
        [[-1.29940438, -0.79732537, -1.28455567],
         [-1.01962435, -1.08546841, -1.19889891]],

        [[-1.08598757, -1.12744570, -1.08301783],
         [-1.10690618, -1.07192361, -1.11757934]]
    ]:
        batch_size = len(expected_log_softmax)
        states = sample_states(batch_size)
        log_softmax = bcq.compute_log_probs(states)
        yield torch.allclose(log_softmax, torch.as_tensor(
            expected_log_softmax)), f'log_softmaxs are incorrect (error = {torch.sum(torch.abs(log_softmax - torch.as_tensor(expected_log_softmax))).item()}'
        yield None

    yield 'imitator_loss()'

    for expected_loss in [
        1.097770094871521,
        1.147094488143921,
        1.0680086612701416
    ]:
        states = sample_states(5)
        actions = sample_actions(5)
        batch = (states, actions, _, _, _)
        loss = bcq.imitator_loss(batch)
        yield torch.allclose(loss, torch.Tensor([expected_loss])), f'Imitator loss is incorrect (error = {torch.sum(torch.abs(loss - expected_loss)).item()}'
        yield None

    yield 'compute_max_q()'

    for expected_value in [
        [4.21411991, 0.00000000, 28.97923660, 38.94561005, 33.75335693],
        [35.10310364, 15.52004814, 1.47012889, 37.41002274, 14.30604076]
    ]:
        states = sample_states(5)
        value = bcq.compute_max_q(states)

        yield torch.allclose(value, torch.as_tensor(
            expected_value)), f'Values are incorrect (error = {torch.sum(torch.abs(value - torch.as_tensor(expected_value))).item()}'
        yield None

    yield 'compute_argmax_q'

    for expected_actions in [
        [2, 0, 2, 2, 2],
        [0, 2, 2, 1, 2]
    ]:
        states = sample_states(5)
        actions = bcq.compute_arg_max(states)
        yield torch.all(actions == torch.as_tensor(expected_actions)).item(), 'actions are incorrect'
        yield None


rl_tests.run_tests(test_bcq())

We run BCQ agents with a threshold of $0.3$ on the datasets with 100 and 1000 transitions.

In [None]:
data_paths = ["transitions_expert_0.1k", "transitions_expert_1k"]
data_paths = [os.path.join("data_and_models", path) for path in data_paths]

average_returns_all_bcq = []
q_values_all_bcq = []

for path in data_paths:
    memory = ReplayMemory(replay_capacity, rng)
    memory.load_transitions(path)
    print(f"Number of transitions in {path}: {len(memory.transitions)}")

    bcq = BCQ(state_dim, num_actions, learning_rate, gamma, threshold=0.3)
    average_returns = []
    q_values = []

    print("Start training...")

    for i in range(10):
        # Train for num_offline_updates many steps
        target_bcq = copy.deepcopy(bcq)
        train_offline(bcq, target_bcq, target_interval, memory,
                      num_offline_updates, batch_size)

        # Evaluate the agent
        performance, qs = evaluate_agent(env, bcq, 20)
        average_returns.append(performance)
        q_values.append(qs)
        print(f"Iteration: {i + 1}/10, Average return: {performance}, Value-Estimates: {qs}")

    average_returns_all_bcq.append(average_returns)
    q_values_all_bcq.append(q_values)

In [None]:
labels = ["0.1k Expert data", "1k Expert data"]

x = [(i + 1) for i in range(10)]
fig, axes = plt.subplots(1, 2, sharex=True, sharey=False, figsize=(10, 4))

axes[0].set_ylabel('Average return')
axes[0].set_ylim([-300, 250])
axes[0].set_xlabel('Num updates / 10k')
axes[0].set_xticks(x)
for i, n in enumerate(average_returns_all_bcq):
    axes[0].plot(x, n, label=labels[i])
axes[0].legend()

axes[1].set_ylabel('Value-Estimates')
axes[1].set_ylim([-300, 250])
axes[1].set_xlabel('Num updates / 10k')
axes[1].set_xticks(x)
for i, n in enumerate(q_values_all_bcq):
    axes[1].plot(x, n, label=labels[i])
axes[1].legend()

fig.tight_layout()

### What can you observe when you compare the performance and corresponding Value-estimates?

Lastly, we train a BCQ agent with a dataset that was collect with the sub-optimal policy below.

In [None]:
suboptimal_ddqn = DDQN(state_dim, num_actions, learning_rate, gamma)
suboptimal_ddqn.load_state_dict(torch.load("data_and_models/suboptimal_dqn.pt"))

performance, qs = evaluate_agent(env, suboptimal_ddqn, 20)
print(f"Performance: {performance}, Value-Estimates: {qs}")

In [None]:
memory = ReplayMemory(replay_capacity, rng)
memory.load_transitions("data_and_models/transitions_suboptimal_1k_eps0.2")
print("Number of transitions in memory:", len(memory.transitions))

for threshold in [0.3]:

    bcq = BCQ(state_dim, num_actions, learning_rate, gamma, threshold=threshold)
    target_bcq = copy.deepcopy(bcq)

    for i in range(5):
        target_bcq = copy.deepcopy(bcq)
        train_offline(bcq, target_bcq, target_interval, memory,
                      num_updates=10000, batch_size=batch_size)

        performance, qs = evaluate_agent(env, bcq, 20)
        print(i + 1, performance, qs)


### Can the Offline Agent outperform the Behaviour policy?