In [7]:
import torch
from torch.optim import Optimizer
import time

In [9]:
class MarquardtLevenberg(Optimizer):
    def __init__(self, params, lr=1e-3, lambd=1e-3):
        defaults = dict(lr=lr, lambd=lambd)
        super(MarquardtLevenberg, self).__init__(params, defaults)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            lambd = group['lambd']

            for p in group['params']:
                if p.grad is None:
                    continue

                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('MarquardtLevenberg does not support sparse gradients')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['prev_grad'] = torch.zeros_like(p.data)

                prev_grad = state['prev_grad']
                state['step'] += 1

                # Marquardt-Levenberg update
                A = grad.matmul(grad.t()) + lambd * torch.eye(grad.size(0), device=grad.device)
                g = grad

                # Measure time for matrix inversion
                start_inv = time.time()
                H_inv = torch.inverse(A)
                end_inv = time.time()

                # Measure time for matrix multiplication
                start_mul = time.time()
                delta = H_inv.matmul(g)
                end_mul = time.time()

                # Update parameter
                p.data.add_(-lr * delta)

                # Save timing information
                if 'time_inv' not in state:
                    state['time_inv'] = 0
                    state['time_mul'] = 0
                    state['time_other'] = 0

                state['time_inv'] += (end_inv - start_inv)
                state['time_mul'] += (end_mul - start_mul)
                state['time_other'] += (end_inv - start_inv) + (end_mul - start_mul)

                state['prev_grad'].copy_(grad)

        return loss

    def profile(self):
        inv_time = 0
        mul_time = 0
        other_time = 0
        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if 'time_inv' in state:
                    inv_time += state['time_inv']
                    mul_time += state['time_mul']
                    other_time += state['time_other']

        total_time = inv_time + mul_time + other_time
        print(f"Matrix inversion time: {100 * inv_time / total_time:.2f}%")
        print(f"Matrix multiplication time: {100 * mul_time / total_time:.2f}%")
        print(f"Other operations time: {100 * other_time / total_time:.2f}%")

In [12]:

model = torch.nn.Linear(10, 2)
optimizer = MarquardtLevenberg(model.parameters(), lr=1e-3, lambd=1e-2)

# data
input = torch.randn(5, 10)
target = torch.randn(5, 2)

criterion = torch.nn.MSELoss()

for epoch in range(10):
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        return loss

    optimizer.step(closure)

optimizer.profile()

Matrix inversion time: 42.60%
Matrix multiplication time: 7.40%
Other operations time: 50.00%
