In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
%matplotlib inline

# extra functons

In [None]:
def reprod(a,b):
    return a.real*b.real+a.imag*b.imag

def redot(a,b,axis=None):    
    res = np.sum(reprod(a,b),axis=axis)        
    return res

def improd(a,b):
    return -a.real*b.imag+a.imag*b.real

def imdot(a,b,axis=None):    
    res = np.sum(improd(a,b),axis=axis)        
    return res

# create some 2d object

In [None]:
n = 128
psi = np.zeros([n,n],dtype='complex64')
psi[n//4:-n//4,n//4:-n//4] = 2+1j
plt.imshow(psi.real)
plt.colorbar()

# generate shifts 

In [None]:
npos = 16
shifts = 64*(np.random.random([npos,2])-0.5).astype('float32')
print(shifts)


# forward and adjoint shift operators

In [None]:
def S(psi,p):
    x = np.fft.fftfreq(n).astype('float32')
    [y, x] = np.meshgrid(x, x)
    pp = np.exp(-2*np.pi*1j * (y*p[:, 1, None, None]+x*p[:, 0, None, None])).astype('complex64')
    res = np.fft.ifft2(pp*np.fft.fft2(psi))
    return res

def ST(psi,p):    
    return S(psi,-p)

# generate data

In [None]:
data = S(psi,shifts)
plt.imshow(data[0].real)
plt.colorbar()
plt.show()
plt.imshow(data[npos-1].real)
plt.colorbar()
plt.show()

# We try to solve 
# $$ \argmin_x F(x)=\argmin_x\|S_x\psi-d\|_2^2$$

# Gradients

In [None]:
def gradientF(psi, x, data):       
    res = 2*(S(psi,x)-data)
    return res
    
def gradientx(psi, x, gradF):    
    # frequencies
    xi1 = np.fft.fftfreq(n).astype('float32')
    xi2, xi1 = np.meshgrid(xi1, xi1)

    # multipliers in frequencies
    w = np.exp(-2 * np.pi * 1j * (xi1 * x[:, 0, None, None]+xi2 * x[:, 1, None, None]))
        
    tmp = np.fft.fft2(psi) 
    dt1 = np.fft.ifft2(w*xi1*tmp)
    dt2 = np.fft.ifft2(w*xi2*tmp)
    
    # inner product with gradF
    gradx = np.zeros([npos, 2], dtype='float32')
    gradx[:, 0] = -2 * np.pi * imdot(gradF, dt1, axis=(1, 2))
    gradx[:, 1] = -2 * np.pi * imdot(gradF, dt2, axis=(1, 2))

    return gradx

# Hessians

In [None]:

def hessianF(psi,x,dt1,dt2,data):   
    res = 2*redot(dt1,dt2)
    return res
    
def hessian(psi,x,dx1,dx2,data,gradF):    
    # frequencies
    xi1 = np.fft.fftfreq(n).astype('float32')
    [xi2, xi1] = np.meshgrid(xi1, xi1)    

    # multipliers in frequencies
    dx1 = dx1[:,:,np.newaxis,np.newaxis]
    dx2 = dx2[:,:,np.newaxis,np.newaxis]
    w = np.exp(-2*np.pi*1j * (xi2*x[:, 1, None, None]+xi1*x[:, 0, None, None]))
    w1 = xi1*dx1[:,0]+xi2*dx1[:,1]
    w2 = xi1*dx2[:,0]+xi2*dx2[:,1]
    w12 = xi1**2*dx1[:,0]*dx2[:,0]+ \
                xi1*xi2*(dx1[:,0]*dx2[:,1]+dx1[:,1]*dx2[:,0])+ \
                xi2**2*dx1[:,1]*dx2[:,1]    
    
    tmp = np.fft.fft2(psi)        
    dt1 = -2*np.pi*1j*np.fft.ifft2(w*w1*tmp)
    dt2 = -2*np.pi*1j*np.fft.ifft2(w*w2*tmp)
    d2t1 = -4*np.pi**2*np.fft.ifft2(w*w12*tmp)
                
    res = redot(gradF,d2t1)+hessianF(psi, x, dt1, dt2, data)             
    return res


# Alpha, beta

In [None]:
def calc_alpha(psi,x,dx1,dx2,gradF,d):    
    top = -redot(dx1,dx2)    # note gradientx is with '-'
    bottom = hessian(psi,x,dx2,dx2,d,gradF)     
    return top/bottom, top, bottom

def calc_beta(psi,x,dx1,dx2,gradF,d):    
    top = hessian(psi,x,dx1,dx2,d,gradF)        
    bottom = hessian(psi,x,dx2,dx2,d,gradF)        
    return top/bottom


# BH

In [None]:
def minf(psi,x,data):
    Spsi = S(psi,x)
    f = np.linalg.norm(Spsi-data)**2
    return f

def plot_debug(psi, x, eta, top, bottom, alpha, data):
    npp = 17
    errt = np.zeros(npp*2)
    errt2 = np.zeros(npp*2)
    for k in range(0,npp*2):
        xt = x+(alpha*k/(npp-1))*eta
        errt[k] = minf(psi,xt,data)
                
    t = alpha*(np.arange(2*npp))/(npp-1)    
    errt2 = minf(psi,x,data) -top*t+0.5*bottom*t**2
    plt.plot(alpha*np.arange(2*npp)/(npp-1),errt,'.')
    plt.plot(alpha*np.arange(2*npp)/(npp-1),errt2,'.')
    plt.show()

def BH(data, psi, x, niter):
    
    for i in range(niter):
        gradF = gradientF(psi,x,data)
        grad = gradientx(psi,x,gradF)    
        
        if i==0:
            eta = -grad        
        else:      
            beta = calc_beta(psi,x,grad,eta,gradF,data)
            eta  = -grad+beta*eta            
         
        alpha,top,bottom = calc_alpha(psi,x,grad,eta,gradF,data)

        plot_debug(psi,x,eta,top,bottom,alpha,data)                

        x += alpha*eta        

        print('error', minf(psi,x,data))
    return x

# make some error in shifts and run reconstruction
# if we increase c then we get divergence
c = 1
x = shifts+c*(np.random.random(shifts.shape)-0.5).astype('float32')

BH(data, psi, x, 10)