# Proximal Operator - test

We compute with tensor the proximal operator associated to hyperslab constraint
in order to include it as an activation function.

In [1]:
import numpy as np
import os
import sys
# pytorch
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch.nn.functional as F
# matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import sys
sys.path.insert(0, '/Users/cdellava/Documents/phd/MyResNet')

for d in sys.path:
    print(d)

/Users/cdellava/Documents/phd/MyResNet
/Users/ceciledv/Documents/phd/OLD
/Library/Frameworks/Python.framework/Versions/3.9/lib/python39.zip
/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9
/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/lib-dynload

/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages
/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/IPython/extensions
/Users/ceciledv/.ipython


In [3]:
# local
# from IPsolver import cardan

In [4]:
class cardan(torch.autograd.Function):  

    @staticmethod
    def forward(ctx,gamma_mu,xtilde,A,im_range,mode_training=True):
        """
	    Finds the solution of the cubic equation involved in the computation of the proximity operator of the 
        logarithmic barrier of the hyperslab constraints (xmin< u^Tx <xmax) using the Cardano formula: x^3+ax^2+bx+c=0 
        is rewritten as x^3+px+q=0. Selects the solution x such that x-a/3 is real and belongs to ]xmin,xmax[.
        Parameters
        ----------
           gamma_mu (torch.FloatTensor): product of the barrier parameter and the stepsize, size n
           xtilde (torch.FloatTensor): point at which the proximity operator is applied, size n
           u_ker (torch.FloatTensor) : kernel conv u^T x_tilde, size n
           im_range (list): minimal and maximal pixel values
           device (string) : 
           mode_training (bool): indicates if the model is in training (True) or testing (False) (default is True)
        Returns
        -------
           sol (torch.FloatTensor): proximity operator of gamma_mu*barrier at xtilde, size n 
        """
        # Device CPU/GPU
        device="cpu"
        if device == "cuda":
            dtype = torch.cuda.FloatTensor
        else :
            dtype = torch.FloatTensor
        #initialize variables
        size              = xtilde.size()
        x1,x2,x3          = torch.zeros(1).type(dtype),torch.zeros(1).type(dtype),torch.zeros(1).type(dtype)   
        crit,crit_compare = torch.zeros(1).type(dtype),torch.zeros(1).type(dtype)
        sol               = torch.zeros(size).type(dtype)
        kappa             = torch.zeros(1).type(dtype)
        xmin,xmax         = im_range
        Atx               = torch.matmul(A,xtilde)
        #set coefficients
        a     = -(xmin+xmax+Atx)
        b     = xmin*xmax + Atx*(xmin+xmax) - 2*gamma_mu*torch.norm(A)**2
        c     = gamma_mu*(xmin+xmax) - Atx*xmin*xmax
        p     = b - (a**2)/3
        q     = c - a*b/3 + 2*(a**3)/27
        delta = (p/3)**3 + (q/2)**2  

        #three cases depending on the sign of delta
        #########################################################################
        #when delta is positive
        if delta>0:
            z1 = -q/2
            z2 = torch.sqrt(delta)
            u  = (z1+z2).sign() * torch.pow((z1+z2).abs(),1/3)
            v  = (z1-z2).sign() * torch.pow((z1-z2).abs(),1/3) 
            x1 = u+v   
            x2 = -(u + v)/2 ; #real part of the complex solution
            x3 = -(u + v)/2 ; #real part of the complex solution
        #########################################################################
        #when delta is 0
        elif delta==0:
            x1 = 3 *q / p 
            x2 = -1.5 * q / p
            x3 = -1.5 * q / p 
        #########################################################################
        #when delta is negative
        elif delta<0:
            cos = (-q/2) * ((27 / torch.pow(p,3)).abs()).sqrt() 
            cos[cos<-1] = 0*cos[cos<-1]-1
            cos[cos>1]  = 0*cos[cos>1]+1
            phi         = torch.acos(cos)
            tau         = 2 * ((p/3).abs()).sqrt() 
            x1     = tau * torch.cos(phi/3) 
            x2     = -tau * torch.cos((phi + np.pi)/3)
            x3     = -tau * torch.cos((phi - np.pi)/3)
        #########################################################################
        x1   = x1-a/3
        x2   = x2-a/3
        x3   = x3-a/3
        # when gamma_mu is very small there might be some numerical instabilities
        # in case there are nan values, we set the corresponding pixels equal to 2*xmax
        # these values will be replaced by valid values at least once
        if (x1!=x1).any():
            x1[x1!=x1]=2*xmax
        if (x2!=x2).any():
            x2[x2!=x2]=2*xmax
        if (x3!=x3).any():
            x3[x3!=x3]=2*xmax
        kappa = x1
        sol  = xtilde + (x1 - Atx)/torch.norm(A)**2*A
        #########################################################################
        #take x1
        p1 = sol
        Atp1 = torch.matmul(A,p1)
        if (Atp1>xmin)&(Atp1<xmax):
            crit[0] = -(torch.log(Atp1-xmin)+torch.log(xmax-Atp1))
            crit = 0.5*torch.norm(p1-xtilde)**2+gamma_mu*crit
        else:
            crit[0] = np.inf
        #########################################################################
        #test x2
        p2 = xtilde + (x2 - Atx)/torch.norm(A)**2*A
        Atp2 = torch.matmul(A,p2)
        if (Atp2 >xmin)&(Atp2 <xmax): 
            crit_compare[0]  = -(torch.log(Atp2-xmin)+torch.log(xmax-Atp2))
            crit_compare  = 0.5*torch.norm(p2-xtilde)**2+gamma_mu*crit_compare
        else:
            crit_compare[0] = np.inf
        if crit_compare<=crit:
            kappa = x2
            sol  = p2
            crit = crit_compare
        #########################################################################
        #test x3
        p3 = xtilde + (x3 - Atx)/torch.norm(A)**2*A
        Atp3 = torch.matmul(A,p3)
        if (Atp3>xmin)&(Atp3<xmax):
            crit_compare[0] = -(torch.log(Atp3-xmin)+torch.log(xmax-Atp3))
            crit_compare = 0.5*torch.norm(p3-xtilde)**2+gamma_mu*crit_compare
        else:
            crit_compare[0] = np.inf
        if crit_compare<=crit:
            kappa = x3
            sol  = p3
            crit = crit_compare
        #########################################################################
        # when gamma_mu is very small and xtilde is very close to one of the bounds,
        # the solution of the cubic equation is not very well estimated -> test xtilde
        # denom = (sol-xmin)*(sol-xmax)-2*gamma_mu -(sol-xtilde)*(xmin+xmax-2*sol)
        if (Atx>xmin)&(Atx<xmax):
            crit_compare = -(torch.log(xmax-Atx)+torch.log(Atx-xmin))
            crit_compare = gamma_mu*crit_compare
        else:
            crit_compare[0] = np.inf
        if crit_compare<crit :
            kappa = Atx
            sol = xtilde
        
        if mode_training==True:
            ctx.save_for_backward(gamma_mu,xtilde,kappa)
        return sol

    @staticmethod
    def backward(ctx, grad_output_var):
        """
        Computes the first derivatives of the proximity operator of the log barrier with respect to x and gamma_mu.
            This method is automatically called by the backward method of the loss function.
        Parameters
        ----------
           ctx (list): list of torch.FloatTensors, variable saved during the forward operation
           grad_output_var (torch.FloatTensor): gradient of the loss wrt the output of cardan
        Returns
        -------
           grad_input_gamma_mu (torch.FloatTensor): gradient of the prox wrt gamma_m 
           grad_input_u (torch.FloatTensor): gradient of the prox wrt x
           None: no gradient wrt u
           None: no gradient wrt the image range
           None: no gradient wrt the mode
        """
        # Device CPU/GPU
        device="cpu"
        if device == "cuda":
            dtype = torch.cuda.FloatTensor
        else :
            dtype = torch.FloatTensor
        #initialize variables
        xmin           = 0
        xmax           = 1
        grad_output    = grad_output_var.data
        gamma_mu, xtilde, kappa = ctx.saved_tensors
        size           = xtilde.size()
        u              = torch.ones(size).type(dtype)
        uTx            = torch.matmul(u,xtilde)
        Id             = torch.eye(size).type(dtype)
        denom          = (kappa-xmin)*(kappa-xmax)\
                         -2*gamma_mu*torch.norm(u)**2\
                         -(kappa-uTx)*(xmin+xmax-2*kappa)
        #
        if denom.abs()<1e-7:
            denom = denom +1
            grad_input_gamma_mu = (2*kappa-(xmin+xmax))/denom*u
            grad_input_x        = Id + ((kappa**2-kappa*(xmin+xmax)+xmin*xmax))/denom/torch.norm(u)**2*torch.matmul(u,u.T)
            # if denom is very small, it means that gamma_mu is very small and sol is very close to one of the bounds,
            # there is a discontinuity when gamma_mu tends to zero, if 0<sol<1 the derivative wrt x is approximately equal to 
            # 1 and the derivative wrt gamma_mu is approximated by 10^5 times the sign of 2*x[1-idx]-(xmin+xmax)
            grad_input_gamma_mu = 0*grad_input_gamma_mu+1e5*torch.sign(2*kappa-(xmin+xmax))
            grad_input_x        = 0*grad_input_x+1
        else:
            grad_input_gamma_mu = (2*kappa-(xmin+xmax))/denom*u
            grad_input_x        = Id + ((kappa**2-kappa*(xmin+xmax)+xmin*xmax))/denom/torch.norm(u)**2*torch.matmul(u,u.T)
        
        
        grad_input_gamma_mu = (grad_input_gamma_mu*grad_output).sum(1).sum(1).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        grad_input_x        = grad_input_x*grad_output
        
        # safety check for numerical instabilities
        if (grad_input_gamma_mu!=grad_input_gamma_mu).any():
            print('there is a nan in grad_input_gamma_mu')
        if (grad_input_x!=grad_input_x).any():
            print('there is a nan in grad_input_u')
            sys.exit()
        
        grad_input_gamma_mu = Variable(grad_input_gamma_mu.type(dtype),requires_grad=True)
        grad_input_x        = Variable(grad_input_x.type(dtype),requires_grad=True)
        
        return grad_input_gamma_mu, grad_input_x, None, None, None

In [5]:
x_tilde = torch.FloatTensor([0.01,2,3.8,23,9,0,2])
x_tilde.requires_grad = True
u = torch.FloatTensor(0.01*np.linspace(0,1,7))
print(torch.matmul(u,x_tilde))

tensor(0.2110, grad_fn=<DotBackward>)


In [6]:
im_range = [0,1]
gamma = torch.tensor(2.0, requires_grad=True)
mu = torch.tensor(2.0, requires_grad=True)

out = cardan.apply(gamma*mu,x_tilde,u,im_range,"cpu")

In [7]:
out

tensor([1.0000e-02, 2.0000e+00, 3.8000e+00, 2.3000e+01, 9.0000e+00, 0.0000e+00,
        2.0000e+00], grad_fn=<cardanBackward>)

In [8]:
out.backward(torch.FloatTensor(0.01*np.linspace(0,1,7)))

ValueError: not enough values to unpack (expected 3, got 0)