In [53]:
import torch
import math

In [44]:
def d_p(x, p):
    '''
    Prox-function
    '''
    return 1 / (p + 1) * x**(p + 1)


def hessian_vector_product(f, x, v):
    '''
    Hessian-vector product: D^2(f) @ v
    '''
    grad_f = torch.autograd.grad([f(x)], [x], create_graph=True)[0]
    z = grad_f @ v
    z.backward()
    return x.grad


def BDGM(f, x_tilda_k, delta, L3):
    
    grad_f_x_tilda_k = torch.autograd.grad([f(x_tilda_k)], [x_tilda_k], create_graph=True)[0]
    
    z_0 = x_tilda_k
    tau = 3 * delta / (8 * (2 + math.sqrt(2)) * torch.norm(grad_f))
    
    def D2v(z):
        return hessian_vector_product(f, x_tilda_k, z - x_tilda_k)
    
    def rho_k(z):
        '''
        Scaling function
        '''
        return .5 * D2v(f, x_tilda_k, z - x_tilda_k) @ (z - x_tilda_k) \
               + L3 * d_p(z - x_tilda_k, 4)
    
    def beta_rho_k(z_i, z):
        '''
        Bregman distance
        '''
        grad_rho_k = torch.autograd.grad([rho_k(z)])
        return rho_k(z) - rho_k(z_i) - grad_rho_k @ (z - z_i)
    
    def g_x_tilda_k_tau(z):
        grad_g_p = torch.autograd.grad([f(x_tilda_k + tau * (z - x_tilda_k))], [x], create_graph=True)[0]
        grad_g_n = torch.autograd.grad([f(x_tilda_k - tau * (z - x_tilda_k))], [x], create_graph=True)[0]
        return 1 / tau**2 * (grad_g_p + grad_g_n - 2 * grad_f_x_tilda_k)
    
    def g_phi_k_tau(z):
        return grad_f_x_tilda_k + D2v(z) + g_x_tilda_k_tau(z) + L3 * ((z - x_tilda_k)**2).sum() * (z - x_tilda_k)
    
    i = 0
    z_i = z_0
    while True:
        g_phi_k_tau_z_i = g_phi_k_tau(z_i)
        grad_f_z_i = torch.autograd.grad([f(z_i)], [z_i], create_graph=True)[0]
        if torch.norm(g_phi_k_tau_z_i) < 1 / 6 * torch.norm(grad_f_z_i) - delta:
            break
        else:
            
        i += 1

In [45]:
v = torch.Tensor([1, 1])
v.requires_grad_()

x = torch.Tensor([0.1, 0.1])
x.requires_grad_()

tensor([0.1000, 0.1000], requires_grad=True)

In [46]:
def f(x):
    return 3 * x[0] ** 2 + 4 * x[0] * x[1] + x[1] **2

In [47]:
D2v(f, x, v)

tensor([10.,  6.])

In [49]:
v = torch.Tensor([1, 1])
v.requires_grad_()

x = torch.Tensor([0.1, 0.1])
x.requires_grad_()

f = 3 * x[0] ** 2 + 4 * x[0] * x[1] + x[1] **2

grad_f, = torch.autograd.grad([f], [x], create_graph=True)
z = grad_f @ v
z.backward()
x.grad

tensor([10.,  6.])

In [30]:
f

tensor(0.0800, grad_fn=<AddBackward0>)

In [31]:
grad_f

(tensor([1.0000, 0.6000], grad_fn=<AddBackward0>),)

In [10]:
v

tensor([1., 1.], requires_grad=True)

In [11]:
z

tensor(1.6000, grad_fn=<DotBackward>)

In [52]:
v @ grad_f

tensor(1.6000, grad_fn=<DotBackward>)