In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
class Experiment:
    """
    Experiment class for training a model on a dataset using a given optimizer
    """

    def __init__(
        self,
        model,
        dataset,
        optimizer_class,
        lr,
        criterion,
        true_params=None,
        batch_size=1,
    ):
        """
        Parameters:
        model: PyTorch model
        dataset: PyTorch dataset
        optimizer_class: PyTorch optimizer class
        lr: Learning rate
        criterion: Loss function
        true_params: True parameters of the model (optional), used to compute the parameter estimation error
        batch_size: Batch size for training, default is 1 for online learning
        """
        self.model = model
        self.dataset = dataset
        self.batch_size = batch_size
        self.dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        self.optimizer = optimizer_class(model.parameters(), lr=lr)
        self.true_params = true_params
        self.param_errors = []
        self.criterion = criterion

    def param_error(self):
        """
        Compute the squared error between the estimated parameters and the true parameters
        """
        if self.true_params is None:
            raise ValueError("True parameters not provided")
        estimated_params = torch.cat([p.view(-1) for p in self.model.parameters()])
        true_params = torch.cat([p.view(-1) for p in self.true_params])
        error = estimated_params - true_params
        return torch.dot(error, error).item()

    def compute_hessian(self, x, y):
        """
        Compute the Hessian of the loss function with respect to the model parameters
        """
        self.model.eval()
        self.model.zero_grad()

        x = x.view(1, -1)  # Reshape input to 2D tensor, if model expects bathch input
        y_pred = self.model(x)
        loss = self.criterion(y_pred, y.view(-1, 1))
        hessian = torch.autograd.functional.hessian(loss, self.model.parameters())
        return hessian

    def run(self):
        self.model.train()
        for x, y in tqdm(self.dataloader):
            y_pred = self.model(x)
            loss = self.criterion(y_pred, y.view(-1, 1))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Log parameter error after each update
            self.param_errors.append(self.param_error())
        self.plot_param_errors()

    def plot_param_errors(self):
        plt.plot(self.param_errors, label="Parameter Estimation Error")
        plt.xlabel("Sample size")
        plt.ylabel("Error")
        plt.legend()
        plt.show()

In [None]:
# Usage example
true_params = torch.tensor([1.5, -2.0, 1.0, 0.5, 3.0])
X = torch.randn(10000, 5)
noise = 0.5 * torch.randn(10000)
y = X @ true_params + noise
dataset = TensorDataset(X, y)
model = nn.Linear(5, 1, bias=False)
criterion = nn.MSELoss()
experiment = Experiment(
    model, dataset, optim.SGD, lr=0.001, true_params=true_params, criterion=criterion
)

In [None]:
experiment.run()

In [None]:
import torch
from torch import nn
from torch.autograd import functional as autograd_f


class NewtonOptim(torch.optim.Optimizer):
    def __init__(self, params, lr=1):
        defaults = dict(lr=lr)
        super().__init__(params, defaults)

    def step(self, closure=None):
        loss = closure()
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                def func(input):
                    return closure()

                hessian = autograd_f.hessian(func, p)

                hessian_inv = torch.inverse(hessian + 1e-5 * torch.eye(hessian.size(0)))

                p.data.sub_(group["lr"] * hessian_inv @ p.grad.data.flatten()).view_as(
                    p
                )

        return loss

In [None]:
# Usage example
true_params = torch.tensor([1.5, -2.0, 1.0, 0.5, 3.0])
X = torch.randn(10000, 5)
noise = 0.5 * torch.randn(10000)
y = X @ true_params + noise
dataset = TensorDataset(X, y)
model = nn.Linear(5, 1, bias=False)
criterion = nn.MSELoss()
experiment = Experiment(
    model, dataset, NewtonOptim, lr=0.001, true_params=true_params, criterion=criterion
)

In [None]:
def closure():
    y_pred = model(X)
    loss = criterion(y_pred, y.view(-1, 1))
    return loss

In [None]:
for x, y in experiment.dataloader:
    experiment.optimizer.step(closure)

In [None]:
experiment.plot_param_errors()