# Step 2. Preprocessing data in h5 file: remove outliers, adjust amplitudes.

In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import cupyx.scipy.ndimage as ndimage
from holotomocupy.utils import *


# Init

In [None]:
ntheta = 4500
ids = np.arange(0,4500,4500/ntheta).astype('int')

z1_ids = np.array([0,1,2,3])
str_z1_ids = ''.join([str(x) for x in z1_ids[:]+1])
ndist = len(z1_ids)


path_out = '/data2/vnikitin/brain_rec/20251115/Y350a'
file_out = f'data{str_z1_ids}.h5'

with h5py.File(f'{path_out}/{file_out}') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focustodetectordistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:] 
    shifts = fid['/exchange/shifts'][ids]
    attrs = fid['/exchange/attrs'][ids]
    shape = np.array(fid[f'/exchange/data0'].shape)
    shape_ref = fid['/exchange/data_white_start0'].shape
    shape_dark = fid['/exchange/data_dark0'].shape    
    

## sizes
n = shape[1]
ndark = shape_dark[0]
nref = shape_ref[0]

print(f'{z1=}')
print(f'{focustodetectordistance=}')
print(f'{detector_pixelsize=}')


### Remove outliers function

In [None]:
def remove_outliers(data, dezinger, dezinger_threshold):    
    res = data.copy()
    w = [dezinger,dezinger]
    for k in range(data.shape[0]):
        data0 = cp.array(data[k])
        fdata = ndimage.median_filter(data0, w)
        # print(np.sum(np.abs(data0-fdata)>fdata*dezinger_threshold))
        res[k] = np.where(np.abs(data0-fdata)>fdata*dezinger_threshold, fdata, data0).get()
    return res

### Read ref and dark

In [None]:
ref0 = np.empty([nref,ndist,n,n],dtype='float32')
ref1 = np.empty([nref,ndist,n,n],dtype='float32')
dark = np.empty([ndark,ndist,n,n],dtype='float32')
with h5py.File(f'{path_out}/{file_out}') as fid:
    for k in range(ndist):
        ref0[:,k] = fid[f'/exchange/data_white_start{k}'][:]
        ref1[:,k] = fid[f'/exchange/data_white_end{k}'][:]
        dark[:,k] = fid[f'/exchange/data_dark{k}'][:]
        

### Remove outliers

In [None]:
ref = ref0#(ref0+ref1)*0.5
dark = np.mean(dark,axis=0)
ref = np.mean(ref,axis=0)
ref-=dark
ref[ref<0]=0
radius = 3
threshold = 0.9
ref[:] = remove_outliers(ref[:], radius, threshold)     

#### Normalization

In [None]:
mean_data_ref = np.zeros(ndist,dtype='float32')
with h5py.File(f'{path_out}/{file_out}','a') as fid:    
    for k in range(ndist):
        if f'/exchange/pdata{k}' in fid:
            del fid[f'/exchange/pdata{k}']
        data_out = fid.create_dataset(f'/exchange/pdata{k}', shape = shape)   
        for j in range(1):
            data = fid[f'/exchange/data{k}'][ids[j]].astype('float32')
            data-=dark[k]
            data[data<0]=0
            data = remove_outliers(data[None], radius, threshold)[0]
            mean_data_ref[k] = np.mean(data)
            

In [None]:
### counts before scan
mmr = np.mean(ref,axis=(1,2))
# scale mean of first projection based on that
mean_data_ref *= mmr[0]/mmr[:]
# scale ref based on that
ref *= mmr[0]/mmr[:,None,None]

mean_data_ref/=mmr[0]
ref/=mmr[0]

# ref /= mean_data_ref[:,None,None]
print(np.mean(ref,axis=(1,2)))
# ss


In [None]:
with h5py.File(f'{path_out}/{file_out}','a') as fid:    
    if f'/exchange/pref' in fid:
        del fid[f'/exchange/pref']
    fid.create_dataset(f'/exchange/pref', data=ref)
    

### Process data 

In [None]:
with h5py.File(f'{path_out}/{file_out}','a') as fid:    
    for k in range(ndist):
        if f'/exchange/pdata{k}' in fid:
            del fid[f'/exchange/pdata{k}']
        data_out = fid.create_dataset(f'/exchange/pdata{k}', shape = shape)   
        for j in range(ntheta):
            data = fid[f'/exchange/data{k}'][ids[j]].astype('float32')
            data-=dark[k]
            data[data<0]=0
            data = remove_outliers(data[None], radius, threshold)[0]
            data=data/np.mean(data)*mean_data_ref[k]
            data_out[j] = data
            if j%100==0:
                print(j,k,np.mean(data))