# Notebook to visualise the ptychography result
Created by Jinseok Ryu (jinseok.ryu@diamond.ac.uk)

In [None]:
import os
import glob
import scipy.ndimage as ndimage
import numpy as np
import hyperspy.api as hs
import matplotlib.pyplot as plt
import tifffile
import h5py as h5
import sys

# Load data

In [None]:
img_adr = ''
meta_adr = img_adr[:-3]+'json'
print(img_adr)
print(meta_adr)
data_name = img_adr.split("/")[-1].split('.')[0]
print(data_name)

In [None]:
try:
    with h5.File(img_adr, 'r') as f:
        img = f["entry_1"]['process_1']['output_1']['object'][:]
        probe_img = f["entry_1"]['process_1']['output_1']['probe'][:]
        R_pixel_size = f['/entry_1/process_1/common_1/dx'][0][0]

        img = np.squeeze(img)
        print(img.shape)
        probe_img = np.squeeze(probe_img)
        print(probe_img.shape)

        if len(probe_img.shape) == 4:
            probe_img = np.sum(probe_img, axis=0)
        elif len(probe_img.shape) == 5:
            probe_img = np.sum(probe_img, axis=0)
            probe_img = probe_img[:, -1]
        else:
            probe_img = probe_img
        
        probe_img = probe_img[-1]
        print(probe_img.shape)
        
    with open(meta_adr, 'r') as f:
        metadata = json.load(f)
        HT = metadata['process']['common']['source']['energy'][0]
        defocus_exp = metadata['experiment']['optics']['lens']['defocus'][0]
        scan_step = metadata['process']['common']['scan']['dR'][0]
        Ry_extent_px = metadata['process']['common']['scan']['N'][0]
        Rx_extent_px = metadata['process']['common']['scan']['N'][1]
        Ry_extent = Ry_extent_px * scan_step
        Rx_extent = Rx_extent_px * scan_step
        rotation = metadata['process']['common']['scan']['rotation']
        print(Ry_extent, Rx_extent, scan_step)

    package = "ptyrex"
    ptyrex_flag = True

except:
    img = hs.load(img_adr)
    Ry_extent = eval(img.metadata["Ry_extent"])
    Rx_extent = eval(img.metadata["Rx_extent"])
    R_pixel_size = eval(img.metadata["R_pixel_size"])
    print(Ry_extent, Rx_extent, R_pixel_size)

    package = img.metadata["package"]
    recon_type = img.metadata["reconstruction_type"]
    print(package, recon_type)

    img = img.data
    print(img.shape)
    print(img.dtype)
    print(np.min(img), np.max(img))

    ptyrex_flag = False

# Object

In [None]:
Rx_limit_ind = int(np.around(Rx_extent/R_pixel_size))
Ry_limit_ind = int(np.around(Ry_extent/R_pixel_size))
print(Rx_limit_ind, Ry_limit_ind)

rotation = False # correct the scan rotation
if rotation:
    img = ndimage.rotate(img, angle=rotation, axes=(-2, -1), reshape=False)

correct_result = True # cropping

if package == "py4dstem" or package == "ptyrex":
    center = True # center cropping
else:
    center = False

if center:
    yc_ind = int(np.around(img.shape[-2]/2))
    xc_ind = int(np.around(img.shape[-1]/2))

    top_ind = yc_ind - int(np.around(Ry_limit_ind/2))
    left_ind = xc_ind - int(np.around(Rx_limit_ind/2))

In [None]:
%matplotlib widget
# default -> result of the last iteration

if len(img.shape) == 4:
    iter_selected = -1
    if img.dtype == "complex64" or img.dtype == "complex128":
        result_slice = np.angle(img[iter_selected])
    else:
        result_slice = img[iter_selected].copy()

    if correct_result:
        if center:
            result_slice = result_slice[:, top_ind:top_ind+Ry_limit_ind, left_ind:left_ind+Rx_limit_ind]
        else:
            result_slice = result_slice[:, :Ry_limit_ind, :Rx_limit_ind]

elif len(img.shape) == 3:
    if img.dtype == "complex64" or img.dtype == "complex128":
        result_slice = np.angle(img)
    else:
        result_slice = img.copy()

    if correct_result:
        if center:
            result_slice = result_slice[:, top_ind:top_ind+Ry_limit_ind, left_ind:left_ind+Rx_limit_ind]
        else:
            result_slice = result_slice[:, :Ry_limit_ind, :Rx_limit_ind]

else:
    if img.dtype == "complex64" or img.dtype == "complex128":
        result_slice = np.angle(img)
    else:
        result_slice = img.copy()

    if correct_result:
        if center:
            result_slice = result_slice[top_ind:top_ind+Ry_limit_ind, left_ind:left_ind+Rx_limit_ind]
        else:
            result_slice = result_slice[Ry_limit_ind, :Rx_limit_ind]

result_slice = hs.signals.Signal2D(result_slice)
result_slice.axes_manager[-1].scale = R_pixel_size
result_slice.axes_manager[-1].unit = 'Å'
result_slice.axes_manager[-2].scale = R_pixel_size
result_slice.axes_manager[-2].unit = 'Å'
result_slice.plot(cmap="inferno")

In [None]:
%matplotlib inline
plt.close("all")
if len(result_slice.data.shape) == 4:
    result_slice.sum(axis=(0, 1)).plot(cmap="inferno")

elif len(result_slice.data.shape) == 3:
    result_slice.sum(axis=0).plot(cmap="inferno")

else:
    result_slice.plot(cmap="inferno")  

In [None]:
tifffile.imwrite(data_name[:-20]+'_object.tif', result_slice.data)

# Probe

In [None]:
if package=='abtem_legacy':
    probe_adr = img_adr[:-10]+"probe.hspy"
    probe_img = hs.load(probe_adr).data
    print(probe_img.shape)
    print(probe_img.dtype)
    if len(probe_img.shape) == 4:
        probe_img = probe_img[-1, 0]
    elif len(probe_img.shape) == 3:
        probe_img = probe_img[0]
    probe_save = img_adr[:-10]+"probe.tif"
    
elif package=='py4dstem':
    probe_adr = img_adr[:-10]+"probe.hspy"
    probe_img = hs.load(probe_adr).data
    print(probe_img.shape)
    print(probe_img.dtype)
    if len(probe_img.shape) == 4:
        probe_img = np.sum(probe_img, axis=1)
        probe_img = probe_img[-1]
    elif len(probe_img.shape) == 3:
        probe_img = np.sum(probe_img, axis=0)
    probe_save = img_adr[:-10]+"probe.tif"

else:
    print(probe_img.shape)
    print(probe_img.dtype)
    probe_save = img_adr[:-4]+"_probe.tif"
    
print("final shape")
print(probe_img.shape)

In [None]:
%matplotlib widget
result_probe = np.abs(probe_img)
print(result_probe.shape)

result_probe = hs.signals.Signal2D(result_probe)
result_probe.plot(cmap="inferno")

In [None]:
tifffile.imwrite(probe_save, result_probe.data)

# Depth view

In [None]:
%matplotlib inline
sx, sy = 80, 320
ex, ey = 160, 320
try:
    slope = (ey-sy) / (ex-sx)
    b = sy - slope*sx
    x_point = np.arange(sx, ex, 1)
    y_point = x_point * slope + b
except:
    y_point = np.arange(sy, ey, 1)
    x_point = np.full_like(y_point, sx)

x_point = x_point.astype(np.int16)
y_point = y_point.astype(np.int16)

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(np.sum(result_show.data, axis=0), cmap="gray")
ax[0].plot(x_point, y_point, 'r-')

depth_profile = []
for i in range(len(x_point)):
    depth_profile.append(result_show.data[:, y_point[i], x_point[i]])
depth_profile = np.asarray(depth_profile)
print(depth_profile.shape)

ax[1].imshow(depth_profile.T, cmap='inferno', aspect='auto')
ax[1].axis("off")

fig.tight_layout()
plt.show()