# Tutorial on PnP-LADMM

#### Imports

In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt

from skimage.io import imread, imsave

import torch
from torchmetrics.image import TotalVariation
from torchmetrics.functional.image import image_gradients

from metricas import *
from degradacion import *

#### Data-loading

In [2]:
def T_degradacion(x):

    return 2*x

In [3]:
y = torch.tensor(np.array([2,4]), dtype=torch.float64)

In [4]:
y

tensor([2., 4.], dtype=torch.float64)

#### Creation of the model

In [55]:
lam = 1
sigma_energia = 1/np.sqrt(2)

tv = TotalVariation()

def f_energia(x):

    return (1/(2*sigma_energia**2)) * torch.norm(y - T_degradacion(x))**2

def TV(x):

    x = x.unsqueeze(0).unsqueeze(0)
    
    return lam * tv(x)

def L1(x, eps=1e-5):

    resu = torch.sum( torch.square(x) )

    return lam * (resu + eps)**(1/2)

def L2(x):

    return lam * torch.norm(x, p=2)**2

In [56]:
def grad_L2(x):

    return 2*lam*x

def grad_L1(x, eps=1e-5):

    return (lam / L1(x, eps)) * x  

def grad_TV(x, eps = 0.1):
    
    x = x.unsqueeze(0).unsqueeze(0)

    grad_x = image_gradients(x)

    norma_gradx = torch.norm(grad_x[0][0,0,:,:])
    norma_grady = torch.norm(grad_x[1][0,0,:,:])

    denominador = torch.sqrt(eps**2 + norma_gradx**2 + norma_grady**2).item()

    cociente_x = grad_x[0] / denominador
    cociente_y = grad_x[1] / denominador
    
    derivada_cociente_x = image_gradients(cociente_x)
    derivada_cociente_y = image_gradients(cociente_y)

    gradiente_tv = derivada_cociente_x[0][0,0,:,:] + derivada_cociente_y[1][0,0,:,:]

    return lam * gradiente_tv

In [57]:
f = f_energia
h = L2

grad_h = grad_L2

cte_h_lipz = 8

In [58]:
def minimizeViaTorch(funcion_objetivo, x0, lr=0.1, max_iter=10_000, eps=1e-6):

    params = x0
    params.requires_grad_()
    optimizer = torch.optim.Adam([params], lr=lr)

    diff = torch.inf

    k = 0

    while (k < max_iter) and (diff > eps):

        print(f"k: {k}, diff: {diff}", end="\r", flush=True)

        params_ant = params.detach().clone()

        optimizer.zero_grad()
        loss = funcion_objetivo(params)
        loss.backward(retain_graph=True)
        optimizer.step()

        params_sig = params.detach().clone()

        diff = torch.norm(params_sig - params_ant)

        k += 1

    print("\n\n")

    return params.detach().clone()

In [59]:
def f_total(x):

    return f(x) + h(x)

In [60]:
f_total(torch.tensor([1,2], dtype=torch.float64))

tensor(5.0000, dtype=torch.float64)

In [61]:
def PROX_function(x, v, param, func):

    return func(x) + (1/(2*param)) * torch.norm(x - v)**2

#### Run model

ATENCIÓN: La siguiente celda puede llegar a demorar más de media hora dependiendo del hardware

In [62]:
x_aux = minimizeViaTorch(f_total, x0=y.clone(), lr=0.1)

k: 226, diff: 1.4810723371536859e-06




In [63]:
print(y)

print()

print(x_aux)

tensor([2., 4.], dtype=torch.float64)

tensor([0.8000, 1.6000], dtype=torch.float64)


In [64]:
eps = 1e-6
max_iter = 50

Lz = round(cte_h_lipz + cte_h_lipz**2 + 3) + 1
beta = max( cte_h_lipz + Lz + 2 , 3*(Lz**2 + cte_h_lipz**2)/((Lz + cte_h_lipz**2)/2) , 3*Lz**2) + 1
Lx = round(beta + 6*cte_h_lipz**2 + 1) + 1

diff_x = torch.inf
diff_z = torch.inf
diff_gamma = torch.inf

#x = torch.zeros_like(x_real)
#z = torch.ones_like(x_real)
x = y.clone()
z = x.clone()

Nz = z.shape[0]

gamma = torch.ones(x.shape)

mat_inv = torch.from_numpy( np.linalg.inv(Lz + beta*np.identity(Nz)))

nIter = 1
while (nIter <= max_iter) and (max(diff_x, diff_z, diff_gamma) > eps):

    print(f"Iteracion: {nIter}/{max_iter}", end="\r", flush=True)

    print(f"\ndiff_x: {diff_x}")
    print(f"diff_z: {diff_z}")
    print(f"diff_gamma: {diff_gamma}\n")

    if diff_x > 0:

        def PROX_function(x_param):

            return f(x_param) + (1/(2*(1/Lx))) * torch.norm(x_param - (x - (1/Lx)*(gamma + beta*(x - z))))**2
        
        
        def actualizacion_x(x_param):

            return f(x_param) + torch.sum(gamma * x_param) + (Lx/2) * torch.norm(x_param - x)**2 + torch.sum((x_param - x) * (beta*(x - z)))
    
        # Actualizacion de x. Elegir alguna de las dos funciones de arriba
        x_sig = minimizeViaTorch(actualizacion_x, x0=torch.ones(x.shape, dtype=torch.float64)*0.5, lr=0.01, max_iter=1e6, eps=1e-6)

    # Actualizacion de z
    z_sig = mat_inv @ (Lz*z - grad_h(z) + gamma + beta*x_sig)
    
    # Actualizacion de gamma
    gamma_sig = gamma + beta*(x_sig - z_sig)

    # Recalculo diferencias entre actualizaciones
    diff_x = torch.norm(x_sig - x)
    diff_z = torch.norm(z_sig - z)
    diff_gamma = torch.norm(gamma_sig - gamma)

    # Asigno la actualizacion
    gamma = gamma_sig.clone()
    x = x_sig.clone()
    z = z_sig.clone()

    nIter += 1

print("\nTermino!")

Iteracion: 1/50
diff_x: inf
diff_z: inf
diff_gamma: inf

k: 1369, diff: 1.0036574886740368e-06


Iteracion: 2/50
diff_x: 0.0011472382195004717
diff_z: 0.020717950243256355
diff_gamma: 342.84556147859973

k: 1362, diff: 1.0033751163263105e-06


Iteracion: 3/50
diff_x: 0.0396168464390419
diff_z: 0.01976631058757207
diff_gamma: 1.475379930272735

k: 1358, diff: 1.008126108814622e-066


Iteracion: 4/50
diff_x: 0.020119560087151978
diff_z: 0.02011282388226064
diff_gamma: 1.5825332614240002

k: 1354, diff: 1.0133272909662594e-06


Iteracion: 5/50
diff_x: 0.020027350350374524
diff_z: 0.02002961105313394
diff_gamma: 1.5491847227790452

k: 1351, diff: 1.0053058083769884e-06


Iteracion: 6/50
diff_x: 0.01994583326908467
diff_z: 0.019946262013961363
diff_gamma: 1.542564902646164

k: 1347, diff: 1.0112738819678668e-06


Iteracion: 7/50
diff_x: 0.019863575943715746
diff_z: 0.019863993508171765
diff_gamma: 1.5361749550130714

k: 1344, diff: 1.0039880833367931e-06


Iteracion: 8/50
diff_x: 0.01978140

In [67]:
x

tensor([1.1708, 3.6004], dtype=torch.float64)

In [68]:
z

tensor([1.1708, 3.6005], dtype=torch.float64)