# Divergence with Rosenbrock

In [None]:
import os

import matplotlib.pyplot as plt
import torch
from mpl_toolkits.mplot3d import Axes3D

The global minimum for the Rosenbrock function, given by the formula:

$f(x, y) = (1 - x)^2 + 100(y - x^2)^2$

occurs at:

$x = 1, \, y = 1$

At this point, the function value is:

$f(1, 1) = (1 - 1)^2 + 100(1 - 1^2)^2 = 0$

So, the global minimum is:

$f(1, 1) = 0$

In [None]:
def rosenbrock(x):
    return (1 - x[0]) ** 2 + 100 * (x[1] - x[0] ** 2) ** 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def gadam(x_init, decay=0.0, lr=0.001, iterations=5000, sample_every=100):
    x = x_init
    optimizer = torch.optim.Adam([x], lr=lr)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda t: 1 / (1 + decay * t))

    x_vals, y_vals, losses = [], [], []

    for t in range(1, iterations + 1):
        optimizer.zero_grad()
        loss = rosenbrock(x)
        loss.backward()
        optimizer.step()
        scheduler.step()

        if t % sample_every == 0:
            x_vals.append(x[0].item())
            y_vals.append(x[1].item())
            losses.append(loss.item())
    return x_vals, y_vals, losses

In [None]:
iterations = int(1e5)
sample_every = int(1e2)
x_init = torch.randn(2, device=device, requires_grad=True)
results_torch = {
    "decay=0": gadam(x_init, decay=0.0, lr=0.0005, iterations=iterations, sample_every=sample_every),
    "decay=5e-5": gadam(x_init, decay=5e-5, lr=0.0005, iterations=iterations, sample_every=sample_every),
    "decay=1e-4": gadam(x_init, decay=1e-4, lr=0.0005, iterations=iterations, sample_every=sample_every),
}

In [None]:
fig = plt.figure(figsize=(16, 8), dpi=150)

# 3D Plot for x and y trajectories
ax1 = fig.add_subplot(1, 2, 1, projection="3d")
for label, (x_vals, y_vals, _) in results_torch.items():
    ax1.plot(x_vals, y_vals, range(len(x_vals)), label=label, lw=1.5)
ax1.scatter(1, 1, 0, color="red", s=100, label="Global Minima (x=1, y=1)")
ax1.set_title("Trajectory in x-y space")
ax1.set_xlabel("x value")
ax1.set_ylabel("y value")
ax1.set_zlabel("Iterations")
ax1.legend()

# 2D Plot for loss over iterations
ax2 = fig.add_subplot(1, 2, 2)
for label, (_, _, losses) in results_torch.items():
    ax2.plot(losses[:25], label=label)
ax2.set_title("Loss over iterations")
ax2.set_xlabel("Iteration (1e2)")
ax2.set_ylabel("Loss")
ax2.grid()
ax2.legend()

plt.tight_layout()
plt.savefig(os.path.join("results", "rosenbrock_adam.png"))
plt.show()