In [1]:
import numpy as np
import pylab
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy import ndimage
# general.py TODO: try to not make use of transform.resize
from skimage import transform ## commented


def crop(M, n=None):
    """ crop - crop an image to reduce its size
    Only crops square black and white images for now.
    """
    # Check that image is square and black and white (no channels)
    assert M.shape[0] == M.shape[1]
    assert len(M.shape) == 2

    n0 = M.shape[0]

    if n is None:
        n = (n0/2)
    # Start and end of selection
    start_ind = int(np.floor((n0 - n) / 2))
    end_ind = int(-np.ceil((n0 - n) / 2))

    return M[start_ind:end_ind, start_ind:end_ind]


def circshift(x, p):
    """
        Circular shift of an array.
    """
    y = x.copy()
    y = np.concatenate((y[p[0]::, :], y[:p[0]:, :]), axis=0)
    if x.shape[1] > 0 and len(p) > 1:
        y = np.concatenate((y[:, p[0]::], y[:, :p[0]:]), axis=1)
    return y

def circshift1d(x, k):
    """
        Circularly shift a 1D vector
    """
    return np.roll(x, -k, axis=0)

def clamp(x, a=[], b=[]):
    """
     clamp - clamp a value

       y = clamp(x,a,b);

     Default is [a,b]=[0,1].

       Copyright (c) 2004 Gabriel Peyre
    """

    if a == []:
        a = 0.0
    if b == []:
        b = 1.0
    return np.minimum(np.maximum(x, a), b)

def rescale(f,a=0,b=1):
    """
        Rescale linearly the dynamic of a vector to fit within a range [a,b]
    """
    v = f.max() - f.min()
    g = (f - f.min()).copy()
    if v > 0:
        g = g / v
    return a + g*(b-a)

def reverse(x):
    """
        Reverse a vector.
    """
    return x[::-1]


def bilinear_interpolate(im, x, y):
    x = np.asarray(x)
    y = np.asarray(y)

    x0 = np.floor(x).astype(int)
    x1 = x0 + 1
    y0 = np.floor(y).astype(int)
    y1 = y0 + 1

    x0 = np.clip(x0, 0, im.shape[1]-1);
    x1 = np.clip(x1, 0, im.shape[1]-1);
    y0 = np.clip(y0, 0, im.shape[0]-1);
    y1 = np.clip(y1, 0, im.shape[0]-1);

    Ia = im[ y0, x0 ]
    Ib = im[ y1, x0 ]
    Ic = im[ y0, x1 ]
    Id = im[ y1, x1 ]

    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    return wa*Ia + wb*Ib + wc*Ic + wd*Id

def cconv(x, h, d):
    """
        Circular convolution along dimension d.
        h should be small and with odd size
    """
    if d == 2:
        # apply to transposed matrix
        return np.transpose(cconv(np.transpose(x), h, 1))
    y = np.zeros(x.shape)
    p = len(h)
    pc = int(round( float((p - 1) / 2 )))
    for i in range(0, p):
        y = y + h[i] * circshift1d(x, i - pc)
    return y

def div(g):
    """
        Compute a finite difference approximation of the gradient of a 2D vector field, assuming periodic BC.
    """
    S = g.shape;
    s0 = np.concatenate( ([S[0]-1], np.arange(0,S[0]-1)) )
    s1 = np.concatenate( ([S[1]-1], np.arange(0,S[1]-1)) )
    f = (g[:,:,0] - g[s0,:,0]) + (g[:,:,1]-g[:,s1,1])
    return f

def gaussian_blur(f, sigma):

    """ gaussian_blur - gaussian blurs an image
    %
    %   M = perform_blurring(M, sigma, options);
    %
    %   M is the original data
    %   sigma is the std of the Gaussian blur (in pixels)
    %
    %   Copyright (c) 2007 Gabriel Peyre
    """
    if sigma<=0:
        return;
    n = max(f.shape);
    t = np.concatenate( (np.arange(0,n/2+1), np.arange(-n/2,-1)) )
    [Y,X] = np.meshgrid(t,t)
    h = np.exp( -(X**2+Y**2)/(2.0*float(sigma)**2) )
    h = h/np.sum(h)
    return np.real( pylab.ifft2(pylab.fft2(f) * pylab.fft2(h)) )

def grad(f):
    """
        Compute a finite difference approximation of the gradient of a 2D image, assuming periodic BC.
    """
    S = f.shape;
#   g = np.zeros([n[0], n[1], 2]);
    s0 = np.concatenate( (np.arange(1,S[0]),[0]) )
    s1 = np.concatenate( (np.arange(1,S[1]),[0]) )
    g = np.dstack( (f[s0,:] - f, f[:,s1] - f))
    return g

def imageplot(f, str='', sbpt=[]):
    """
        Use nearest neighbor interpolation for the display.
    """
    if sbpt != []:
        plt.subplot(sbpt[0], sbpt[1], sbpt[2])
    imgplot = plt.imshow(f, interpolation='nearest')
    imgplot.set_cmap('gray')
    pylab.axis('off')
    if str != '':
        plt.title(str)

def load_image(name, n=-1, flatten=1, resc=1, grayscale=1):
    """
        Load an image from a file, rescale its dynamic to [0,1], turn it into a grayscale image
        and resize it to size n x n.
    """
    f = plt.imread(name)
    # turn into normalized grayscale image
    if grayscale == 1:
        if (flatten==1) and np.ndim(f)>2:
            f = np.sum(f, axis=2)
    if resc==1:
        f = rescale(f)
    # change the size of the image
    if n > 0:
        if np.ndim(f)==2:
            f = transform.resize(f, [n, n], 1)
        elif np.ndim(f)==3:
            f = transform.resize(f, [n, n, f.shape[2]], 1)
    return f

def perform_wavortho_transf(f, Jmin, dir, h):
    """
        perform_wavortho_transf - compute orthogonal wavelet transform

        fw = perform_wavortho_transf(f,Jmin,dir,options);

        You can give the filter in options.h.

        Works in 2D only.

        Copyright (c) 2014 Gabriel Peyre
    """

    n = f.shape[1]
    Jmax = int(np.log2(n)) - 1
    # compute g filter
    u = np.power(-np.ones(len(h) - 1), range(1, len(h)))
    # alternate +1/-1
    g = np.concatenate(([0], h[-1:0:-1] * u))

    if dir == 1:
        ### FORWARD ###
        fW = f.copy()
        for j in np.arange(Jmax, Jmin - 1, -1):
            A = fW[:2 ** (j + 1):, :2 ** (j + 1):]
            for d in np.arange(1, 3):
                Coarse = subsampling(cconv(A, h, d), d)
                Detail = subsampling(cconv(A, g, d), d)
                A = np.concatenate((Coarse, Detail), axis=d - 1)
            fW[:2 ** (j + 1):, :2 ** (j + 1):] = A
        return fW
    else:
        ### BACKWARD ###
        fW = f.copy()
        f1 = fW.copy()
        for j in np.arange(Jmin, Jmax + 1):
            A = f1[:2 ** (j + 1):, :2 ** (j + 1):]
            for d in np.arange(1, 3):
                if d == 1:
                    Coarse = A[:2**j:, :]
                    Detail = A[2**j: 2**(j + 1):, :]
                else:
                    Coarse = A[:, :2 ** j:]
                    Detail = A[:, 2 ** j:2 ** (j + 1):]
                Coarse = cconv(upsampling(Coarse, d), reverse(h), d)
                Detail = cconv(upsampling(Detail, d), reverse(g), d)
                A = Coarse + Detail
            f1[:2 ** (j + 1):, :2 ** (j + 1):] = A
        return f1


def plot_wavelet(fW, Jmin=0):
    """
        plot_wavelet - plot wavelets coefficients.

        U = plot_wavelet(fW, Jmin):

        Copyright (c) 2014 Gabriel Peyre
    """
    def rescaleWav(A):
        v = abs(A).max()
        B = A.copy()
        if v > 0:
            B = .5 + .5 * A / v
        return B
    ##
    n = fW.shape[1]
    Jmax = int(np.log2(n)) - 1
    U = fW.copy()
    for j in np.arange(Jmax, Jmin - 1, -1):
        U[:2 ** j:,    2 ** j:2 **
            (j + 1):] = rescaleWav(U[:2 ** j:, 2 ** j:2 ** (j + 1):])
        U[2 ** j:2 ** (j + 1):, :2 **
          j:] = rescaleWav(U[2 ** j:2 ** (j + 1):, :2 ** j:])
        U[2 ** j:2 ** (j + 1):, 2 ** j:2 ** (j + 1):] = (
            rescaleWav(U[2 ** j:2 ** (j + 1):, 2 ** j:2 ** (j + 1):]))
    # coarse scale
    U[:2 ** j:, :2 ** j:] = rescale(U[:2 ** j:, :2 ** j:])
    # plot underlying image
    imageplot(U)
    # display crosses
    for j in np.arange(Jmax, Jmin - 1, -1):
        plt.plot([0, 2 ** (j + 1)], [2 ** j, 2 ** j], 'r')
        plt.plot([2 ** j, 2 ** j], [0, 2 ** (j + 1)], 'r')
    # display box
    plt.plot([0, n], [0, 0], 'r')
    plt.plot([0, n], [n, n], 'r')
    plt.plot([0, 0], [0, n], 'r')
    plt.plot([n, n], [0, n], 'r')
    return U

def psnr(x, y, vmax=-1):
    """
     psnr - compute the Peack Signal to Noise Ratio

       p = psnr(x,y,vmax);

       defined by :
           p = 10*log10( vmax^2 / |x-y|^2 )
       |x-y|^2 = mean( (x(:)-y(:)).^2 )
       if vmax is ommited, then
           vmax = max(max(x(:)),max(y(:)))

       Copyright (c) 2014 Gabriel Peyre
    """

    if vmax < 0:
        m1 = abs(x).max()
        m2 = abs(y).max()
        vmax = max(m1, m2)
    d = np.mean((x - y) ** 2)
    return 10 * np.log10(vmax ** 2 / d)

def snr(x, y):
    """
    snr - signal to noise ratio

       v = snr(x,y);

     v = 20*log10( norm(x(:)) / norm(x(:)-y(:)) )

       x is the original clean signal (reference).
       y is the denoised signal.

    Copyright (c) 2014 Gabriel Peyre
    """

    return 20 * np.log10(pylab.norm(x) / pylab.norm(x - y))

def subsampling(x, d):
    # subsampling along dimension d by factor p=2
    p = 2
    if d == 1:
        y = x[::p, :]
    elif d == 2:
        y = x[:, ::p]
    else:
        raise Exception('Not implemented')
    return y

def upsampling(x, d):
    """
        up-sampling along dimension d by factor p=2
    """
    p = 2
    s = x.shape
    if d == 1:
        y = np.zeros((p * s[0], s[1]))
        y[::p, :] = x
    elif d == 2:
        y = np.zeros((s[0], p * s[1]))
        y[:, ::p] = x
    else:
        raise Exception('Not implemented')
    return y


def plot_dictionary(D, title='Dictionary'):
    ''' Plot a dictionary of shape (width*width, n_atoms) '''
    # Check that D.shape == (width*width, n_atoms)
    assert len(D.shape) == 2
    assert int(np.sqrt(D.shape[0]))**2 == D.shape[0]
    (signal_size, n_atoms) = D.shape
    # Rescale values in each atom to have a max absolute value of 1
    # This gives brighter plots
    D = D / np.max(abs(D), axis=0)

    # Reshape dictionary to patches
    width = int(np.sqrt(D.shape[0]))
    D = D.reshape((width, width, n_atoms))
    n = int(np.ceil(np.sqrt(n_atoms)))  # Size of the plot in number of atoms

    # Pad the atoms
    pad_size = 1
    missing_atoms = n ** 2 - n_atoms

    padding = (((pad_size, pad_size), (pad_size, pad_size),
                (0, missing_atoms)))
    D = np.pad(D, padding, mode='constant', constant_values=1)
    padded_width = width + 2*pad_size
    D = D.reshape(padded_width, padded_width, n, n)
    D = D.transpose(2, 0, 3, 1)  # Needed for the reshape
    big_image_size = n*padded_width
    D = D.reshape(big_image_size, big_image_size)
    imageplot(D)
    plt.title(title)
    plt.show()


def perform_wavelet_transf(f, Jmin, dir, filter = "9-7",separable = 0, ti = 0):

    """""
    perform_wavelet_transf - peform fast lifting transform

    y = perform_wavelet_transf(x, Jmin, dir, filter = "9-7",separable = 0, ti = 0);

    Implement 1D and 2D symmetric wavelets with symmetric boundary treatements, using
    a lifting implementation.

    filter gives the coefficients of the lifting filter.
    You can use h='linear' or h='7-9' to select automatically biorthogonal
    transform with 2 and 4 vanishing moments.

    You can set ti=1 to compute a translation invariant wavelet transform.

    You can set separable=1 to compute a separable 2D wavelet
    transform.

    Copyright (c) 2008 Gabriel Peyre
    """

    #copy f
    x = np.copy(f)

    #convert Jmin to int
    Jmin = int(Jmin)

    # detect dimensionality
    d = np.ndim(x)
    # P/U/P/U/etc the last coefficient is scaling
    if filter in ["linear","5-3"]:
        h = [1/2, 1/4, np.sqrt(2)]

    elif filter in ["9-7","7-9"]:
        h = [1.586134342, -.05298011854, -.8829110762, .4435068522, 1.149604398]

    else:
        raise ValueError('Unknown filter')

    if d == 2 and separable == 1:
        ti = 0
        if ti == 1:
            wrn.warning("Separable does not works for translation invariant transform")

        # perform a separable wavelet transform
        n = np.shape(x)[0]
        if dir == 1:
            for i in range(n):
                x[:,i] = perform_wavelet_transf(x[:,i], Jmin, dir, filter, separable, ti)
            for i in range(n):
                x[i,:] = np.transpose(perform_wavelet_transf(np.transpose(x[i,:]), Jmin, dir, filter, separable, ti))
        else:
            for i in range(n):
                x[i,:] = np.transpose(perform_wavelet_transf(np.transpose(x[i,:]), Jmin, dir, filter, separable, ti))
            for i in range(n):
                x[:,i] = perform_wavelet_transf(x[:,i], Jmin, dir, filter, separable, ti)


    # number of lifting steps
    if np.ndim(x) == 1:
        n = len(x)
    else:
        n = np.shape(x)[1]
    m = (len(h)-1)//2
    Jmax = int(np.log2(n)-1)
    jlist = range(Jmax,Jmin-1,-1)

    if dir == -1:
        jlist = range(Jmin,Jmax+1,1)

    if ti == 0:
        # subsampled
        for j in jlist:
            if d == 1:
                x[:2**(j+1),:] = lifting_step(x[:2**(j+1)], h, dir)
            else:
                x[:2**(j+1),:2**(j+1)] = lifting_step(x[:2**(j+1),:2**(j+1)], h, dir)
                x[:2**(j+1),:2**(j+1)] = np.transpose(lifting_step(np.transpose(x[:2**(j+1),:2**(j+1)]), h, dir))

    else:
        # TI
        nJ = Jmax - Jmin + 1
        if dir == 1 and d == 1:
            x = np.tile(x,(nJ + 1,1,1))
        elif dir == 1 and d == 2:
            x = np.tile(x,(3*nJ + 1,1,1))
        #elif dir == 1:
        #    x = np.tile(x,(1,1,1))
        for j in jlist:
            dist = 2**(Jmax - j)

            if d == 1:
                if dir == 1:
                    x[:(j-Jmin+2),:,:] = lifting_step_ti(x[0,:,:], h, dir, dist)
                else:
                    x[0,:,:] = lifting_step_ti(x[:(j-Jmin+2),:,:], h, dir, dist)
            else:
                dj = 3*(j-Jmin)

                if dir == 1:
                    x[[0,dj+1],:,:] = lifting_step_ti(x[0,:,:], h, dir, dist)

                    x[[0,dj+2],:,:] = lifting_step_ti(np.transpose(x[0,:,:]), h, dir, dist)
                    x[0,:,:] = np.transpose(x[0,:,:])
                    x[dj+2,:,:] = np.transpose(x[dj+2,:,:])

                    x[[1+dj,3+dj],:,:] = lifting_step_ti(np.transpose(x[dj+1,:,:]), h, dir, dist)
                    x[dj+1,:,:] = np.transpose(x[dj+1,:,:])
                    x[dj+3,:,:] = np.transpose(x[dj+3,:,:])
                else:

                    x[dj+1,:,:] = np.transpose(x[dj+1,:,:])
                    x[dj+3,:,:] = np.transpose(x[dj+3,:,:])

                    x[dj+1,:,:] = np.transpose(lifting_step_ti(x[[1+dj, 3+dj],:,:], h, dir, dist))

                    x[0,:,:] = np.transpose(x[0,:,:])
                    x[dj+2,:,:] = np.transpose(x[dj+2,:,:])
                    x[0,:,:] = np.transpose(lifting_step_ti(x[[0,dj+2],:,:], h, dir, dist))

                    x[0,:,:] = lifting_step_ti(x[[0,dj+1],:,:], h, dir, dist)

        if dir == -1:
            x = x[0,:,:]

    return x

###########################################################################
###########################################################################
###########################################################################

def lifting_step(x0, h, dir):

    #copy x
    x = np.copy(x0)

    # number of lifting steps
    m = (len(h) - 1)//2

    if dir==1:
        # split
        d = x[1::2,]
        x = x[0::2,]
        for i in range(m):
            d = d - h[2*i] * (x + np.vstack((x[1:,],x[-1,])))
            x = x + h[2*i+1] * (d + np.vstack((d[0,],d[:-1,])))
        x = np.vstack((x*h[-1],d/h[-1]))

    else:
        # retrieve detail coefs
        end = len(x)
        d = x[end//2:,]*h[-1]
        x = x[:end//2,]/h[-1]
        for i in range(m,0,-1):
            x = x - h[2*i-1] * (d + np.vstack((d[0,],d[:-1,])))
            d = d + h[2*i-2] * (x + np.vstack((x[1:,],x[-1,])))
        # merge
        x1 = np.vstack((x,x))
        x1[::2,] = x
        x1[1::2,] = d
        x = x1

    return x

###########################################################################
###########################################################################
###########################################################################
def lifting_step_ti(x0, h, dir, dist):

    #copy x
    x = np.copy(x0)

    # number of lifting steps
    m = (len(h) - 1)//2
    n = np.shape(x[0])[0]

    s1 = np.arange(1, n+1) + dist
    s2 = np.arange(1, n+1) - dist

    # boundary conditions
    s1[s1 > n] = 2*n - s1[s1 > n]
    s1[s1 < 1] = 2   - s1[s1 < 1]

    s2[s2 > n] = 2*n - s2[s2 > n]
    s2[s2 < 1] = 2   - s2[s2 < 1]

    #indices in python start from 0
    s1 = s1 - 1
    s2 = s2 - 1

    if dir == 1:
        # split
        d = x
        for i in range(m):
            if np.ndim(x) == 2 :
                x = np.tile(x,(1,1,1))
            d = d - h[2*i]   * (x[:,s1,:] + x[:,s2,:])
            x = x + h[2*i+1] * (d[:,s1,:] + d[:,s2,:])

        #merge
        x = np.concatenate((x*h[-1],d/h[-1]))

    else:
        # retrieve detail coefs

        d = x[1,:,:]*h[-1]
        x = x[0,:,:]/h[-1]

        for i in range(m,0,-1):
            x = x - h[2*i-1] * (d[s1,:] + d[s2,:])
            d = d + h[2*i-2] * (x[s1,:] + x[s2,:])

        # merge
        x = (x + d)/2

    return x
