In [1]:
from collections.abc import Iterable, Sequence
from typing import Literal

import torch


def jacobian(input: Sequence[torch.Tensor], wrt: Sequence[torch.Tensor], create_graph=False):
    flat_input = torch.cat([i.reshape(-1) for i in input])
    return torch.autograd.grad(
        flat_input,
        wrt,
        torch.eye(len(flat_input), device=input[0].device, dtype=input[0].dtype),
        retain_graph=True,
        create_graph=create_graph,
        allow_unused=True,
        is_grads_batched=True,
    )


def make_newton_loss(loss_fn, tik_l: float | Literal['eig'] = 1e-2, use_torch_func = False):

    class NewtonLoss(torch.autograd.Function):

        @staticmethod
        def forward(ctx, preds: torch.Tensor, targets: torch.Tensor):
            with torch.enable_grad():
                # necessary to flatten preds FIRST so they are part of the graph
                preds_flat = preds.ravel()
                value = loss_fn(preds_flat.view_as(preds), targets)

                # caluclate gradient and hessian
                if use_torch_func:
                    H: torch.Tensor = torch.func.hessian(loss_fn, 0)(preds_flat, targets) # pyright:ignore[reportAssignmentType]
                    g = torch.autograd.grad(value, preds)[0]

                else:
                    g = torch.autograd.grad(value, preds_flat, create_graph=True)[0]
                    H: torch.Tensor = jacobian([g], [preds_flat])[0]

            # apply regularization
            if tik_l == 'eig':
                reg = torch.linalg.eigvalsh(H).neg().clip(min=0).max()
            else:
                reg = tik_l

            if reg != 0:
                H.add_(torch.eye(H.size(0), device=H.device, dtype=H.dtype).mul_(reg))

            # newton step
            newton_step, success = torch.linalg.solve_ex(H, g)
            ctx.save_for_backward(newton_step.view_as(preds))

            return value

        @staticmethod
        def backward(ctx, *grad_outputs):
            newton_step = ctx.saved_tensors[0] # inputs to loss
            return newton_step, None

    return NewtonLoss.apply

In [3]:
from monai.losses import DiceFocalLoss
loss = make_newton_loss(DiceFocalLoss(softmax=True), tik_l=1e-2, use_torch_func=False)

input = torch.randn(1,100, requires_grad=True, device='cuda')
l = loss(input, torch.randn(1,100, requires_grad=True, device='cuda'))
print(l)
l.backward()
input.grad


tensor(1.5630, device='cuda:0', grad_fn=<NewtonLossBackward>)


tensor([[-4.9787e-03, -7.1071e-01, -1.5677e+00, -1.0821e+00,  4.2555e-03,
         -2.8702e-02, -3.5220e+00, -2.3090e+00, -6.7767e-01, -4.1828e-01,
         -7.1692e+01, -2.2520e-01,  1.2092e-02, -8.7981e-01, -3.4437e-01,
         -1.0390e-01, -6.4949e-01, -3.3275e-02, -2.7531e-01, -1.3260e+00,
         -2.1608e-01, -4.9054e-01, -3.6699e-01, -1.3111e+00, -5.2593e-01,
         -4.4454e+00, -8.3285e-01, -1.0195e+00,  6.5973e-02, -6.4688e-01,
         -3.4707e-01, -2.1931e-01, -4.2201e-01, -3.0579e-01, -1.2811e+00,
         -1.6409e-01, -1.0673e+00,  1.6529e-01, -3.9110e-01, -1.0196e+01,
         -3.1219e-01, -2.7379e+00, -5.4618e-02, -4.0282e-01, -6.6204e-01,
         -8.9593e-01, -5.6248e-01, -2.8347e+00, -1.1034e-01, -7.2864e-01,
         -1.0816e+00, -1.6863e+00,  2.2045e+02, -6.3646e-01, -2.9563e+00,
         -2.9134e-01, -4.3018e-01,  1.7629e-01, -4.8417e-01, -3.5092e-01,
         -4.5925e-01, -3.6215e-01, -2.0373e+00, -1.1797e+00, -5.9761e+00,
         -3.3980e+00,  5.6844e-01, -1.