In [9]:
import numpy as np
import scipy.io as io
import astra
import cupy as cp
import sys
import ptychocg as pt
import tomoalign as tm
import time
import matplotlib.pyplot as plt

from pycallgraph import PyCallGraph
from pycallgraph import Config
from pycallgraph.output import GraphvizOutput

config = Config(max_depth=100)
graphviz = GraphvizOutput(output_file='3dptycho_linear_faast.png')

from skimage.transform import resize
volume = io.loadmat('layer_het300.mat')['volume']
# volume = resize(io.loadmat('layer_het300.mat')['volume'],(100,100,100))
volume = np.array(-1e-5*volume + 1j*1e-7*volume).astype('complex64') #phrased as a complex refractive index
(nz,n,n) = volume.shape
probe0 = io.loadmat('probe_100x100.mat')['probe']
# probe = resize(np.real(probe),(30,30)) + 1j*resize(np.imag(probe),(30,30))
#define the scanning positions
nx = 10
ny = 10
nscan = nx*ny
ntheta = 50
theta = np.linspace(0,np.pi,ntheta).astype('float32')
nprb = probe0.shape[-1]

#k-value for the simulated x-ray source
wavelength = 12.389e-9
dist = 5
pxsize = 172e-6
rec_pxsize = wavelength*dist/(pxsize*probe0.shape[0])
wavenum = 2*np.pi/wavelength*rec_pxsize
# wavenum = 2*np.pi*dist/(pxsize*probe.shape[0])

scan = np.zeros([2,ntheta,nscan], dtype='float32')
for i in range(ntheta):
    for j in range(nscan):
        scan[0,i,j] = np.random.rand(1)*(n-probe0.shape[-1]-1)
        scan[1,i,j] = np.random.rand(1)*(nz-probe0.shape[-1]-1)

probe = np.zeros([ntheta, probe0.shape[-1], probe0.shape[-2]], dtype='complex64')
for i in range(ntheta):
    probe[i,:,:] = probe0

def norm(vf):
    n = np.sqrt(np.inner(np.conj(np.complex64(vf)),np.complex64(vf)))
    return n

def psi(volume, theta, ntheta, k):
    nz = volume.shape[0]
    n = volume.shape[1]
    center = n/2
#     obj = np.exp(1j*k*radon3D(volume, proj_geom))
    with tm.SolverTomo(theta, ntheta, nz, n, np.int(nz/2), center) as tmslv:
        psi = tmslv.fwd_tomo_batch(volume)
    psi = 1 + 1j*k*psi
    return psi.astype('complex64')

def jacobian(volume, theta, probe,scan,hn,k):
    nz = volume.shape[0]
    n = volume.shape[1]
    nprb = probe.shape[-1]
    nscan = scan.shape[2]
    ntheta = scan.shape[1]
    Obj = psi(volume, theta, ntheta, k)
    Obj_hn = psi(hn, theta, ntheta,k)-1
    with pt.CGPtychoSolver(nscan, nprb, nprb, nprb, ntheta, nz, n, ptheta=1, igpu=1) as ptslv:
        Psi = ptslv.fwd_ptycho_batch(Obj,scan,probe)
        ptycho_hn = ptslv.fwd_ptycho_batch(Obj_hn,scan,probe)
    J_scans = 2*np.real(np.conj(Psi)*ptycho_hn)
    return np.nan_to_num(J_scans).astype('complex64')

def jacobian_adjoint(volume,theta,probe,scan,hng,k):
    nz = volume.shape[0]
    n = volume.shape[1]
    center = n/2
    nprb = probe.shape[-1]
    nscan = scan.shape[2]
    ntheta = scan.shape[1]
    Obj = psi(volume, theta, ntheta, k)
    with pt.CGPtychoSolver(nscan, nprb, nprb, nprb, ntheta, nz, n, ptheta=1, igpu=1) as ptslv:
        Psi = ptslv.fwd_ptycho_batch(Obj,scan,probe)
        Psi_hng = Psi*np.real(hng)
        ptycho_adj = ptslv.adj_ptycho_batch(Psi_hng,scan,probe*1j*k)
    with tm.SolverTomo(theta, ntheta, nz, n, np.int(nz/2) ,center) as tmslv:
        JA_image = 2*(tmslv.adj_tomo_batch(ptycho_adj))
    return np.nan_to_num(JA_image).astype('complex64')

In [None]:
recon = (-0+0j)*np.ones(volume.shape,dtype=complex)
# Class gpu solver
with pt.CGPtychoSolver(nscan, nprb, nprb, nprb, ntheta, nz, n, ptheta=1, igpu=1) as ptslv:
    # Compute data
    data = ptslv.fwd_ptycho_batch(psi(volume,theta,ntheta,wavenum), scan, probe)
    guess_data = ptslv.fwd_ptycho_batch(psi(recon, theta, ntheta, wavenum), scan, probe)

with PyCallGraph(output=graphviz, config=config):
    #parameters for LMA:
#     del volume
    mu0 = -1 #set negative to calculate automatically
    ni_LMA = 2
    ni_CGM = 2

    #start the actual algorithm
    mu = mu0
    k = 0
    v = 2 #used for mu update
    r_LMA = data - guess_data #initial residual

    while k<ni_LMA:
        try:           
            t_start = time.time()
            #predefining parts of the (jacobian) and jacobian adjoint that will stay constant during an iteration of LMA
            g = jacobian_adjoint(recon,theta,probe,scan,r_LMA,wavenum)
            if mu<0:
                mu = norm(g.flatten())
            print("Mu is",mu)

            #anonymous function as the operator for CGM (laplacian for total varriation)
            A = lambda h1:jacobian_adjoint(recon,theta,probe,scan,jacobian(recon,theta,probe,scan,h1,wavenum),wavenum) + mu*h1
                                                                        
            #CGM initialization
            h = recon*0     
            r = g-A(h)
            p = r#.astype('complex64')

            for j in range(ni_CGM):
                #this operation takes a long time, so it makes sense to save it.
                AP = A(p)
                alpha = np.inner(np.conj(r.flatten()),r.flatten())/np.real(np.inner(np.conj(p.flatten()),AP.flatten())+1e-16)
                dh = alpha*p
                h = np.array(h + dh)#.astype('complex64')
                r2 = r - alpha*AP
                beta = np.inner(np.conj(r2.flatten()),(r2.flatten()-r.flatten()))/(np.inner(np.conj(r.flatten()),r.flatten())+1e-16) #Polak-Ribiére
                r=r2
                p = r + beta*p
                print("Did iteration {} of {} in CGM".format(j+1,ni_CGM))
            #here the reconstruction is updated completing the iteration of LMA
            recon = np.array(recon + h)#.astype('complex64')
            #nonnegativity
            recon = -np.abs(np.real(recon))+1j*np.maximum(np.imag(recon),0)
            #memory management
            del r,r2, p, dh, alpha, beta, AP 
            #here rho is updated. Rho is used to update mu
            with pt.CGPtychoSolver(nscan, nprb, nprb, nprb, ntheta, nz, n, ptheta=1, igpu=1) as ptslv:
                data_guess = ptslv.fwd_ptycho_batch(psi(recon, theta, ntheta, wavenum), scan, probe)
            r_LMA2 = data - data_guess
            rho = (norm(r_LMA.flatten())**2 - norm(r_LMA2.flatten())**2)/np.real(np.inner(np.conj(h.flatten()),(mu*h+g).flatten()))
            del h #memory management
            #mu update for next iteration.
            if rho > 0:
                mu=mu*np.max([1/3,1-(2*rho-1)**3])
                v=2
                print("Mu is",mu)
            else:
                mu = mu*v
                v=2*v
                print("Mu is",mu)
            r_LMA = r_LMA2
            np.save("recon.npy",recon)
            print("Did iteration {} of {} in LMA.".format(k+1,ni_LMA))
            print("This iteration took {} s. Reconstruction has been saved as recon.npy".format(np.round(time.time()-t_start)))
            k=k+1
        except KeyboardInterrupt:
            print("KEYBOARD INTERUPT DETECTED! Current reconstruction is saved as recon.npy")
            raise