In [7]:
import numpy as np
from PIL import Image
from tqdm import tqdm

def divergence(f):
    """
    Computes the divergence of the vector field f, corresponding to dFx/dx + dFy/dy + ...
    :param f: List of ndarrays, where every item of the list is one dimension of the vector field
    :return: Single ndarray of the same shape as each of the items in f, which corresponds to a scalar field
    """
    num_dims = len(f.shape)-1
    return np.ufunc.reduce(np.add, [np.gradient(f[...,i], axis=i) for i in range(num_dims)])

def step_dual_chan_vese(u,z,image_arr, c1, c2, lambd, step_u,step_z):
    def clip_z(z):
        """ 
        Helper functions to truncate z to have at most unit pointwise norm.
        If z has norm ≤ 1, we do nothing, otherwise we replace it by z/norm(z)
        """
        def criterion(v):
            norm = np.linalg.norm(v)
            if norm > 1:
                return v/norm
            else:
                return v

        return np.apply_along_axis(criterion, -1, z)
        #return z / ((1 + np.maximum(0, np.apply_along_axis(np.linalg.norm, -1, z) - 1))[...,np.newaxis])

    z_update = clip_z(z + step_z * np.stack(np.gradient(u), axis=-1))
    tmp = lambd * ((image_arr - c1)**2 - (image_arr-c2)**2)
    u_update = np.clip(u + step_u*(divergence(z_update) - tmp), 0 , 1)

    return (u_update, z_update)

#def energy(u,c1,c2, image_arr):
#    TV_energy = np.sum(np.apply_along_axis(np.linalg.norm, -1))

In [8]:
bitmap_array = np.array(Image.open("images/dirty.bmp"), dtype=float)

u = np.zeros(bitmap_array.shape, dtype=float)
z = np.zeros(bitmap_array.shape+(len(u.shape),), dtype=float)
c1 = 50
c2 = 200

In [9]:
for i in tqdm(range(1000)):
    (u,z) = step_dual_chan_vese(u,z,bitmap_array, c1,c2, 0.0001, 0.2,0.4)

100%|██████████| 1000/1000 [01:48<00:00,  9.19it/s]


In [11]:
x = np.array(255 * (1-u), dtype=np.uint8)
Image.fromarray(x).show()
Image.open("images/dirty.bmp").show()

In [5]:
x = np.array(   [[[1,2],[3,4]], 
                 [[1,3],[2,4]]])
y = np.apply_along_axis(np.linalg.norm, -1, x)
print(y)
print(x*y[...,np.newaxis])

[[2.23606798 5.        ]
 [3.16227766 4.47213595]]
[[[ 2.23606798  4.47213595]
  [15.         20.        ]]

 [[ 3.16227766  9.48683298]
  [ 8.94427191 17.88854382]]]
