In [1]:
import numpy as np
import torch

from stable_baselines3.ppo import PPO

from cppo import CPPO_Policy
from env import SlidingAntEnv

import gym

import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
env = gym.make('CartPole-v1')
env.seed(42)

[42]

In [3]:
ppo = PPO("MlpPolicy", env, seed=42)

In [4]:
ppo.learn(2048)

<stable_baselines3.ppo.ppo.PPO at 0x20859b184c0>

In [5]:
buffer = copy.deepcopy(ppo.rollout_buffer)

In [6]:
env1 = gym.make('CartPole-v1')
env1.seed(42)
ppo_normal = PPO("MlpPolicy", env1, seed=42)
ppo_normal._setup_learn(2048, env1)

(2048, <stable_baselines3.common.callbacks.CallbackList at 0x20859d17250>)

In [7]:
ppo_original_params = ppo_normal.policy.parameters_to_vector()

In [8]:
torch.manual_seed(42)
ppo_normal.rollout_buffer = copy.deepcopy(buffer)
ppo_normal.train()

In [9]:
ppo_trained_params = ppo_normal.policy.parameters_to_vector()

In [482]:
import torch
from torch import Tensor
from typing import List, Optional
from torch.optim import Adam


class CAdam(Adam):
    r"""
    Barebone Adam adaptation with 'step' parameter for each weight in a parameter Tensor, instead of one number for the whole Tensor
    """

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        self._cuda_graph_capture_health_check()

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            params_with_grad = []
            grads = []
            exp_avgs = []
            exp_avg_sqs = []
            max_exp_avg_sqs = []
            state_steps = []
            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    if p.grad.is_sparse:
                        raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                    grads.append(p.grad)

                    state = self.state[p]
                    # Lazy state initialization
                    if len(state) == 0:
                        ### CHANGED ###
                        #state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
                        #    if self.defaults['capturable'] else torch.tensor(0)
                        state['step'] = torch.zeros_like(p, memory_format=torch.preserve_format, device=p.device)
                        ###############
                        
                        
                        # Exponential moving average of gradient values
                        state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format, device=p.device)
                        # Exponential moving average of squared gradient values
                        state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format, device=p.device)
                        if group['amsgrad']:
                            # Maintains max of all exp. moving avg. of sq. grad. values
                            state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                    exp_avgs.append(state['exp_avg'])
                    exp_avg_sqs.append(state['exp_avg_sq'])

                    if group['amsgrad']:
                        max_exp_avg_sqs.append(state['max_exp_avg_sq'])

                    state_steps.append(state['step'])

            cadam(params_with_grad,
                 grads,
                 exp_avgs,
                 exp_avg_sqs,
                 max_exp_avg_sqs,
                 state_steps,
                 amsgrad=group['amsgrad'],
                 beta1=beta1,
                 beta2=beta2,
                 lr=group['lr'],
                 weight_decay=group['weight_decay'],
                 eps=group['eps'],
                 maximize=group['maximize'],
                 foreach=group['foreach'],
                 capturable=group['capturable'])

        return loss


def cadam(params: List[Tensor],
         grads: List[Tensor],
         exp_avgs: List[Tensor],
         exp_avg_sqs: List[Tensor],
         max_exp_avg_sqs: List[Tensor],
         state_steps: List[Tensor],
         # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
         # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
         foreach: bool = None,
         capturable: bool = False,
         *,
         amsgrad: bool,
         beta1: float,
         beta2: float,
         lr: float,
         weight_decay: float,
         eps: float,
         maximize: bool):
    r"""Functional API that performs Adam algorithm computation.
    See :class:`~torch.optim.Adam` for details.
    """

    if not all([isinstance(t, torch.Tensor) for t in state_steps]):
        raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

    if foreach is None:
        # Placeholder for more complex foreach logic to be added when value is not set
        foreach = False

    if foreach and torch.jit.is_scripting():
        raise RuntimeError('torch.jit.script not supported with foreach optimizers')

    if foreach and not torch.jit.is_scripting():
        ### CHANGED ###
        #func = _multi_tensor_adam
        raise NotImplementedError()
        ###############
    else:
        func = _single_tensor_cadam

    func(params,
         grads,
         exp_avgs,
         exp_avg_sqs,
         max_exp_avg_sqs,
         state_steps,
         amsgrad=amsgrad,
         beta1=beta1,
         beta2=beta2,
         lr=lr,
         weight_decay=weight_decay,
         eps=eps,
         maximize=maximize,
         capturable=capturable)

@torch.jit.script
def _single_tensor_cadam(params: List[Tensor],
                        grads: List[Tensor],
                        exp_avgs: List[Tensor],
                        exp_avg_sqs: List[Tensor],
                        max_exp_avg_sqs: List[Tensor],
                        state_steps: List[Tensor],
                        *,
                        amsgrad: bool,
                        beta1: float,
                        beta2: float,
                        lr: float,
                        weight_decay: float,
                        eps: float,
                        maximize: bool,
                        capturable: bool):
    
    for i, param in enumerate(params):

        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step = state_steps[i]

        if capturable:
            assert param.is_cuda and step.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
            
        step.add_(1)
        if weight_decay != 0:
            grad = grad.add(param, alpha=weight_decay)
            
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
        
        bias_correction1 = 1 - torch.pow(beta1, step)
        bias_correction2 = 1 - torch.pow(beta2, step)
        bias_correction2_sqrt = bias_correction2.sqrt()
        
        step_size = lr / bias_correction1
        
        denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add(eps)
        param.sub_((exp_avg / denom).mul(step_size))

In [13]:
from cadam import CAdam

In [40]:
env2 = gym.make('CartPole-v1')
env2.seed(42)
ppo_cppo = PPO(CPPO_Policy, env2, seed=42, policy_kwargs={'optimizer_kwargs':{'eps':1e-5, 'rho': 0, 'm': 1}}) # SB3 sets custom eps for Adam
ppo_cppo._setup_learn(2048, env2)

(2048, <stable_baselines3.common.callbacks.CallbackList at 0x208012fa0a0>)

In [41]:
cppo_original_params = ppo_cppo.policy.parameters_to_vector()

In [42]:
torch.manual_seed(42)
ppo_cppo.rollout_buffer = copy.deepcopy(buffer)
ppo_cppo.train()

In [43]:
cppo_trained_params = ppo_cppo.policy.parameters_to_vector()

In [44]:
(ppo_original_params == cppo_original_params).all()

True

In [45]:
np.allclose(ppo_trained_params, cppo_trained_params)

False

In [46]:
ppo_trained_params

array([-0.04164303,  0.32055363,  0.25014526, ...,  0.20720258,
       -0.16821611,  0.05669835], dtype=float32)

In [47]:
cppo_trained_params

array([-0.04164306,  0.32055348,  0.2501452 , ...,  0.20720248,
       -0.16821593,  0.05669831], dtype=float32)

In [48]:
(ppo_trained_params - cppo_trained_params)[~np.isclose(ppo_trained_params, cppo_trained_params)].max()

1.1846423e-06

In [None]:
ppo_normal.policy.optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-05
    foreach: None
    lr: 0.0003
    maximize: False
    weight_decay: 0
)

In [None]:
ppo_cppo.policy.optimizer

CBP (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-05
    foreach: None
    lr: 0.0003
    maximize: False
    weight_decay: 0
)