Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RangerVA with GC #39

Open
ryancinsight opened this issue Aug 5, 2020 · 0 comments
Open

RangerVA with GC #39

ryancinsight opened this issue Aug 5, 2020 · 0 comments

Comments

@ryancinsight
Copy link

Hello,

Thank you for your work on these optimizers btw. I was testing a couple out and was performing quite well with the RangerVA originally. Then, when your gradient centralization was added I got further improvements but it also seemed to be overtraining the train set more easily despite using the same parameters. Therefore, I tried to implement combining the gradient centralization into the RangerVA algorithm and so far it seems to be performing quite well and faster since it seems I can use larger batch sizes. I was wondering if you could quickly check, whenever you have some free time, if I implemented correctly in the code below since you are so used to this optimizer.

Best

``
class RangerVA(Optimizer):

def __init__(self, params, lr=1e-3, 
             alpha=0.5, k=6, n_sma_threshhold=5, betas=(.95,0.999), 
             eps=1e-5, weight_decay=0, amsgrad=True, transformer='softplus', smooth=50,
             grad_transformer='square',use_gc=True, gc_conv_only=False):
    #parameter checks
    if not 0.0 <= alpha <= 1.0:
        raise ValueError(f'Invalid slow update rate: {alpha}')
    if not 1 <= k:
        raise ValueError(f'Invalid lookahead steps: {k}')
    if not lr > 0:
        raise ValueError(f'Invalid Learning Rate: {lr}')
    if not eps > 0:
        raise ValueError(f'Invalid eps: {eps}')

    #prep defaults and init torch.optim base
    defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, 
                    n_sma_threshhold=n_sma_threshhold, eps=eps, weight_decay=weight_decay,
                    smooth=smooth, transformer=transformer, grad_transformer=grad_transformer,
                   amsgrad=amsgrad,use_gc=use_gc, gc_conv_only=gc_conv_only )
    super().__init__(params,defaults)

    #adjustable threshold
    self.n_sma_threshhold = n_sma_threshhold   

    #look ahead params
    self.alpha = alpha
    self.k = k 

    #radam buffer for state
    self.radam_buffer = [[None,None,None] for ind in range(10)]
    
    #gc on or off
    self.use_gc=use_gc
    #level of gradient centralization
    self.gc_gradient_threshold = 3 if gc_conv_only else 1
    print(f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
    if (self.use_gc and self.gc_gradient_threshold==1):
        print(f"GC applied to both conv and fc layers")
    elif (self.use_gc and self.gc_gradient_threshold==3):
        print(f"GC applied to conv layers only")


def __setstate__(self, state):
    print("set state called")
    super(RangerVA, self).__setstate__(state)


def step(self, closure=None):
    loss = None
    #Evaluate averages and grad, update param tensors
    for group in self.param_groups:
        for p in group['params']:
            if p.grad is None:
                continue
            grad = p.grad.data.double()
            if grad.is_sparse:
                raise RuntimeError('Ranger optimizer does not support sparse gradients')
            
            amsgrad = group['amsgrad']
            smooth = group['smooth']
            grad_transformer = group['grad_transformer']

            p_data_fp32 = p.data.double()

            state = self.state[p]  #get state dict for this param

            if len(state) == 0:   
                state['step'] = 0
                state['exp_avg'] = torch.zeros_like(p_data_fp32)
                state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                if amsgrad:
                    # Maintains max of all exp. moving avg. of sq. grad. values
                    state['max_exp_avg_sq'] = torch.zeros_like(p.data)                    

                #look ahead weight storage now in state dict 
                state['slow_buffer'] = torch.empty_like(p.data)
                state['slow_buffer'].copy_(p.data)

            else:
                state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
                                  

            #begin computations 
            exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
            beta1, beta2 = group['betas']
            if amsgrad:
                max_exp_avg_sq = state['max_exp_avg_sq']  
                # Maintains the maximum of all 2nd moment running avg. till now
                torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                # Use the max. for normalizing running avg. of gradient
                denomc = max_exp_avg_sq.clone()
            else:
                denomc = exp_avg_sq.clone()
            #GC operation for Conv layers and FC layers       
            if grad.dim() > self.gc_gradient_threshold:                    
                grad.add_(-grad.mean(dim = tuple(range(1,grad.dim())), keepdim = True))

            state['step'] += 1              

            #compute variance mov avg
            exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
            #compute mean moving avg
            exp_avg.mul_(beta1).add_(1 - beta1, grad)
            buffered = self.radam_buffer[int(state['step'] % 10)]
            if state['step'] == buffered[0]:
                N_sma, step_size = buffered[1], buffered[2]
            else:
                buffered[0] = state['step']
                beta2_t = beta2 ** state['step']
                N_sma_max = 2 / (1 - beta2) - 1
                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                buffered[1] = N_sma
                if N_sma > self.n_sma_threshhold:
                    step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                else:
                    step_size = 1.0 / (1 - beta1 ** state['step'])
                buffered[2] = step_size

            
            ##transformer
            if grad_transformer == 'square':
                grad_tmp = grad**2
                denomc.sqrt_() 
            elif grad_transformer == 'abs':
                grad_tmp = grad.abs()


            exp_avg_sq.mul_(beta2).add_((1 - beta2)*grad_tmp)

            if group['weight_decay'] != 0:
                p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
            bias_correction1 = 1 - beta1 ** state['step']
            bias_correction2 = 1 - beta2 ** state['step']
            step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1                

            
            # ...let's use calibrated alr 
            if N_sma > self.n_sma_threshhold:
                if  group['transformer'] =='softplus':
                    sp = torch.nn.Softplus( smooth)
                    denomf = sp( denomc)
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denomf )
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
            else:
                p_data_fp32.add_(-step_size * group['lr'], exp_avg)
            p.data.copy_(p_data_fp32)

            #integrated look ahead...
            #we do it at the param level instead of group level
            if state['step'] % group['k'] == 0:
                slow_p = state['slow_buffer'] #get access to slow param tensor
                slow_p.add_(self.alpha, p.data - slow_p)  #(fast weights - slow weights) * alpha
                p.data.copy_(slow_p)  #copy interpolated weights to RAdam param tensor

    return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant