In [None]:
import numpy as np
import cupy as cp
import sys
import pandas as pd
import time
import matplotlib.pyplot as plt
import h5py
from types import SimpleNamespace
import warnings
warnings.filterwarnings("ignore", message=f".*peer.*")


sys.path.insert(0, '..')
from utils import *
from rec import Rec

In [None]:
energy = 33.5
wavelength = 1.24e-09 / energy
z1 = -17.75e-3  # [m] position of the sample
detector_pixelsize = 3.03751e-6
focusToDetectorDistance = 1.28  # [m]
# adjustments for the cone beam
z2 = focusToDetectorDistance - z1
distance = (z1 * z2) / focusToDetectorDistance
magnification = focusToDetectorDistance / z1
voxelsize = float(cp.abs(detector_pixelsize / magnification))

In [None]:
args = SimpleNamespace()

args.ngpus = 1  # int(sys.args[1])
args.n = 8704
args.ncode = 8704
args.npsi = 8704
args.pad = 0
args.nq = 8704
args.ex = 0
args.npatch = 8704
args.npos = 1
args.nchunk = 1

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance
args.distancec = distance
args.eps = 1e-8
args.show = True

# doesnt matter
args.lam = 0
args.eps = 1e-12
args.rho = [1, 0.0, 0.0]
args.crop = 0
args.path_out = ""
args.niter = 2049
args.err_step = 1
args.vis_step = -1
args.method = "BH-CG"
args.show = True

# create class
cl_rec = Rec(args)

In [None]:
path_code = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224/SiemensLH/code2um_nfp18x18_01/bets_final_256_0.1_2"
iter = 768

In [None]:
cthickness = 1
code_angle = read_tiff(f"{path_code}/rec_psi_angle/{iter:04}.tiff")
code_abs = read_tiff(f"{path_code}/rec_psi_abs/{iter:04}.tiff")
code = np.exp(cthickness*(np.log(code_abs)+1j * code_angle))

mshow_polar(code, args.show)

q_angle = read_tiff(f"{path_code}/rec_prb_angle/{iter:04}.tiff")
q_abs = read_tiff(f"{path_code}/rec_prb_abs/{iter:04}.tiff")
q = q_abs * np.exp(1j * q_angle)
mshow_polar(q, args.show)


In [None]:
code_data = (cp.abs(cl_rec.D(cp.array(code[cp.newaxis]))[0])**2).get()
mshow(code_data,args.show,vmax=3)
# code_data = np.load("code_data.npy")
mshow(code_data, args.show, vmax=3)

In [None]:
mshow(code_data, args.show, vmax=1.7, vmin=0.7)
mshow(code_data[2000:2500, 2000:2500], args.show, vmax=1.7, vmin=0.7)

## Model data



In [None]:
z1c = -17.75e-3
detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09 / energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = 5.5e-3 - sx0
z2 = focusToDetectorDistance - z1
distance = (z1 * z2) / focusToDetectorDistance
magnifications = focusToDetectorDistance / z1
voxelsize = np.abs(detector_pixelsize / magnifications)  # object voxel size

# magnification when propagating from the probe plane to the detector
magnifications2 = z1 / z1c
distancec = (z1 - z1c) / (z1c / z1)  # magnifications2

show = True
print(distance,distancec)

In [None]:
args = SimpleNamespace()
args.npos = 8

args.ngpus = 4
args.lam = 0.2

args.n = 2048
args.ncode = 8192
args.pad = args.n // 8
args.npsi = args.n + 2 * args.pad
args.nq = args.n + 2 * args.pad
args.ex = 8
args.npatch = args.nq + 2 * args.ex
args.nchunk = 2

args.voxelsize = voxelsize
args.wavelength = wavelength
args.distance = distance
args.distancec = distancec
args.eps = 1e-12
args.rho = [1, 0, 0]
args.crop = args.npsi//2#2 * args.pad

args.path_out = f"/data/vnikitin/ESRF/ID16A/20240924_rec0224/SiemensLH/SiemensLH_010nm_code2um_nfp9x9_01/syn_{args.lam}_{cthickness}"


args.niter = 2049
args.err_step = 1
args.vis_step = 32
args.method = "BH-CG"
args.show = True

In [None]:
import scipy.ndimage as ndimage

psi = ndimage.zoom(np.load("psi.npy"), 4)
psi = psi[
    psi.shape[0] // 2 - args.npsi // 2 : psi.shape[0] // 2 + args.npsi // 2,
    psi.shape[0] // 2 - args.npsi // 2 : psi.shape[0] // 2 + args.npsi // 2,
]
m = 10
psi = np.abs(psi) * np.exp(1j * np.angle(psi) / m)

np.random.seed(100)
shifts_code = np.load("shifts_code.npy").astype('int32')
ids = np.where((np.abs(shifts_code[:, 0]) < 670) * (np.abs(shifts_code[:, 1]) < 650))[0]
print(len(ids))

args.npos = len(ids)

args.path_out += f'_{m}_{args.npos}_n1000'

shifts_code = shifts_code[ids]
# shifts_code += (np.random.random([len(ids),2])-0.5)

ri = shifts_code.astype("int32")
r = (shifts_code - ri).astype("float32")
print(psi.shape)
mshow_polar(psi, show)


# create class
cl_rec = Rec(args)

In [None]:
# qpad = args.nq//2-q.shape[-1]//2
# q = np.pad(q,((qpad,qpad),(qpad,qpad)))
# mshow_polar(q,True)

In [None]:
t = cl_rec.S(ri, r, code)
t = cl_rec.Dc(t * q) * psi
t = cl_rec.D(t)
data = np.abs(t) ** 2

t = cl_rec.Dc(q[np.newaxis])
t = cl_rec.D(t)[0]
ref = np.abs(t) ** 2

ndata = np.random.poisson(data*1000).astype('float32')/1000
data = ndata
mshow(data[0]-ndata[0], show,vmax=0.3,vmin=-0.3)
mshow(ref, show)

In [None]:
rdata = data / ref

In [None]:
ri = shifts_code.astype("int32")
r = (shifts_code - ri).astype("float32")
scode = cl_rec.S(ri, r, code)
Dscode = cl_rec.Dc(scode * q)
DDscode = cl_rec.D(Dscode)

mshow_polar(DDscode[0], mshow)
rrdata = data / np.abs(DDscode) ** 2
mshow(rrdata[0], show, vmax=1.3, vmin=0.7)
mshow(cp.mean(rrdata[:], axis=0), show, vmax=1.1, vmin=0.8)

In [None]:
def Paganin(data, wavelength, voxelsize, delta_beta, alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype("float32")
    [fx, fy] = cp.meshgrid(fx, fx)
    rad_freq = cp.fft.fft2(data)
    taylorExp = 1 + wavelength * distance * cp.pi * (delta_beta) * (fx**2 + fy**2)
    numerator = taylorExp * (rad_freq)
    denominator = taylorExp**2 + alpha
    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase = delta_beta * 0.5 * phase
    return phase

rrdata0 = np.mean(rrdata,axis=0)
rrdata0/=np.mean(rrdata0)
rrdata0[rrdata0>1.5]=1.5

rrdata0 = cp.array(rrdata0)
psi_init = Paganin(rrdata0, wavelength, voxelsize, 4.05, 2e-2)
psi_init = np.pad(psi_init,((args.npsi//2-args.n//2,args.npsi//2-args.n//2),
                                         (args.npsi//2-args.n//2,args.npsi//2-args.n//2)
                                         ))
psi_init = np.exp(1j * psi_init)
mshow_polar(psi_init,args.show)

# smooth borders
v = cp.arange(-args.npsi // 2,args.npsi // 2) / args.npsi
[vx, vy] = cp.meshgrid(v, v)
v = cp.exp(-100 * (vx**2 + vy**2)).astype("float32")

psi_init = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(psi_init)))
psi_init = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(psi_init * v))).astype(
    "complex64"
)
mshow_polar(psi_init,args.show)

rdata = v = []

In [None]:
# def Paganin(data, wavelength, voxelsize, delta_beta, alpha):
#     fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype("float32")
#     [fx, fy] = cp.meshgrid(fx, fx)
#     rad_freq = cp.fft.fft2(data)
#     taylorExp = 1 + wavelength * distance * cp.pi * (delta_beta) * (fx**2 + fy**2)
#     numerator = taylorExp * (rad_freq)
#     denominator = taylorExp**2 + alpha
#     phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
#     phase = delta_beta * 0.5 * phase
#     return phase


# def rec_init(rdata):
#     recMultiPaganin = cp.zeros([args.npsi, args.npsi], dtype="float32")
#     recMultiPaganinr = cp.zeros(
#         [args.npsi, args.npsi], dtype="float32"
#     )  # to compensate for overlap
#     for j in range(0, args.npos):
#         r = cp.array(rdata[j])
#         r = Paganin(r, wavelength, voxelsize, 25.05, 5e-7)
#         rr = r * 0 + 1  # to compensate for overlap
#         rpsi = cp.zeros([args.npsi, args.npsi], dtype="float32")
#         rrpsi = cp.zeros([args.npsi, args.npsi], dtype="float32")
#         stx = args.npsi // 2 - args.n // 2
#         endx = stx + args.n
#         sty = args.npsi // 2 - args.n // 2
#         endy = sty + args.n
#         rpsi[sty:endy, stx:endx] = r
#         rrpsi[sty:endy, stx:endx] = rr

#         recMultiPaganin += rpsi
#         recMultiPaganinr += rrpsi

#     recMultiPaganinr[np.abs(recMultiPaganinr) < 5e-2] = 1
#     recMultiPaganin /= recMultiPaganinr
#     recMultiPaganin = np.exp(1j * recMultiPaganin)
#     return recMultiPaganin


# print(wavelength, voxelsize)
# rrdata0 = np.mean(rrdata, axis=0)
# rrdata0 /= np.mean(rrdata0)
# rrdata0 = cp.array(rrdata0)
# psi_init = Paganin(rrdata0, wavelength, voxelsize, 25.05, 5e-4)
# psi_init = np.pad(
#     psi_init,
#     (
#         (args.npsi // 2 - args.n // 2, args.npsi // 2 - args.n // 2),
#         (args.npsi // 2 - args.n // 2, args.npsi // 2 - args.n // 2),
#     ),
# )
# psi_init = np.exp(1j * psi_init)
# mshow_polar(psi_init, args.show)

# # smooth borders
# v = cp.arange(-args.npsi // 2, args.npsi // 2) / args.npsi
# [vx, vy] = cp.meshgrid(v, v)
# v = cp.exp(-1000 * (vx**2 + vy**2)).astype("float32")

# psi_init = cp.fft.fftshift(cp.fft.fftn(cp.fft.fftshift(psi_init)))
# psi_init = cp.fft.fftshift(cp.fft.ifftn(cp.fft.fftshift(psi_init * v))).astype(
#     "complex64"
# )
# mshow_polar(psi_init, args.show)

# rdata = v = []

In [None]:
# variables
vars = {}
vars["code"] = cp.array(code)
vars["psi"] = cp.array(psi_init)
vars["q"] = cp.array(q)
vars["ri"] = np.floor(shifts_code).astype("int32")
vars["r"] = np.array(shifts_code - vars["ri"]).astype("float32") #+ (
    #np.random.random([args.npos, 2]).astype("float32") - 0.5
#)
vars["table"] = pd.DataFrame(columns=["iter", "err", "time"])

# reconstruction
vars = cl_rec.BH(data, vars)