In [None]:
import math
from torch.optim.optimizer import Optimizer


class CRTAdamW(Optimizer):
    r"""
    Centralized-Rectified Time-decay AdamW (CRT-AdamW)

    Args:
        params (iterable): model parameters.
        lr (float): learning-rate η.
        betas (Tuple[float,float]): (β1, β2).
        eps (float): ε for numerical stability.
        weight_decay (float): initial λ₀.
        decay_alpha (float): α in λ_t = λ₀ / (1 + α t)^β.
        decay_beta  (float): β in λ_t = λ₀ / (1 + α t)^β.
        gc_conv_only (bool): if True, apply Gradient Centralization only to
                             parameters with dim > 1 (recommended).
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999),
                 eps=1e-8, weight_decay=1e-2,
                 decay_alpha=1e-4, decay_beta=0.5,
                 gc_conv_only=True):
        if lr <= 0.0:
            raise ValueError(f"Invalid lr: {lr}")
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay,
                        decay_alpha=decay_alpha, decay_beta=decay_beta,
                        gc_conv_only=gc_conv_only)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            beta1, beta2 = group['betas']
            lr, eps = group['lr'], group['eps']
            lam0 = group['weight_decay']
            alpha, beta = group['decay_alpha'], group['decay_beta']
            gc_conv_only = group['gc_conv_only']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad

                # --------- Gradient Centralization -------------
                if grad.ndim > 1 or not gc_conv_only:
                    grad = grad - grad.mean(dim=tuple(range(1, grad.ndim)), keepdim=True)

                state = self.state[p]
                if not state:   # state initialization
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                t = state['step']

                # moments
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # bias corrections
                bias_c1 = 1 - beta1 ** t
                bias_c2 = 1 - beta2 ** t
                m_hat = exp_avg / bias_c1
                v_hat = exp_avg_sq / bias_c2

                # ---------- Rectification (RAdam) ---------------
                rho_inf = 2 / (1 - beta2) - 1
                rho_t = rho_inf - 2 * t * (beta2 ** t) / (1 - beta2 ** t)
                if rho_t > 4:
                    r_t = math.sqrt((rho_t - 4) * (rho_t - 2) * rho_inf /
                                    ((rho_inf - 4) * (rho_inf - 2) * rho_t))
                    update = m_hat * r_t / (torch.sqrt(v_hat) + eps)
                else:
                    update = m_hat

                # ----------- Time-decay weight-decay ------------
                lam_t = lam0 / (1.0 + alpha * t) ** beta

                # parameter update
                p.add_(update, alpha=-lr)
                if lam_t != 0.0:
                    p.add_(p, alpha=-lr * lam_t)

        return loss