<a href="https://colab.research.google.com/github/nikitinvv/ptychodistrib/blob/main/ptychodistrib_admm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-4.7.12-Linux-x86_64.sh
!chmod +x Miniconda3-4.7.12-Linux-x86_64.sh
!bash ./Miniconda3-4.7.12-Linux-x86_64.sh -b -f -p /usr/local

In [None]:
!conda install -q -y --prefix /usr/local python=3.7.10 

Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

## Package Plan ##

  environment location: /usr/local

  added / updated specs:
    - python=3.7.10


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2021.1.19  |       h06a4308_1         118 KB
    certifi-2020.12.5          |   py37h06a4308_0         141 KB
    conda-4.9.2                |   py37h06a4308_0         2.9 MB
    ld_impl_linux-64-2.33.1    |       h53a641e_7         568 KB
    libffi-3.3                 |       he6710b0_2          50 KB
    ncurses-6.2                |       he6710b0_1         817 KB
    openssl-1.1.1k             |       h27cfd23_0         2.5 MB
    python-3.7.10              |       hdb3f193_0        45.2 MB
    readline-8.1               |       h27cfd23_0         362 KB
    sqlite-3.35.2              |       hdfb4753_0      

In [None]:
import sys
_ = (sys.path.append("/usr/local/lib/python3.7/site-packages"))

In [None]:
mkdir code; 

In [None]:
cd code

In [None]:
!git clone https://github.com/nikitinvv/ptychodistrib

Cloning into 'ptychodistrib'...
remote: Enumerating objects: 47, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 47 (delta 5), reused 39 (delta 5), pack-reused 0[K
Unpacking objects: 100% (47/47), done.


In [None]:
!conda install -y -c conda-forge dxchange swig scikit-build

In [None]:
cd ptychodistrib/

/content/code/ptychodistrib


In [None]:
!pip install .

Processing /content/code/ptychodistrib
Building wheels for collected packages: ptychodistrib
  Building wheel for ptychodistrib (setup.py) ... [?25l[?25hdone
  Created wheel for ptychodistrib: filename=ptychodistrib-0.1.0-cp37-cp37m-linux_x86_64.whl size=43370 sha256=c627cf2eb66761ce02c40b79d8d0ed6d27e31f9f3c04727032cd72ca975e13ce
  Stored in directory: /tmp/pip-ephem-wheel-cache-4bw1dl5i/wheels/b9/b3/dd/5a68fea2e231b8335932e5221c3cacf1b41a8daae86059a055
Successfully built ptychodistrib
Installing collected packages: ptychodistrib
Successfully installed ptychodistrib-0.1.0


# Solver class

In [None]:
import cupy as cp
import numpy as np
from ptychodistrib.ptychofft import ptychofft


class SolverPtycho(ptychofft):
    """Ptychography solver class.
    This class is a context manager which provides the basic operators required
    to implement a tomography solver. It also manages memory automatically,
    and provides correct cleanup for interruptions or terminations.
    Attributes
    ----------
    nz, n : int
        The pixel height and width of the projection.
    nscan : int
        Number of scanning positions
    ndet : int
        Detector size
    nprb : int
        Probe size    
    nnodes : int
        Number of nodes
    """

    def __init__(self, nz, n, nscan, ndet, nprb, nnodes):
        """Please see help(SolverPtycho) for more info."""
        self.nnodes = nnodes
        if(nscan % self.nnodes != 0):
            print(f'Number of nodes should be a multiple of nscan')
            exit()
        super().__init__(1, nz, n, nscan//nnodes, ndet, nprb, 1)  # ntheta==1, ngpu==1

    def __enter__(self):
        """Return self at start of a with-block."""
        return self

    def __exit__(self, type, value, traceback):
        """Free GPU memory due at interruptions or with-block exit."""
        self.free()

    def fwd_ptycho(self, psi, prb, scan):
        """Ptychography transform (FQ)"""
        res = cp.zeros([self.nscan, self.ndet, self.ndet], dtype='complex64')
        # convert to C-contiguous arrays if needed
        psi = cp.ascontiguousarray(psi)
        prb = cp.ascontiguousarray(prb)
        scan = cp.ascontiguousarray(scan)
        # run C wrapper
        self.fwd(res.data.ptr, psi.data.ptr,
                 prb.data.ptr, scan.data.ptr, 0)  # igpu = 0
        return res

    def adj_ptycho(self, data, prb, scan):
        """Adjoint ptychography transform (Q*F*)"""
        res = cp.zeros([self.nz, self.n], dtype='complex64')
        # convert to C-contiguous arrays if needed
        data = cp.ascontiguousarray(data)
        prb = cp.ascontiguousarray(prb)
        scan = cp.ascontiguousarray(scan)
        # run C wrapper
        self.adj(res.data.ptr, data.data.ptr,
                 prb.data.ptr, scan.data.ptr, 0)  # igpu = 0
        return res

    def update_penalty(self, psi, z, z0, rho):
        """Update rho for a faster convergence"""
        r = cp.linalg.norm(psi - z)**2
        s = cp.linalg.norm(rho*(z-z0))**2
        if (r > 10*s):
            rho *= 2
        elif (s > 10*r):
            rho *= 0.5
        return rho

    def grad_ptycho(self, data, psi, prb, scan, zlamd, rho, niter):
        """Gradient solver for the ptychography problem |||FQpsi|-sqrt(data)||^2_2 + rho||psi-zlamd||^2_2"""
        # minimization functional
        def minf(fpsi, psi):
            f = cp.linalg.norm(cp.abs(fpsi) - cp.sqrt(data))**2
            if(rho > 0):
                f += rho*cp.linalg.norm(psi-zlamd)**2
            return f

        for i in range(niter):
            # compute the gradient
            fpsi = self.fwd_ptycho(psi, prb, scan)

            gradpsi = self.adj_ptycho(
                fpsi - cp.sqrt(data)*fpsi/(cp.abs(fpsi)+1e-32), prb, scan)

            # normalization coefficient for skipping the line search procedure
            afpsi = self.adj_ptycho(fpsi, prb, scan)
            norm_coeff = cp.real(cp.sum(psi*cp.conj(afpsi)) /
                                 (cp.sum(afpsi*cp.conj(afpsi))+1e-32))

            if(rho > 0):
                gradpsi += rho*(psi-zlamd)
                gradpsi *= min(1/rho, norm_coeff)/2
            else:
                gradpsi *= norm_coeff/2
            # update psi
            psi = psi - 0.5*gradpsi
            # check convergence
            # print(f'{i}) {minf(fpsi, psi).get():.2e} ')

        return psi

    def grad_ptycho_batch(self, data, psi, prb, scan, zlamd, rho, piter):
        """Gradient solver with splitting by nodes"""
        for k in range(self.nnodes):
            ids = cp.arange(k*self.nscan, (k+1)*self.nscan)
            psi[k] = self.grad_ptycho(
                data[ids], psi[k], prb, scan[:, ids], zlamd[k], rho, piter)
        return psi

    def take_lagr(self, data, psi, prb, scan, z, lamd, rho):
        """Compute Lagrangian"""
        lagr = np.zeros(4, dtype='float32')
        for k in range(self.nnodes):
            ids = cp.arange(k*self.nscan, (k+1)*self.nscan)
            lagr[0] += cp.linalg.norm(cp.abs(self.fwd_ptycho(psi[k],
                                      prb, scan[:, ids]))-cp.sqrt(data[ids]))**2
        lagr[1] = 2*cp.sum(cp.real(cp.conj(lamd)*(psi-z)))
        lagr[2] = rho*cp.linalg.norm(psi-z)**2
        lagr[3] = cp.sum(lagr[:3])
        return lagr


/content/code/ptychodistrib/tests


# **ADMM solver**

In [None]:
import numpy as np
import cupy as cp
import dxchange
from random import sample
import matplotlib.pyplot as plt


n = 384  # object size n x
nz = 384  # object size in z
ndet = 128  # detector size
nprb = 128  # probe size
nscan = 256  #  number of scan positions (max 4554)
nnodes = 1  # number of nodes (multiple of nscan)

# Load object
amp = dxchange.read_tiff('data/object_ampe.tiff')
angle = dxchange.read_tiff('data/object_anglee.tiff')
psiinit = amp*np.exp(1j*angle)

# Load probe
probe_amp = dxchange.read_tiff('data/probe_amp.tiff')
probe_angle = dxchange.read_tiff('data/probe_angle.tiff')
prb = probe_amp*np.exp(1j*probe_angle)

# Load scan positions
scan = np.load('data/scan.npy')
# pick randomly nscan positions
scan = scan[:,sample(range(scan.shape[1]),nscan)]
plt.plot(scan[1], scan[0], 'r.')
plt.savefig(f'data/scan.png')

# copy to gpu
psiinit = cp.array(psiinit)
prb = cp.array(prb)
scan = cp.array(scan)        

# compute data
with SolverPtycho(nz, n, nscan, ndet, nprb, 1) as pslv:
    # data = ||FQpsi||^2
    data = cp.abs(pslv.fwd_ptycho(psiinit, prb, scan))**2

# ADMM solver
with SolverPtycho(nz, n, nscan, ndet, nprb, nnodes) as pslv:
    # init variable
    psi = cp.ones([nnodes, *psiinit.shape], dtype='complex64')
    z = cp.ones(psiinit.shape, dtype='complex64')
    lamd = cp.zeros([nnodes, *psiinit.shape], dtype='complex64')

    niter = 128  # number of outer iterations
    piter = 4  # number of inner iterations in ptychography

    rho = 0.5
    for m in range(niter):
        # keep z from the previous iteration for penalty updates
        z0 = z.copy()
        # 1) ptycho problem (many nodes)
        psi = pslv.grad_ptycho_batch(
            data, psi, prb, scan, z-lamd, rho, piter)
        # 2) regularization problem (one node)
        z = cp.mean(psi+lamd, axis=0)
        # 3) lambda update
        lamd = lamd + (psi - z)
        # update rho, tau for a faster convergence
        rho = pslv.update_penalty(psi, z, z0, rho)
        # Lagrangians difference between two iterations
        lagr = pslv.take_lagr(data, psi, prb, scan, z, lamd, rho)
        print("%d/%d) rho=%.2e, Lagrangian terms:  %.2e %.2e %.2e, Sum: %.2e" %
              (m, niter, rho, *lagr))

dxchange.write_tiff(cp.angle(z).get(),
                    'rec_admm/object_angle.tiff', overwrite=True)
dxchange.write_tiff(cp.abs(z).get(),
                    'rec_admm/object_amp.tiff', overwrite=True)