# Convergence comparison of ptycho algorithms

## Introduction

In this notebook, we are comparing the convergence of the three ptycho reconstruction algorithms: admm, combined, divided. The ADMM approach solves the ptychography nearplane, farplane, probe, and object subproblems jointly using ADMM. The divded approach solves each of the subproblems sequentially, but without regularization. The combined problem solves the object and probe subproblems with the nearplane and farplane problems integated.

We will plot the convergence of each these algorithms using various metrics including:

* primal residuals
* dual residuals
* augmented lagrangian
* relative error (to previous iteration) for all subproblems

## Convergence metrics

In this section, we define helper functions to compute all the convergence metrics.

In [2]:
import numpy as np

In [3]:
def constraint_farplane(operator, nearplane, farplane, **kwargs):
    return operator.propagation.fwd(nearplane) - farplane

def primal_farplane(operator, nearplane, farplane, **kwargs):
    return np.linalg.norm(constraint_farplane(operator, nearplane, farplane))**2

def constraint_nearplane(operator, probe, psi, nearplane, scan, **kwargs):
    return operator.diffraction.fwd(probe=probe, psi=psi, scan=scan) - nearplane

def primal_nearplane(operator, probe, psi, nearplane, **kwargs):
    return np.linalg.norm(constraint_nearplane(operator, probe, psi, nearplane, **kwargs))**2

In [4]:
# dual residuals are same as relative error for farplane and nearplane?

In [5]:
data_frames = np.load('data.npy')
def augmented_lagrangian(operator, farplane, nearplane, psi, probe, λ, μ, ρ, τ, **kwargs):
    return (
        + operator.propagation.cost(data_frames, farplane)
        + 2 * np.sum(np.real(np.conj(λ) * constraint_farplane(operator, nearplane, farplane, **kwargs)))
        + ρ * primal_farplane(operator, nearplane, farplane, **kwargs)
        + 2 * np.sum(np.real(np.conj(μ) * constraint_nearplane(operator, probe, psi, nearplane, **kwargs)))
        + τ * primal_nearplane(operator, probe, psi, nearplane, **kwargs)
    )

In [6]:
def relative_error(x0, x1):
    return np.linalg.norm(x0 - x1)**2 / np.linalg.norm(x0)**2

## Compute the metrics

In this section, we load the data from the disk and compute the metrics.

In [7]:
import tike.operators
op = tike.operators.Ptycho(
    detector_shape=15*2,
    probe_shape=15,
    nscan=21*21,
    nz=128,
    n=128,
)

In [8]:
import glob
from collections import defaultdict

import matplotlib.pyplot as plt
from tike.view import plot_phase

import tifffile

all_metrics = {}

for algorithm in ['admm', 'combined', 'divided']:
    metrics = defaultdict(list)
    files = sorted(glob.glob(f'{algorithm}.*.npz'))
    wall_time = 0
    data0 = []
    for i in range(len(files)):
        print(files[i])
        ii = int(files[i].split('.')[1])
        metrics['i'].append(ii)
        
        data = np.load(files[i])
        
#         tifffile.imsave(f'cg4/p/{algorithm}.{ii:03d}.tiff', np.angle(data['psi'][0]).astype('float32'))
#         tifffile.imsave(f'cg4/a/{algorithm}.{ii:03d}.tiff', np.abs(data['psi'][0]).astype('float32'))
        
        if i > 0:
            wall_time += data['time']
        metrics['time'].append(wall_time)
        
        if 'farplane' in data and 'nearplane' in data:
            metrics['primal_farplane'].append(primal_farplane(op, **data))
        else:
            metrics['primal_farplane'].append(None)
        
        if 'nearplane' in data:
            metrics['primal_nearplane'].append(primal_nearplane(op, **data))
        else:
            metrics['primal_nearplane'].append(None)
        
        if 'λ' in data and 'farplane' in data:
            metrics['augmented_lagrangian'].append(augmented_lagrangian(op, ρ=0.5, τ=0.5, **data))
        else:
            metrics['augmented_lagrangian'].append(None)
        
        for param in ['psi', 'probe']:
            if param in data and param in data0:
                metrics[f'Δ{param}'].append(relative_error(data0[param], data[param]))
            else:
                metrics[f'Δ{param}'].append(None)

        data0 = data
        
#         plt.figure()
#         plot_phase(data['psi'][0])
#         plt.show()

    all_metrics[algorithm] = metrics

data0.close()
data.close()

admm.000.npz
admm.001.npz
admm.002.npz
admm.003.npz
admm.004.npz
admm.005.npz
admm.006.npz
admm.007.npz
admm.008.npz
admm.009.npz
admm.010.npz
admm.012.npz
admm.013.npz
admm.015.npz
admm.016.npz
admm.018.npz
admm.021.npz
admm.023.npz
admm.026.npz
admm.029.npz
admm.033.npz
admm.037.npz
admm.041.npz
admm.046.npz
admm.052.npz
admm.058.npz
admm.065.npz
admm.073.npz
admm.082.npz
admm.092.npz
admm.103.npz
admm.115.npz
admm.129.npz
admm.145.npz
admm.162.npz
admm.182.npz
admm.204.npz
admm.228.npz
admm.256.npz
combined.000.npz
combined.001.npz
combined.002.npz
combined.003.npz
combined.004.npz
combined.005.npz
combined.006.npz
combined.007.npz
combined.008.npz
combined.009.npz
combined.010.npz
combined.012.npz
combined.013.npz
combined.015.npz
combined.016.npz
combined.018.npz
combined.021.npz
combined.023.npz
combined.026.npz
combined.029.npz
combined.033.npz
combined.037.npz
combined.041.npz
combined.046.npz
combined.052.npz
combined.058.npz
combined.065.npz
combined.073.npz
combined.082.npz


KeyError: 'time is not a file in the archive'

## Plot the metrics

In [None]:
import matplotlib.pyplot as plt

for key in all_metrics['divided'].keys():
    if key in ['i', 'time']:
        continue
    plt.figure()
    ax1 = plt.subplot(1, 2, 1)
    ax2 = plt.subplot(1, 2, 2)
    l = []
    for algorithm in ['admm', 'divided', 'combined']:
        if key in  all_metrics[algorithm].keys():
            ax1.semilogy(
                np.array(all_metrics[algorithm]['time']) / 60, all_metrics[algorithm][key],
            )
            ax2.semilogy(
                np.array(all_metrics[algorithm]['i']), all_metrics[algorithm][key],
            )
            l.append(algorithm)
    plt.legend(l)
    plt.suptitle(key)
    ax1.set_xlabel('time [m]')
    ax2.set_xlabel('iteration')
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    plt.savefig(f'{key}.svg')
plt.show()