In [None]:
import gym
from gym import spaces
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Module, Linear
from torch.distributions import Distribution, Normal
from torch.nn.functional import relu, logsigmoid

import matplotlib.pyplot as plt
import copy
from pyvirtualdisplay import Display
from IPython import display as disp
import sys

# Adaptively Calibrated Critics (ACC)

Implementation of the Adaptively Calibrated Critics (ACC) algorithm, an deep RL approach designed to dynamically adjust bias in Temporal Difference (TD) learning targets. 

ACC builds upon the concepts of Truncated Quantile Critics (TQC) by introducing dynamic adjustments to the quantile targets exclusion parameter \( \beta \), improving the precision in bias adjustment. This allows for a more nuanced control between over- and underestimation, thereby improving both the robustness and accuracy of the policy learning process.

## Source
The code is heavily adapted from the original ACC implementation available at [Nicolinho/ACC](https://github.com/Nicolinho/ACC/tree/main)



In [None]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
EPISODE_LENGTH = 2000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOG_STD_MIN_MAX = (-20, 2)

In [None]:


class RescaleAction(gym.ActionWrapper):
    def __init__(self, env, a, b):
        assert isinstance(env.action_space, spaces.Box), (
            "expected Box action space, got {}".format(type(env.action_space)))
        assert np.less_equal(a, b).all(), (a, b)
        super(RescaleAction, self).__init__(env)
        self.a = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + a
        self.b = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + b
        self.action_space = spaces.Box(low=a, high=b, shape=env.action_space.shape, dtype=env.action_space.dtype)

    def action(self, action):
        assert np.all(np.greater_equal(action, self.a)), (action, self.a)
        assert np.all(np.less_equal(action, self.b)), (action, self.b)
        low = self.env.action_space.low
        high = self.env.action_space.high
        action = low + (high - low)*((action - self.a)/(self.b - self.a))
        action = np.clip(action, low, high)
        return action

class Mlp(Module):
    def __init__(
            self,
            input_size,
            hidden_sizes,
            output_size
    ):
        super().__init__()
        self.fcs = []
        in_size = input_size
        for i, next_size in enumerate(hidden_sizes):
            fc = Linear(in_size, next_size)
            self.add_module(f'fc{i}', fc)
            self.fcs.append(fc)
            in_size = next_size
        self.last_fc = Linear(in_size, output_size)

    def forward(self, input):
        h = input
        for fc in self.fcs:
            h = relu(fc(h))
        output = self.last_fc(h)
        return output

class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim, max_size=int(1e6)):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0

        self.transition_names = ('state', 'action', 'next_state', 'reward', 'not_done')
        sizes = (state_dim, action_dim, state_dim, 1, 1)
        for name, size in zip(self.transition_names, sizes):
            setattr(self, name, np.empty((max_size, size)))

    def add(self, state, action, next_state, reward, done):
        values = (state, action, next_state, reward, 1. - done)
        for name, value in zip(self.transition_names, values):
            getattr(self, name)[self.ptr] = value

        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        names = self.transition_names
        return (torch.FloatTensor(getattr(self, name)[ind]).to(DEVICE) for name in names)

    def states_by_ptr(self, ptr_list, cpu=False):
        ind = np.array([], dtype='int64')
        for interval in ptr_list:
            if interval[0] < interval[1]:
                ind = np.concatenate((ind, np.arange(interval[0], interval[1])))
            elif interval[0] > interval[1]:
                ind = np.concatenate((ind, np.arange(interval[0], self.max_size)))
                ind = np.concatenate((ind, np.arange(0, interval[1])))

        names = ('state', 'action')
        if cpu:
            return (torch.FloatTensor(getattr(self, name)[ind]) for name in names)
        else:
            return (torch.FloatTensor(getattr(self, name)[ind]).to(DEVICE) for name in names)


class Critic(Module):
    def __init__(self, state_dim, action_dim, n_quantiles, n_nets):
        super().__init__()
        self.nets = []
        self.n_quantiles = n_quantiles
        self.n_nets = n_nets
        for i in range(n_nets):
            net = Mlp(state_dim + action_dim, [512, 512, 512], n_quantiles)
            self.add_module(f'qf{i}', net)
            self.nets.append(net)

    def forward(self, state, action):
        sa = torch.cat((state, action), dim=1)
        quantiles = torch.stack(tuple(net(sa) for net in self.nets), dim=1)
        return quantiles


class Actor(Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.action_dim = action_dim
        self.net = Mlp(state_dim, [256, 256], 2 * action_dim)

    def forward(self, obs):
        mean, log_std = self.net(obs).split([self.action_dim, self.action_dim], dim=1)
        log_std = log_std.clamp(*LOG_STD_MIN_MAX)

        if self.training:
            std = torch.exp(log_std)
            tanh_normal = TanhNormal(mean, std)
            action, pre_tanh = tanh_normal.rsample()
            log_prob = tanh_normal.log_prob(pre_tanh)
            log_prob = log_prob.sum(dim=1, keepdim=True)
        else:  # deterministic eval without log_prob computation
            action = torch.tanh(mean)
            log_prob = None
        return action, log_prob

    def select_action(self, obs):
        obs = torch.FloatTensor(obs).to(DEVICE)[None, :]
        action, _ = self.forward(obs)
        action = action[0].cpu().detach().numpy()
        return action


class TanhNormal(Distribution):
    def __init__(self, normal_mean, normal_std):
        super().__init__()
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.standard_normal = Normal(torch.zeros_like(self.normal_mean, device=DEVICE),torch.ones_like(self.normal_std, device=DEVICE))
        self.normal = Normal(normal_mean, normal_std)

    def log_prob(self, pre_tanh):
        log_det = 2 * np.log(2) + logsigmoid(2 * pre_tanh) + logsigmoid(-2 * pre_tanh)
        result = self.normal.log_prob(pre_tanh) - log_det
        return result

    def rsample(self):
        pretanh = self.normal_mean + self.normal_std * self.standard_normal.sample()
        return torch.tanh(pretanh), pretanh

In [None]:
def eval_policy(policy, eval_env, max_episode_steps, eval_episodes=10):
    policy.eval()
    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        t = 0
        while not done and t < max_episode_steps:
            action = policy.select_action(state)
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward
            t += 1
    avg_reward /= eval_episodes
    policy.train()
    return avg_reward


def quantile_huber_loss_f(quantiles, samples):
    pairwise_delta = samples[:, None, None, :] - quantiles[:, :, :, None]  # batch x nets x quantiles x samples
    abs_pairwise_delta = torch.abs(pairwise_delta)
    huber_loss = torch.where(abs_pairwise_delta > 1,
                             abs_pairwise_delta - 0.5,
                             pairwise_delta ** 2 * 0.5)

    n_quantiles = quantiles.shape[2]
    tau = torch.arange(n_quantiles, device=DEVICE).float() / n_quantiles + 1 / 2 / n_quantiles
    loss = (torch.abs(tau[None, None, :, None] - (pairwise_delta < 0).float()) * huber_loss).mean()
    return loss

In [None]:
class Trainer(object):
    def __init__(
            self,
            *,
            actor,
            critic,
            critic_target,
            discount,
            tau,
            top_quantiles_to_drop,
            target_entropy,
            use_acc,
            lr_dropped_quantiles,
            adjusted_dropped_quantiles_init,
            adjusted_dropped_quantiles_max,
            diff_ma_coef,
            num_critic_updates,
            # writer
    ):

        self.actor = actor
        self.critic = critic
        self.critic_target = critic_target
        self.log_alpha = torch.zeros((1,), requires_grad=True, device=DEVICE)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=3e-4)

        self.discount = discount
        self.tau = tau
        self.top_quantiles_to_drop = top_quantiles_to_drop
        self.target_entropy = target_entropy

        self.quantiles_total = critic.n_quantiles * critic.n_nets

        self.total_it = 0

        # self.writer = writer

        self.use_acc = use_acc
        self.num_critic_updates = num_critic_updates
        if use_acc:
            self.adjusted_dropped_quantiles = torch.tensor(adjusted_dropped_quantiles_init, requires_grad=True)
            self.adjusted_dropped_quantiles_max = adjusted_dropped_quantiles_max
            self.dropped_quantiles_dropped_optimizer = torch.optim.SGD([self.adjusted_dropped_quantiles], lr=lr_dropped_quantiles)
            self.first_training = True
            self.diff_ma_coef = diff_ma_coef

    def train(self, replay_buffer, batch_size=256, ptr_list=None, disc_return=None, do_beta_update=False):

        if ptr_list is not None and do_beta_update:
            self.update_beta(replay_buffer, ptr_list, disc_return)

        for it in range(self.num_critic_updates):
            state, action, next_state, reward, not_done = replay_buffer.sample(batch_size)
            alpha = torch.exp(self.log_alpha)

            # --- Q loss ---
            with torch.no_grad():
                # get policy action
                new_next_action, next_log_pi = self.actor(next_state)
                # compute and cut quantiles at the next state
                next_z = self.critic_target(next_state, new_next_action)  # batch x nets x quantiles
                sorted_z, _ = torch.sort(next_z.reshape(batch_size, -1))
                if self.use_acc:
                    sorted_z_part = sorted_z[:, :self.quantiles_total - round(self.critic.n_nets * self.adjusted_dropped_quantiles.item())]
                else:
                    sorted_z_part = sorted_z[:, :self.quantiles_total - self.top_quantiles_to_drop]

                # compute target
                target = reward + not_done * self.discount * (sorted_z_part - alpha * next_log_pi)

            cur_z = self.critic(state, action)
            critic_loss = quantile_huber_loss_f(cur_z, target.detach())

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
                    target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)


        # --- Policy and alpha loss ---
        new_action, log_pi = self.actor(state)
        alpha_loss = -self.log_alpha * (log_pi + self.target_entropy).detach().mean()
        actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()


        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.total_it += 1

    
    def update_beta(self, replay_buffer, ptr_list=None, disc_return=None):
        state, action = replay_buffer.states_by_ptr(ptr_list)
        disc_return = torch.FloatTensor(disc_return).to(DEVICE)
        assert disc_return.shape[0] == state.shape[0]

        mean_Q_last_eps =  self.critic(state, action).mean(2).mean(1, keepdim=True).mean().detach()
        mean_return_last_eps = torch.mean(disc_return).detach()

        if self.first_training:
            self.diff_mvavg = torch.abs(mean_return_last_eps - mean_Q_last_eps).detach()
            self.first_training = False
        else:
            self.diff_mvavg = (1 - self.diff_ma_coef) * self.diff_mvavg \
                              + self.diff_ma_coef * torch.abs(mean_return_last_eps - mean_Q_last_eps).detach()

        diff_qret = ((mean_return_last_eps - mean_Q_last_eps) / (self.diff_mvavg + 1e-8)).detach()
        aux_loss = self.adjusted_dropped_quantiles * diff_qret
        self.dropped_quantiles_dropped_optimizer.zero_grad()
        aux_loss.backward()
        self.dropped_quantiles_dropped_optimizer.step()
        self.adjusted_dropped_quantiles.data = self.adjusted_dropped_quantiles.clamp(min=0., max=self.adjusted_dropped_quantiles_max)

        # self.writer.add_scalar('learner/adjusted_dropped_quantiles', self.adjusted_dropped_quantiles, self.total_it)


In [None]:
def main(args, prefix):
    # --- Init ---
    
    log_f = open("agent-log.txt","w+")
    

    env = gym.make(args.env)
    video_every = 25
    env = gym.wrappers.Monitor(env, "./videoACC", video_callable=lambda ep_id: ep_id%video_every == 0, force=True)
    
    env.seed(seed)
    env.action_space.seed(seed)

    env = RescaleAction(env, -1., 1.)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    replay_buffer = ReplayBuffer(state_dim, action_dim)
    actor = Actor(state_dim, action_dim).to(DEVICE)
    critic = Critic(state_dim, action_dim, args.n_quantiles, args.n_nets).to(DEVICE)
    critic_target = copy.deepcopy(critic)

    top_quantiles_to_drop = args.top_quantiles_to_drop_per_net * args.n_nets
    
    trainer = Trainer(actor=actor,
                      critic=critic,
                      critic_target=critic_target,
                      top_quantiles_to_drop=top_quantiles_to_drop,
                      discount=args.discount,
                      tau=args.tau,
                      target_entropy=-np.prod(env.action_space.shape).item(),
                      use_acc=args.use_acc,
                      lr_dropped_quantiles=args.lr_dropped_quantiles,
                      adjusted_dropped_quantiles_init=args.adjusted_dropped_quantiles_init,
                      adjusted_dropped_quantiles_max=args.adjusted_dropped_quantiles_max,
                      diff_ma_coef=args.diff_ma_coef,
                      num_critic_updates=args.num_critic_updates)

    state, done = env.reset(), False
    episode_return, last_episode_return = 0, 0
    episode_timesteps = 0
    episode_num = 0

    actor.train()
    for t in range(int(args.max_timesteps)):
        
        print(f"T: {episode_timesteps}", end="\r")
        action = actor.select_action(state)
        next_state, reward, done, _ = env.step(action)
        episode_timesteps += 1
        episode_return += reward
   
        replay_buffer.add(state, action, next_state, reward, done)

        state = next_state
        
        # Train agent after collecting sufficient data
        if t >= args.init_expl_steps:
            trainer.train(replay_buffer, args.batch_size)


        if done or episode_timesteps >= EPISODE_LENGTH:
            print(f"Total T: {t + 1} Episode Num: {episode_num + 1} Episode T: {episode_timesteps} Reward: {episode_return:.3f}")
            state, done = env.reset(), False
            
            log_f.write('episode: {}, reward: {}\n'.format(episode_num, episode_return))
            log_f.flush()

            episode_return = 0
            episode_timesteps = 0
            episode_num += 1

In [None]:

sys.argv = [
    "notebook",
    "--env", "BipedalWalkerHardcore-v3",
    "--eval_freq", "5000",
    "--max_timesteps", "1000000",
    "--init_expl_steps", "5000",
    "--seed", "42",
    "--n_quantiles", "25",
    "--use_acc", "True",
    "--top_quantiles_to_drop_per_net", "2",
    "--beta_udate_rate", "1000",
    "--init_num_steps_before_beta_updates", "25000",
    "--size_limit_beta_update_batch", "5000",
    "--lr_dropped_quantiles", "0.1",
    "--adjusted_dropped_quantiles_init", "2.5",
    "--adjusted_dropped_quantiles_max", "5.0",
    "--diff_ma_coef", "0.05",
    "--num_critic_updates", "4",
    "--n_nets", "5",
    "--batch_size", "256",
    "--discount", "0.99",
    "--tau", "0.005",
    "--log_dir", "results",
    "--exp_name", "eval_run",
    "--prefix", "",
]

In [None]:
import sys
import argparse
from pathlib import Path

def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser()
parser.add_argument("--env", default="BipedalWalkerHardcore-v3")              # OpenAI gym environment name
parser.add_argument("--eval_freq", default=1e3, type=int)           # How often (time steps) we evaluate
parser.add_argument("--max_timesteps", default=1e6, type=int)       # Max time steps to run environment
parser.add_argument("--init_expl_steps", default=5000, type=int)    # num of exploration steps before training starts
parser.add_argument("--seed", default=42, type=int)                  # random seed
parser.add_argument("--n_quantiles", default=25, type=int)          # number of quantiles for TQC
parser.add_argument("--use_acc", default=True, type=str2bool)       # if acc for automatic tuning of beta shall be used, o/w top_quantiles_to_drop_per_net will be used
parser.add_argument("--top_quantiles_to_drop_per_net",
                    default=3, type=int)        # how many quantiles to drop per net. Parameter has no effect if: use_acc = True
parser.add_argument("--beta_udate_rate", default=1000, type=int)# num of steps between beta/dropped_quantiles updates
parser.add_argument("--init_num_steps_before_beta_updates",
                    default=25000, type=int)    # num steps before updates to dropped_quantiles are started
parser.add_argument("--size_limit_beta_update_batch",
                    default=5000, type=int)     # size of most recent state-action pairs stored for dropped_quantiles updates
parser.add_argument("--lr_dropped_quantiles",
                    default=0.1, type=float)    # learning rate for dropped_quantiles
parser.add_argument("--adjusted_dropped_quantiles_init",
                    default=2.5, type=float)     # initial value of dropped_quantiles
parser.add_argument("--adjusted_dropped_quantiles_max",
                    default=5.0, type=float)    # maximal value for dropped_quantiles
parser.add_argument("--diff_ma_coef", default=0.05, type=float)     # moving average param. for normalization of dropped_quantiles loss
parser.add_argument("--num_critic_updates", default=1, type=int)    # number of critic updates per environment step
parser.add_argument("--n_nets", default=5, type=int)                # number of critic networks
parser.add_argument("--batch_size", default=256, type=int)          # Batch size for both actor and critic
parser.add_argument("--discount", default=0.99, type=float)         # Discount factor
parser.add_argument("--tau", default=0.005, type=float)             # Target network update rate
parser.add_argument("--log_dir", default='results')                 # results directory
parser.add_argument("--exp_name", default='eval_run')               # name of experiment
parser.add_argument("--prefix", default='')                         # optional prefix to the name of the experiments
args = parser.parse_args()

log_dir = Path(args.log_dir)

# Print the parsed arguments for confirmation
print(f"env: {args.env}")
print(f"eval_freq: {args.eval_freq}")
print(f"max_timesteps: {args.max_timesteps}")
print(f"seed: {args.seed}")
print(f"n_quantiles: {args.n_quantiles}")
print(f"top_quantiles_to_drop_per_net: {args.top_quantiles_to_drop_per_net}")
print(f"n_nets: {args.n_nets}")
print(f"batch_size: {args.batch_size}")
print(f"discount: {args.discount}")
print(f"tau: {args.tau}")
print(f"log_dir: {args.log_dir}")
print(f"prefix: {args.prefix}")

# Call your main function using the parsed arguments
main(args, args.prefix)