In [None]:
from cil.optimisation.operators import LinearOperator, GradientOperator, FiniteDifferenceOperator
from cil.framework import ImageGeometry, BlockGeometry
from cil.optimisation.functions import IndicatorBox, MixedL21Norm

class Divergence2D(LinearOperator):
    
    def __init__(self, domain_geometry, method="forward", bnd_cond="Neumann"):
        
        self.size_dom_gm = len(domain_geometry.shape)               
        self.bnd_cond = bnd_cond 
        
        # Call FiniteDiff operator 
        self.method = method                        
        self.FD = FiniteDifferenceOperator(domain_geometry.geometries[0], direction = 0, method = self.method, bnd_cond = self.bnd_cond)
                    
        if domain_geometry.geometries[0].channels > 1:

            range_geometry = BlockGeometry(*[domain_geometry.geometries[0] for _ in range(domain_geometry.geometries[0].length-1)] )

            if self.size_dom_gm == 4:
                # 3D + Channel
                expected_order = [ImageGeometry.CHANNEL, ImageGeometry.VERTICAL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]

            else:
                # 2D + Channel
                expected_order = [ImageGeometry.CHANNEL, ImageGeometry.HORIZONTAL_Y, ImageGeometry.HORIZONTAL_X]

            order = domain_geometry.geometries[0].get_order_by_label(domain_geometry.geometries[0].dimension_labels, expected_order)

            self.ind = order[1:]        
        
        self.voxel_size_order = domain_geometry.geometries[0].spacing 
        super(Divergence2D, self).__init__(domain_geometry = domain_geometry, 
                                     range_geometry = domain_geometry.geometries[0]) 

    def direct(self, x, out=None):

        if out is not None:

            tmp = self.range_geometry().allocate()            
            for i in range(x.shape[0]):
                self.FD.direction=self.ind[i] 
                self.FD.voxel_size = self.voxel_size_order[i]
                self.FD.adjoint(x.get_item(i), out = tmp)
                if i == 0:
                    out.fill(tmp)
                else:
                    out += tmp
            out*=-1       
        else:            
            tmp = self.range_geometry().allocate()
            for i in range(x.shape[0]):
                self.FD.direction=self.ind[i]
                self.FD.voxel_size = self.voxel_size_order[i]
                tmp += self.FD.adjoint(x.get_item(i))
            return -tmp 
        
    def adjoint(self, x, out=None):        
                
         if out is not None:
            
             for i in range(self.domain_geometry().shape[0]):
                 self.FD.direction = self.ind[i]
                 self.FD.voxel_size = self.voxel_size_order[i]
                 self.FD.direct(x, out = out[i])
             out*=-1    
                
         else:
             tmp = self.domain_geometry().allocate()        
             for i in range(tmp.shape[0]):
                 self.FD.direction = self.ind[i]
                 self.FD.voxel_size = self.voxel_size_order[i]
                 tmp.get_item(i).fill(self.FD.direct(x))
             return -tmp         
                
        
ig1 = ImageGeometry(3,4, channels=2)        
Grad = GradientOperator(ig1)    
x = ig1.allocate('random')
w = Grad.range.allocate('random')
res = Grad.direct(x)

print(res.dot(w))

Div = Divergence2D(Grad.range) 
print(x.dot(Div.direct(w)))

# print(res.get_item(0).as_array())
# print(res1.get_item(0).as_array())
        
    Gradient.adjoint = - Divegnce.direct
    
<\nabla u , w > = < u, -div w>    


# x = ig.allocate('random')
# res = Grad.direct(x)

# Div = Divergence2D(Grad.range) 
# res1 = Div.direct(res)

# res2 = Div.adjoint(res1)

Grad = GradientOperator(ig)
Div = Divergence2D(Grad.range)
alpha = 0.1
from cil.optimisation.functions import OperatorCompositionFunction, ZeroFunction, LeastSquares
FF = OperatorCompositionFunction(0.5*L2NormSquared(b=-noisy_data), Div) - 0.5*L2NormSquared(b=noisy_data)
# FF =   0.5 * LeastSquares( alpha * Div, noisy_data, 1) + 0.5 * L2NormSquared(b=noisy_data) 
GG = alpha * MixedL21Norm()

from cil.optimisation.algorithms import FISTA

x_init = Grad.range.allocate()

fista = FISTA(initial = Grad.range.allocate(), f = FF, operator = Div, g = GG, 
              max_iteration=100, update_objective_interval = 10)
fista.run(verbose=1)