In [1]:
import torch as t

import sys 
sys.path.append('../tests')

import test_optimizers

In [2]:
from typing import Iterable


class LAMBSimple:

    def __init__(
        self, params: Iterable[t.nn.parameter.Parameter], lr: float,
        betas: tuple[float, float], eps: float, weight_decay: float):
        '''Implements Layer-wise Adaptive Moments optimizer for Batch training.
        
        Accepts parameter iterables.

        Args:

        Returns:
        
        '''
        self.eta = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lam = weight_decay
        self.thetas = list(params)
        #self.phi = scale_func
        self.previous_m_t = [t.zeros_like(p) for p in self.thetas]
        self.previous_v_t = [t.zeros_like(p) for p in self.thetas]
        self.t = 1

    def zero_grad(self):
        for param in self.thetas:
            param.grad = t.zeros_like(param)

    #@t.inference_mode()
    #def lamb(self, param_group):
    #    pass
    
    @t.inference_mode()
    def step(self):
        # each "param" represents a matrix of the parameters in a layer
        for i, param in enumerate(self.thetas):
            # maximize=false, so:
            g_t = param.grad

            m_t = self.beta1 * self.previous_m_t[i] + (1 - self.beta1) * g_t
            v_t = self.beta2 * self.previous_v_t[i] + (1 - self.beta2) * g_t ** 2

            m_t = m_t / (1 - self.beta1 ** self.t)
            v_t = v_t / (1 - self.beta2 ** self.t)

            r_t = m_t / (v_t ** 0.5 + self.eps)

            #print(f"param shape: {param.shape} m_t shape: {m_t.shape} v_t shape: {v_t.shape} r_t shape: {r_t.shape}")

            param_norm = t.linalg.norm(param.detach(), dim=None, ord=2)
            #print(f"param_norm shape: {param_norm.shape}")

            if self.lam != 0:
                param = self.lam * param
                #print(f"param shape after wd: {param.shape}")

            update_norm = t.linalg.norm((r_t + param), dim=None, ord=2)

            #print(f"update_norm shape: {update_norm.shape}")
            #print(f"param_norm: {param_norm} update_norm: {update_norm}")

            r = param_norm / update_norm

            #print(f"r shape: {r.shape}")

            eta = r * self.eta
            
            #print(f"eta shape: {eta.shape}")
            param -= eta * (r_t + param)
            
            self.previous_m_t[i] = m_t
            self.previous_v_t[i] = v_t
        self.t += 1

    def __repr__(self) -> str:
        # Should return something reasonable here, e.g. "SGD(lr=lr, ...)"
        return f"lr={self.gamma}, momentum={self.mu}, weight_decay={self.lam}"

In [3]:
test_optimizers.test_lamb(LAMBSimple)


Testing configuration:  {'lr': 0.1, 'betas': (0.8, 0.95), 'eps': 0.001, 'weight_decay': 0.0}
actual: Parameter containing:
tensor([[ 0.0476,  0.6076],
        [-0.7025,  0.3038],
        [-0.0951,  0.2789],
        [ 0.6028, -0.1554],
        [ 0.2378,  0.8149],
        [ 0.3952,  0.7227],
        [-0.4194,  0.5829],
        [ 0.7097, -0.2370],
        [ 0.0316, -0.6797],
        [ 0.4186, -0.7241],
        [-0.3617,  0.6916],
        [ 0.0441,  0.1000],
        [ 0.5072, -0.0614],
        [ 0.7546, -0.2747],
        [-0.7075, -0.3306],
        [ 0.0702,  0.2893],
        [-0.3142, -0.5099],
        [ 0.6492,  0.6795],
        [ 0.6989,  0.7466],
        [ 0.0683, -0.1927],
        [-0.2699, -0.7493],
        [ 0.3154, -0.3336],
        [-0.6454,  0.4441],
        [-0.0249,  0.0024],
        [ 0.1194,  0.0896],
        [ 0.8464, -0.0952],
        [ 0.4026, -0.0296],
        [-0.5598,  0.0045],
        [ 0.4658, -0.8837],
        [-0.2521,  0.3666],
        [ 0.1568,  0.3738],
        

AssertionError: Tensor-likes are not close!

Mismatched elements: 64 / 64 (100.0%)
Greatest absolute difference: 0.29106128215789795 at index (17, 1) (up to 1e-05 allowed)
Greatest relative difference: 4.133229074600336 at index (0, 0) (up to 0 allowed)