In [1]:
import numpy as np

class NesterovOptimizer:
    def __init__(self, learning_rate, momentum):
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.v = None
        self.x = None

    def init_params(self, params, verbose):
        self.v = np.zeros_like(params)
        self.x = np.zeros_like(params)
        self.verbose = verbose

    def update(self, params, grad):
        self.v = self.momentum * self.v - self.learning_rate * grad
        x_new = self.x + self.v
        self.x = x_new
        if self.verbose:
            print(f'params: {params}, x: {self.x}, v: {self.v}')
        return self.x

In [5]:

# Define a simple quadratic function
def f(x):
  return np.sum(x**2 + 2*x)

# Initialize the Nesterov optimizer
optimizer = NesterovOptimizer(learning_rate=0.1, momentum=0.9)

# Initialize the parameters and gradient
params = np.random.rand(3)  # start with random values
grad = np.ones_like(params)  # start with constant gradient

# Initialize the optimizer
optimizer.init_params(params, verbose=True)

# Perform optimization
for i in range(100):
  params = optimizer.update(params, grad)
  grad = -2 * params  # update the gradient

print(f'Final params: {params}')

params: [0.31179588 0.69634349 0.37775184], x: [-0.1 -0.1 -0.1], v: [-0.1 -0.1 -0.1]
params: [-0.1 -0.1 -0.1], x: [-0.21 -0.21 -0.21], v: [-0.11 -0.11 -0.11]
params: [-0.21 -0.21 -0.21], x: [-0.351 -0.351 -0.351], v: [-0.141 -0.141 -0.141]
params: [-0.351 -0.351 -0.351], x: [-0.5481 -0.5481 -0.5481], v: [-0.1971 -0.1971 -0.1971]
params: [-0.5481 -0.5481 -0.5481], x: [-0.83511 -0.83511 -0.83511], v: [-0.28701 -0.28701 -0.28701]
params: [-0.83511 -0.83511 -0.83511], x: [-1.260441 -1.260441 -1.260441], v: [-0.425331 -0.425331 -0.425331]
params: [-1.260441 -1.260441 -1.260441], x: [-1.8953271 -1.8953271 -1.8953271], v: [-0.6348861 -0.6348861 -0.6348861]
params: [-1.8953271 -1.8953271 -1.8953271], x: [-2.84579001 -2.84579001 -2.84579001], v: [-0.95046291 -0.95046291 -0.95046291]
params: [-2.84579001 -2.84579001 -2.84579001], x: [-4.27036463 -4.27036463 -4.27036463], v: [-1.42457462 -1.42457462 -1.42457462]
params: [-4.27036463 -4.27036463 -4.27036463], x: [-6.40655472 -6.40655472 -6.4065547