In [None]:
import numpy as np
import cupy as cp
from holotomocupy.holo import G, GT
from holotomocupy.magnification import M, MT
from holotomocupy.shift import S, ST
from holotomocupy.recon_methods import multiPaganin
from holotomocupy.utils import *
from holotomocupy.proc import remove_outliers
import holotomocupy.chunking as chunking
import sys

chunk = 5
chunking.global_chunk = chunk


In [None]:
n = 2048  # object size in each dimension
pad = 0#n//8
ndist = 4
lam = 0.1
show = True
# ntheta = int(sys.argv[1])
# st = int(sys.argv[2])
# gpu = int(sys.argv[3])
ntheta=150
st = 0

flg = f'{n}_{ntheta}_{pad}_{lam}_{st}'

detector_pixelsize = 3.03751e-6
energy = 33.35  # [keV] xray energy
wavelength = 1.2398419840550367e-09/energy  # [m] wave length
focusToDetectorDistance = 1.28  # [m]
sx0 = 1.286e-3
z1 = np.array([4.236e-3,4.3625e-3,4.86850e-3,5.91950e-3])[:ndist]-sx0
z2 = focusToDetectorDistance-z1
distances = (z1*z2)/focusToDetectorDistance
magnifications = focusToDetectorDistance/z1
voxelsize = detector_pixelsize/magnifications[0]*2048/n  # object voxel size

norm_magnifications = magnifications/magnifications[0]
# scaled propagation distances due to magnified probes
distances = distances*norm_magnifications**2

z1p = z1[0]  # positions of the probe for reconstruction
z2p = z1-np.tile(z1p, len(z1))
# magnification when propagating from the probe plane to the detector
magnifications2 = (z1p+z2p)/z1p
# propagation distances after switching from the point source wave to plane wave,
distances2 = (z1p*z2p)/(z1p+z2p)
norm_magnifications2 = magnifications2/(z1p/z1[0])  # normalized magnifications
# scaled propagation distances due to magnified probes
distances2 = distances2*norm_magnifications2**2
distances2 = distances2*(z1p/z1)**2

# sample size after demagnification
ne = int(np.ceil((n+2*pad)/norm_magnifications[-1]/32))*32  # make multiple of 32


path = f'/data/vnikitin/ESRF/ID16A/20240924/AtomiumS2/'
pfile = f'AtomiumS2_HT_007nm'
path_out = f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/{flg}'
print(f'{voxelsize=}')
n0=n
ne0=ne
pad0=pad
voxelsize0=voxelsize
print(norm_magnifications)

In [None]:
shifts_init = np.load('shifts.npy')/2
shifts = shifts_init.copy()
iter = 96
for st in range(0,1800,150):
    flg = f'{n}_{ntheta}_{pad}_{lam}_{st}'
    shifts[st:st+150]=np.load(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/{flg}/crec_shift{flg}_{iter:03}.npy')

fig, axs = plt.subplots(1, 2, figsize=(9, 3))
for k in range(4):        
    axs[0].plot(shifts_init[:,k,0]-shifts[:,k,0],'.',label=f"{k}")
    axs[1].plot(shifts_init[:,k,1]-shifts[:,k,1],'.',label=f"{k}")
    plt.legend()
plt.show()

In [None]:


# data = np.zeros([1800,ne//2,ne//2],dtype='float32')
# for st in range(0,1800,150):
#     print(st)
#     flg = f'{n}_{ntheta}_{pad}_{lam}_{st}'
#     data[st:st+150] = dxchange.read_tiff(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/{flg}/crec_psi_angle{flg}/{iter:03}.tiff')


In [None]:
# sdata=data[:].copy()

# # rdata = np.fft.rfft2(rdata)
# # rdata[:,0,0]=0
# # rdata = np.fft.irfft2(rdata)
# vsum = np.sum(sdata,axis=2)
# vsum-=np.min(vsum,axis=1)[:,np.newaxis]
# vsum=vsum.swapaxes(0,1)
# # plt.plot(vsum)
# # plt.show()
# shifts = np.zeros(sdata.shape[0],dtype='int')
# st = 410*2
# end = 460*2
# plt.plot(vsum[st:end])
# plt.show()
# a=vsum[st:end,0].copy()
# for k in range(sdata.shape[0]):    
#     b=vsum[st:end,k]
#     af = np.fft.fft(a)
#     bf = np.fft.fft(b)
#     c = np.fft.ifft(af * np.conj(bf))
#     shifts[k] = np.argmax(abs(c))
    
#     vsum[:,k] = np.roll(vsum[:,k],shifts[k])
#     sdata[k] = np.roll(data[k],shifts[k],axis=0)
# st = 410*2+10
# end = 460*2-20
# plt.plot(vsum[st:end])
# plt.show()    

# print(shifts)
# # plt.plot(vsum)
# # plt.show()
# plt.plot(shifts)
# sdata-=np.mean(sdata[:,400:400+128,1000:1000+128],axis=(1,2))[:,np.newaxis,np.newaxis]
# dxchange.write_tiff_stack(sdata,f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/sdata',overwrite=True)


In [None]:
sdata = dxchange.read_tiff_stack(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/sdata_00000.tiff',ind=np.arange(0,1800))

In [None]:
import tomoalign
ngpus = 1
pnz = 1
nitercg = 256
center = 500
crop = 256*2
theta = (np.arange(1800)/1800*np.pi).astype('float32')
data_rec = np.ascontiguousarray(sdata[:,sdata.shape[1]//2:sdata.shape[1]//2+2,sdata.shape[1]//2-crop:sdata.shape[1]//2+crop])#.swapaxes(0,1)
for k in np.arange(center-2,center+2,0.25):
    print(k)
    res = tomoalign.cg(data_rec, theta, pnz, k, ngpus, nitercg, padding=True)
    mshow(res['u'][0],show)
    dxchange.write_tiff(res['u'][0], f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/test_center/r{k}', overwrite=True)


        

In [None]:
import tomoalign
nproj = 1800
center = 499.75
ngpus = 1
pnz = 16
ptheta = 20
theta = (np.arange(1800)/1800*np.pi).astype('float32')
data_rec = np.ascontiguousarray(sdata[:,sdata.shape[1]//2-crop:sdata.shape[1]//2+crop,sdata.shape[1]//2-crop:sdata.shape[1]//2+crop])#.swapaxes(0,1)
niteradmm = [384,256]
startwin = [1024,512]
stepwin = [2,2]
res = tomoalign.admm_of_levels(
    data_rec, theta, pnz, ptheta, center, ngpus, niteradmm, startwin, stepwin, f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/tmp_admm',padding=True)

dxchange.write_tiff_stack(
    res['u'], f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/results_admm/u/r', overwrite=True)
dxchange.write_tiff_stack(
    res['psi'], f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/results_admm/psi/r', overwrite=True)
np.save(f'/data/vnikitin/ESRF/ID16A/20240924_rec/AtomiumS2/{pfile}/results_admm/flow.npy', res['flow'])