In [None]:
import numpy as np
import cupy as cp
from types import SimpleNamespace
import h5py
from holotomocupy.utils import *
from holotomocupy.tomo_batched import TomoBatched
from holotomocupy.shift import Shift 


### Init

In [None]:
ntheta = 900
st = 0
bin = 0
show = True
rotation_center_shift = -8.780113000000028

paganin = 20

z1_ids = np.array([0,1,2,3])
str_z1_ids = ''.join(map(str, z1_ids + 1))
ndist = len(z1_ids)
ids = np.arange(st, 4500, 4500 / ntheta).astype('int')

path_out = '/data2/vnikitin/brain_rec/20251115/Y350a'
file_out = f'data{str_z1_ids}.h5'

with h5py.File(f'{path_out}/{file_out}', 'r') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:] 
    energy = fid['/exchange/energy'][0] 
    shifts = fid['/exchange/shifts'][ids]
    theta = fid['/exchange/theta'][ids, 0]
    theta = -theta / 180 * np.pi
    attrs = fid['/exchange/attrs'][ids]
    shape = np.array(fid[f'/exchange/data0'].shape)
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape    

wavelength = 1.24e-09 / energy
z2 = focusToDetectorDistance - z1
magnifications = focusToDetectorDistance / z1
norm_magnifications = magnifications / magnifications[0]
distances = (z1 * z2) / focusToDetectorDistance * norm_magnifications**2
voxelsize = detector_pixelsize / magnifications[0]

n = shape[1] 
nobj = int(np.ceil((n) / norm_magnifications[-1] / 64)) * 64 
# nobj+=nobj//8
## change to the current bin level
n //= (2**bin)
nobj //= (2**bin)
voxelsize = voxelsize*2**bin

print(f'{energy=}')
print(f'{z1=}')
print(f'{focusToDetectorDistance=}')
print(f'{detector_pixelsize=}')
print(f'{magnifications=}')
print(f'{voxelsize=}')
print(f'{distances=}')

### Read data

In [None]:
data = np.zeros([ntheta, ndist, n, n], dtype='float32')
with h5py.File(f'{path_out}/{file_out}', 'r') as fid:
    for k in range(ndist):
        data[:, k] = fid[f'/exchange/pdata{k}_{bin}'][ids]
    ref = fid[f'/exchange/pref_{bin}'][:ndist]
    r = (fid[f'/exchange/cshifts_final'][ids] / 2**bin).astype('float32')
    
    #compensate for rotation center shift
    s = rotation_center_shift
    for k in range(bin):
        s = (s - 0.5) / 2
    r[..., 1] += s

### Alignment and intensity correction

In [None]:
cl_shift = Shift(n, nobj,n,nobj, 1 / norm_magnifications)
distances_pag = distances / norm_magnifications**2
npad = n // 16
cref = cp.array(ref)

srdata_out = np.empty([ntheta, ndist, nobj, nobj], dtype='float32')

with h5py.File(f'{path_out}/data{str_z1_ids}.h5','a') as fid:    
    srdata = cp.zeros([ndist, nobj, nobj], dtype='float32')
    for j in range(ntheta):
        data = cp.empty([ndist,n,n],dtype='float32')
        for k in range(ndist):
            data[k] = cp.array(fid[f'/exchange/pdata{k}_{bin}'][ids[j]])
            
            
        rdata = data / (cref + 1e-5)

        for k in range(ndist - 1, -1, -1):
            tmp = rdata[k].astype('complex64')            
            tmp = cl_shift.curlySadj(tmp[None], cp.array(r[j:j+1, k]), k)[0].real
            tmp /= norm_magnifications[k]**2
            padx0 = int((nobj - n / norm_magnifications[k]) / 2) - int(r[j,k,1])  
            pady0 = int((nobj - n / norm_magnifications[k]) / 2) - int(r[j,k,0])
            padx1 = int((nobj - n / norm_magnifications[k]) / 2) + int(r[j,k,1])  
            pady1 = int((nobj - n / norm_magnifications[k]) / 2) + int(r[j,k,0])
            padx0 = min(nobj,max(0,padx0))
            pady0 = min(nobj,max(0,pady0))
            padx1 = min(nobj,max(0,padx1))
            pady1 = min(nobj,max(0,pady1))
            padx0+=2
            pady0+=2
            padx1+=2
            pady1+=2
            
            # print(padx0,padx1,pady0,pady1)
            tmp = cp.pad(tmp[pady0:-pady1], ((pady0,pady1),(0,0)),'edge')
            tmp = cp.pad(tmp[:,padx0:-padx1], ((0,0),(padx0,padx1)),'linear_ramp',end_values=((1, 1), (1, 1)) )
            if k < ndist - 1:                    
                wx = cp.ones([nobj], dtype='float32')
                wy = cp.ones([nobj], dtype='float32')
                
                v = cp.linspace(0, 1, npad, endpoint=False)
                v = v**5 * (126 - 420 * v + 540 * v**2 - 315 * v**3 + 70 * v**4)
                wx[:padx0] = 0
                wx[padx0:padx0+npad] = v
                wx[-padx1-npad:-padx1] = 1 - v
                wx[-padx1:] = 0
                wy[:pady0] = 0
                wy[pady0:pady0+npad] = v
                wy[-pady1-npad:-pady1] = 1 - v
                wy[-pady1:] = 0
                
                w = cp.outer(wy, wx)
                tmp = tmp * w + srdata[k+1] * (1 - w)
            srdata[k] = tmp
        srdata_out[j] = srdata.get()
        # if j%100==0:
        #     print(j)
        #     mshow_complex(srdata[0]+1j*srdata[ndist-1],show)
                                    
srdata = srdata_out

### Masking data

In [None]:
mask_r = 0.9
x = np.linspace(-1, 1, nobj)
[x, y] = np.meshgrid(x, x)
circ = (x**2 < mask_r).astype('float32')
g = np.exp(-30**2 * (x**2 + y**2))
fcirc = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(circ)))
fg = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(g)))
mask = np.fft.fftshift(np.fft.ifft2(np.fft.fftshift(fcirc * fg))).real.astype('float32')
mask /= np.amax(mask)
mshow(mask, True)

#### Paganin

In [None]:
def multiPaganin(data, distances, wavelength, voxelsize, delta_beta, alpha):
    fx = cp.fft.fftfreq(data.shape[-1], d=voxelsize).astype('float32')
    [fx, fy] = cp.meshgrid(fx, fx)
    numerator = 0
    denominator = 0
    for j in range(data.shape[0]):        
        rad_freq = cp.fft.fft2(data[j])
        taylorExp = 1 + wavelength * distances[j] * cp.pi * delta_beta * (fx**2 + fy**2)
        numerator += taylorExp * rad_freq
        denominator += taylorExp**2

    numerator /= len(distances)
    denominator = (denominator / len(distances)) + alpha

    phase = cp.log(cp.real(cp.fft.ifft2(numerator / denominator)))
    phase *= delta_beta * 0.5

    return phase

def rec_init(rdata, paganin):
    recMultiPaganin = np.zeros([ntheta, nobj, nobj], dtype="float32")
    for j in range(ntheta):
        r = cp.array(rdata[j])
        mm = np.mean(r[:, :32 * n // 512]).get()
        r = cp.pad(r, ((0, 0), (nobj // 8, nobj // 8), (nobj // 8, nobj // 8)), 'constant', constant_values=mm)
        distances_pag = (distances / norm_magnifications**2)
        r = multiPaganin(r, distances_pag, wavelength, voxelsize, paganin, 1e-5)
        recMultiPaganin[j] = r[nobj // 8:-nobj // 8, nobj // 8:-nobj // 8].get()
        
    recMultiPaganin -= np.median(recMultiPaganin[:, :, :16 * n // 512])
    recMultiPaganin = recMultiPaganin + 1j * (recMultiPaganin / paganin).astype('float32')
    
    return recMultiPaganin
psi_data = rec_init(srdata, paganin)
psi_data *= mask    
mshow_complex(psi_data[-1], show)

### Create class

In [None]:
args = SimpleNamespace()
args.ngpus = cp.cuda.runtime.getDeviceCount()

args.ntheta = ntheta
args.nobj = nobj
args.nzobj = nobj

args.nchunk = 2
args.show = True
args.theta = theta
args.mask_r = 1.1

cl_rec = TomoBatched(args)

### iterative reconstruction

In [None]:
rec = cl_rec.rec_tomo(psi_data, 21)

mshow_complex(rec[rec.shape[0] // 2], args.show)
mshow_complex(rec[rec.shape[0] // 2, nobj // 2 - nobj // 8:nobj // 2 + nobj // 8, nobj // 2 - nobj // 8:nobj // 2 + nobj // 8], args.show)

### save to h5

In [None]:
with h5py.File(f'{path_out}/data{str_z1_ids}.h5', 'a') as fid:
    for key in [f'/exchange/obj_init_re{paganin}_{bin}', f'/exchange/obj_init_imag{paganin}_{bin}']:
        if key in fid:
            del fid[key]
    fid.create_dataset(f'/exchange/obj_init_re{paganin}_{bin}', data=rec.real)
    fid.create_dataset(f'/exchange/obj_init_imag{paganin}_{bin}', data=rec.imag)