In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
class Experiment:
    """
    Experiment class to run optimization experiments with second order methods
    """

    def __init__(
        self,
        g,
        g_grad,
        g_grad_and_hessian,
        dataset,
        optimizer,
        true_theta=None,
        true_hessian=None,
    ):
        """
        Initialize the experiment
        """
        self.true_theta = true_theta
        self.true_hessian = true_hessian
        self.g = g
        self.g_grad = g_grad
        self.g_grad_and_hessian = g_grad_and_hessian
        self.optimizer = optimizer
        self.dataset = dataset
        if true_theta is not None:
            self.true_theta = true_theta
            self.theta_error = [np.dot(theta - true_theta, theta - true_theta)]
        if true_hessian is not None:
            self.true_hessian = true_hessian
            self.hessian_error = [
                np.linalg.norm(self.hessian - true_hessian, ord="fro")
            ]

    def run(self):
        """
        Run the experiment
        """
        if self.theta is None:
            raise ValueError("Initial theta not set")
        self.hessian = np.eye(self.theta.shape)
        self.optimizer.reset_lr()
        for x, y in tqdm(self.dataset):
            self.theta, self.hessian = self.optimizer.step(
                self.theta,
                self.hessian,
                x,
                y,
                self.g,
                self.g_grad,
                self.g_grad_and_hessian,
            )
            # Log parameter error after each update
            if self.true_theta is not None:
                self.theta_error.append(
                    np.dot(self.theta - self.true_theta, self.theta - self.true_theta)
                )
            if self.true_hessian is not None:
                self.hessian_error.append(
                    np.linalg.norm(self.hessian - self.true_hessian, ord="fro")
                )
        self.plot_errors()
        self.theta = None  # Reset theta

    def plot_errors(self):
        """
        Plot the errors
        """
        if self.true_theta is not None:
            plt.plot(self.theta_error)
            plt.title("Parameter error")
            plt.show()
        if self.true_hessian is not None:
            plt.plot(self.hessian_error)
            plt.title("Hessian error (Frobenius norm)")
            plt.show()

    def set_theta(self, theta):
        """
        Set the initial theta
        """
        self.theta = theta

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def create_dataset_logistic(n, theta):
    d = len(theta)
    X = np.random.randn(n, d - 1)
    phi = np.hstack([np.ones((n, 1)), X])
    Y = np.random.binomial(1, sigmoid(phi @ theta))
    return X, Y


X, Y = create_dataset_logistic(n, true_theta)
dataset = zip(X, Y)


def g(h, X, Y):
    n, d = X.shape
    phi = np.hstack([np.ones(n, 1), X])
    dot_product = np.dot(phi, h)
    return np.log(1 + np.exp(dot_product)) - dot_product * Y


def g_grad(h, X, Y):
    n, d = X.shape
    phi = np.hstack([np.ones(n, 1), X])
    dot_product = np.dot(phi, h)
    p = sigmoid(dot_product)
    # grad = (p - Y)[:, np.newaxis] * X
    grad = (p - Y) * X  # Equivalent
    return grad


def g_grad_and_hessian(h, X, Y):
    """
    Compute the gradient and the Hessian of the logistic loss
    Does not work for a batch of data because of the outer product
    """
    n, d = X.shape
    phi = np.hstack([np.ones(n, 1), X])
    dot_product = np.dot(phi, h)
    p = sigmoid(dot_product)
    grad = (p - Y) * X
    hessian = np.einsum("i,ij,ik->ijk", p * (1 - p), X, X)
    return grad, hessian

In [None]:
# test broadcast numpy *
a = np.array([[1], [2]])  # like (p - Y) for a batch of 2 samples
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
# print(a * x)

# test outer product
# print(np.outer(x, x))

# test np.atleast_2d
x = np.atleast_2d(np.array([1, 2, 3]))
# print(x.shape)

# test np.dot
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.array([[1], [2], [3]])
# print(np.dot(x, y))

# test np.einsum
p = np.array([0.5, 0.6])
x = np.array([[1, 2], [3, 4]])
print(np.einsum("i,ij,ik->ijk", p * (1 - p), x, x))
print(p[0] * (1 - p[0]) * np.outer(x[0], x[0]))
print(p[1] * (1 - p[1]) * np.outer(x[1], x[1]))

In [None]:
# Usage example
N = 100
n = 10_000
true_theta = np.array([0, 3, -9, 4, -9, 15, 0, -7, 1, 0])
for i in range(N):
    X, Y = create_dataset_logistic(n, true_theta)
    dataset = zip(X, Y)
    exp = Experiment(
        g, g_grad, g_grad_and_hessian, dataset, optimizer, true_theta=true_theta
    )

    exp.set_theta(np.random.randn(10))
    exp.run()

In [None]:
experiment.run()

In [None]:
import torch
from torch import nn
from torch.autograd import functional as autograd_f


class NewtonOptim(torch.optim.Optimizer):
    def __init__(self, params, lr=1):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = closure()
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                def func(input):
                    return closure()

                hessian = autograd_f.hessian(func, p)

                hessian_inv = torch.inverse(hessian + 1e-5 * torch.eye(hessian.size(0)))

                p.data.sub_(group["lr"] * hessian_inv @ p.grad.data.flatten()).view_as(
                    p
                )

        return loss

In [None]:
# Usage example
theta_true = torch.tensor([1.5, -2.0, 1.0, 0.5, 3.0])
X = torch.randn(10000, 5)
noise = 0.5 * torch.randn(10000)
y = X @ theta_true + noise
dataset = TensorDataset(X, y)
g = nn.Linear(5, 1, bias=False)
criterion = nn.MSELoss()
experiment = Experiment(
    g, dataset, NewtonOptim, lr=0.001, theta_true=theta_true, criterion=criterion
)

In [None]:
def closure():
    y_pred = g(X)
    loss = criterion(y_pred, y.view(-1, 1))
    return loss

In [None]:
for x, y in experiment.dataloader:
    experiment.optimizer.step(closure)

In [None]:
experiment.plot_param_errors()