# Step 4. Make binned data for processing by levels, also adjust amplitudes after data stitching for different distances.

In [None]:
import numpy as np
import cupy as cp
import h5py
from holotomocupy.shift import Shift
from holotomocupy.utils import *

## Init

In [None]:
ntheta = 4500
show = True
rotation_center_shift = -8.780113000000028
ids = np.arange(0, 4500, 4500 / ntheta).astype('int')


z1_ids = np.array([0,1,2,3])
str_z1_ids = ''.join(map(str, z1_ids + 1))
ndist = len(z1_ids)


path_out = '/data2/vnikitin/brain_rec/20251115/Y350a'
file_out = f'data{str_z1_ids}.h5'

with h5py.File(f'{path_out}/{file_out}') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focusToDetectorDistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:] 
    energy = fid['/exchange/energy'][0] 
    shifts = fid['/exchange/shifts'][ids]
    attrs = fid['/exchange/attrs'][ids]
    shape = np.array(fid[f'/exchange/data0'].shape)
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape    

z2 = focusToDetectorDistance - z1
magnifications = focusToDetectorDistance / z1
norm_magnifications = magnifications / magnifications[0]
distances = (z1 * z2) / focusToDetectorDistance * norm_magnifications**2
voxelsize = detector_pixelsize / magnifications[0]

n = shape[1] 
npsi = int(np.ceil(n / norm_magnifications[-1] / 64)) * 64 

print(f'{energy=}')
print(f'{z1=}')
print(f'{focusToDetectorDistance=}')
print(f'{detector_pixelsize=}')
print(f'{magnifications=}')
print(f'{voxelsize=}')
print(f'{distances=}')

### Read data and write binned ref

In [None]:
nlevels = 4 # number of bin levels

with h5py.File(f'{path_out}/data{str_z1_ids}.h5','a') as fid:       
    ref = fid[f'/exchange/pref'][:ndist]
    ref0 = ref.copy()
    for bin in range(nlevels):
        if f'/exchange/pref_{bin}' in fid:
            del fid[f'/exchange/pref_{bin}']
        fid.create_dataset(f'/exchange/pref_{bin}',data=ref0)
        ref0 = 0.5 * (ref0[..., ::2] + ref0[..., 1::2])
        ref0 = 0.5 * (ref0[..., ::2, :] + ref0[..., 1::2, :])        
            
    r = (fid[f'/exchange/cshifts_final'][ids] * n / 2048).astype('float32')

    ### compensate for the rotation center shift
    s = rotation_center_shift    
    r[...,1] += s

In [None]:
print(np.amax(np.abs(r)))

### Alignment and intensity correction, save to h5

In [None]:
cl_shift = Shift(n, npsi,n,npsi, 1 / norm_magnifications)
distances_pag = distances / norm_magnifications**2
npad = n // 16
cref = cp.array(ref)
with h5py.File(f'{path_out}/data{str_z1_ids}.h5','a') as fid:
    data_out = [None]*nlevels
    for bin in range(nlevels):
        for k in range(ndist):
            if f'/exchange/pdata{k}_{bin}' in fid:
                del fid[f'/exchange/pdata{k}_{bin}']
        data_out[bin] = []
        for k in range(ndist):
            data_out[bin].append(fid.create_dataset(f'/exchange/pdata{k}_{bin}', shape = [ntheta,n//2**bin,n//2**bin]))
    
    srdata = cp.zeros([ndist, npsi, npsi], dtype='float32')
    for j in range(ntheta):
        data = cp.empty([ndist,n,n],dtype='float32')
        for k in range(ndist):
            data[k] = cp.array(fid[f'/exchange/pdata{k}'][ids[j]])
            
        for t in range(2):                    
            rdata = data / (cref + 1e-5)
            
            for k in range(ndist - 1, -1, -1):
                tmp = rdata[k].astype('complex64')            
                tmp = cl_shift.curlySadj(tmp[None], cp.array(r[j:j+1, k]), k)[0].real
                tmp /= norm_magnifications[k]**2
                padx0 = int((npsi - n / norm_magnifications[k]) / 2) - int(r[j,k,1])  
                pady0 = int((npsi - n / norm_magnifications[k]) / 2) - int(r[j,k,0])
                padx1 = int((npsi - n / norm_magnifications[k]) / 2) + int(r[j,k,1])  
                pady1 = int((npsi - n / norm_magnifications[k]) / 2) + int(r[j,k,0])
                padx0 = min(npsi,max(0,padx0))
                pady0 = min(npsi,max(0,pady0))
                padx1 = min(npsi,max(0,padx1))
                pady1 = min(npsi,max(0,pady1))
                padx0+=2
                pady0+=2
                padx1+=2
                pady1+=2
                
                # print(padx0,padx1,pady0,pady1)
                tmp = cp.pad(tmp[pady0:-pady1], ((pady0,pady1),(0,0)),'edge')
                tmp = cp.pad(tmp[:,padx0:-padx1], ((0,0),(padx0,padx1)),'linear_ramp',end_values=((1, 1), (1, 1)) )
                if k < ndist - 1:                    
                    wx = cp.ones([npsi], dtype='float32')
                    wy = cp.ones([npsi], dtype='float32')
                    
                    v = cp.linspace(0, 1, npad, endpoint=False)
                    v = v**5 * (126 - 420 * v + 540 * v**2 - 315 * v**3 + 70 * v**4)
                    wx[:padx0] = 0
                    wx[padx0:padx0+npad] = v
                    wx[-padx1-npad:-padx1] = 1 - v
                    wx[-padx1:] = 0
                    wy[:pady0] = 0
                    wy[pady0:pady0+npad] = v
                    wy[-pady1-npad:-pady1] = 1 - v
                    wy[-pady1:] = 0
                    
                    w = cp.outer(wy, wx)
                    tmp = tmp * w + srdata[k+1] * (1 - w)
                srdata[k] = tmp
                # mshow(srdata[k],True)
                
            st, end = npsi // 4 + npsi // 8, 3 * npsi // 4 - npsi // 8
            mmean = cp.mean(srdata[:, st:end, st:end], axis=(1, 2)) / cp.mean(srdata[0, st:end, st:end])
            data /= mmean[..., None, None]        
            if j%100==0 and t==1:
                print(j,mmean)    
                mshow_complex(srdata[0]+1j*srdata[ndist-1],show)
                mshow_complex(srdata[0,1000:1000+512,:512]+1j*srdata[ndist-1,:512,:512],show)
                # mshow(srdata[0]-srdata[-1],show,vmax=0.2,vmin=-0.2)
            
            for k in range(ndist):  
                datak = data[k]              
                for bin in range(nlevels):
                    data_out[bin][k][j] = datak.get()
                    datak = 0.5*(datak[::2,:]+datak[1::2,:])
                    datak = 0.5*(datak[:,::2]+datak[:,1::2])
            
