In [2]:
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from torch.nn.modules.loss import _Loss
from math import ceil
import os

In [3]:
# Physical data
l = 200
tau = 100
dep = 2
# Numerical data
nx = 200
dx = l/(nx+1)
nt = 200
dt = tau/nt
T_operator = 1/100*dx*np.tri(nt, nx, 0, dtype=int)
# Data sample
nsamp = 400
x_dagger = np.zeros((nsamp,nx))
y = np.zeros((nsamp,nt))
x_grid = np.linspace(0,l,nx)
#
x_sample = np.zeros((nsamp,nx))
#
for i in range(0,nsamp):
    mu = l/2
    sigma = 0.1
    x_dagger[i] = (sigma*np.sqrt(2*np.pi))**-1*np.exp(-(x_grid-mu)**2/2*sigma**2)
    x_dagger[i] = x_dagger[i]/x_dagger[i].sum()
    y[i] = T_operator.dot(x_dagger[i]) 
    xi = np.random.uniform(-0.005,0.005,nt)
    y[i] += xi*np.linalg.norm(y[i])/np.linalg.norm(xi)
    x_sample[i] = np.transpose(T_operator).dot(y[i])

In [4]:
# Regularisation operator
D_op = np.diag(np.ones(nx-1),1)+ np.diag(np.ones(nx-1),-1)-2*np.eye(nx)
T_op = 1/100*dx*np.tri(nt, nx, 0, dtype=int)

In [5]:
class IPIter(torch.nn.Module):
    """
    Computes the proximal interior point iteration.
    Attributes
    ----------
        D  (torch.FloatTensor): operators for the gradient
        Dt (torch.FloatTensor): transposed  gradient operator
        HtH (torch.FloatTensor): operator HtH with H the moment operator
        im_range (list): minimal and maximal pixel values
    """
    def __init__(self,im_range,HtH,D,Dt,dtype):
        """
        Parameters
        ----------
        im_range              (list): minimal and maximal pixel values
        D  (torch.FloatTensor): operators for the gradient
        Dt (torch.FloatTensor): transposed  gradient operator
        HtH (torch.FloatTensor): operator HtH with H the moment operator
        dtype                       : data type
        """
        super(IPIter, self).__init__()
        self.D       = D
        self.Dt      = Dt
        self.HtH      = HtH
        self.im_range = im_range
        
    def Grad(self, reg, x, Ht_y):
        """
        Computes the gradient of the smooth term in the objective function (data fidelity + regularization).
        Parameters
        ----------
      	    reg  (torch.FloatTensor): regularization parameter, size n*1*1*1
            x    (torch.nn.Tensor): images, size n*c*h*w
            Ht_y (torch.nn.Tensor):result of Ht applied to the degraded images, size n*c*h*w
        Returns
        -------
       	    (torch.FloatTensor): gradient of the smooth term in the cost function, size n*c*h*w
        """
        Dx = torch.matmul(self.D,x)
        DtDx    = torch.matmul(self.Dt,Dx)
        return  torch.matmul(self.HtH,x) - Ht_y + reg * DtDx

    def forward(self,gamma,mu,reg_mul,reg_constant,delta,x,Ht_y,\
                std_approx,mode):
        """
        Computes the proximal interior point iteration.
        Parameters
        ----------
      	    gamma                 (torch.nn.FloatTensor): gradient descent stepsize, size 1
            mu                    (torch.nn.FloatTensor): barrier parameter, size n*1*1
            reg_mul, reg_constant (torch.nn.FloatTensor): parameters involved in the hidden layer used to estimate 
                                                          the regularization parameter, size 1
            x                     (torch.nn.FloatTensor): images from previous iteration, size n*l*1
            Ht_y          (torch.nn.FloatTensor): Ht*degraded images, size n*l*1
            std_approx            (torch.nn.FloatTensor): approximation of the noise standard deviation, size n*1
            mode                                  (bool): True if training mode, False else 
        Returns
        -------
       	    (torch.FloatTensor): next proximal interior point iterate, n*c*h*w
        """
        Dx         = torch.cat((self.Dv(x),self.Dh(x)),1)
        avg        = Dx.mean(-1).mean(-1).mean(-1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
        reg        = reg_mul*std_approx/(torch.sqrt(((Dx-avg)**2).mean(-1).mean(-1).mean(-1))+reg_constant)
        x_tilde    = x - gamma*self.Grad(reg.unsqueeze(1).unsqueeze(2).unsqueeze(3), delta, x, Ht_x_blurred)
        return cardan.apply(gamma*mu,x_tilde,self.im_range,mode)
    

        if save_gamma_mu_lambda!='no':
            #write the value of the stepsize in a file
            file = open(os.path.join(save_gamma_mu_lambda,'gamma.txt'), "a")
            file.write('\n'+'%.3e'%gamma.data.cpu())
            file.close()
            #write the value of the barrier parameter in a file
            file = open(os.path.join(save_gamma_mu_lambda,'mu.txt'), "a")
            file.write('\n'+'%.3e'%mu.data.cpu())
            file.close()
            #write the value of the regularization parameter in a file
            file = open(os.path.join(save_gamma_mu_lambda,'lambda.txt'), "a")
            file.write('\n'+'%.3e'%reg.data.cpu())
            file.close()