# Step 3. Find shifts. Currently they are adjusted from the ones Peter got.

In [None]:
import numpy as np
import cupy as cp
import h5py
import matplotlib.pyplot as plt
import scipy
from holotomocupy.utils import *

## Init

In [None]:
ntheta = 4500
ids = np.arange(0, 4500, 4500 / ntheta).astype('int')

z1_ids = np.array([0,1,2,3])
str_z1_ids = ''.join([str(x) for x in z1_ids + 1])
ndist = len(z1_ids)

path_out = '/data2/vnikitin/brain_rec/20251115/Y350a'
file_out = f'data{str_z1_ids}.h5'


with h5py.File(f'{path_out}/data{str_z1_ids}.h5') as fid:
    detector_pixelsize = fid['/exchange/detector_pixelsize'][0]    
    focustodetectordistance = fid['/exchange/focusdetectordistance'][0]    
    z1 = fid['/exchange/z1'][:] 
    energy = fid['/exchange/energy'][0] 
    shifts = fid['/exchange/shifts'][ids]    
    shape = np.array(fid[f'/exchange/data0'].shape)

wavelength = 1.24e-09 / energy
z2 = focustodetectordistance - z1
magnifications = focustodetectordistance / z1
norm_magnifications = magnifications / magnifications[0]
distances = (z1 * z2) / focustodetectordistance * norm_magnifications**2
voxelsize = detector_pixelsize / magnifications[0]

n = shape[1]
print(f'{energy=}')
print(f'{z1=}')
print(f'{focustodetectordistance=}')
print(f'{detector_pixelsize=}')
print(f'{magnifications=}')
print(f'{voxelsize=}')
print(f'{distances=}')

### All shifts will be converted to the object pixel coordinates in the highest magnification

In [None]:
random_shifts = np.zeros([ntheta, ndist, 2], dtype='float32')
for k in range(ndist):
    random_shifts[:, k, 0] = shifts[:, k, 1] / norm_magnifications[k]
    random_shifts[:, k, 1] = shifts[:, k, 0] / norm_magnifications[k]

#### Rhapp, alignment between planes. Saved in rhapp.mat file, converted to python.

In [None]:
#save('/data2/vnikitin/brain_rec/20240515/Y350c/rhapp_python.mat','rhapp','-v7')
rhapp_shifts = scipy.io.loadmat(f'{path_out}/rhapp_python{str_z1_ids}.mat')['rhapp']
rhapp_shifts = -rhapp_shifts.swapaxes(0, 2)[:4500]

### Motion shifts, alignment for a reference plane. Initially given with random shifts include, we subtract random shifts.

In [None]:
motion_shifts = np.zeros_like(rhapp_shifts)
for k in range(4):
    motion_shifts[:, k] = np.loadtxt('/data2/vnikitin/brain/20251115/Y350a_HT_20nm_8dist_3_/correct_motion.txt')[:4500][:, ::-1] / norm_magnifications[2]
    motion_shifts[:, k] -= random_shifts[:, 2]

### Correct 3D shifts, alignment before tomography reconstruction

In [None]:
correct3d_shifts = np.loadtxt(f'{path_out}/correct_correct3D_{str_z1_ids}.txt')[:4500][:, ::-1]
correct3d_shifts = np.tile(correct3d_shifts[:, np.newaxis], (1, ndist, 1))

### Visualization

In [None]:
plt.plot(motion_shifts[:,:,1])
plt.plot(motion_shifts[:,:,0])
plt.show()

plt.plot(correct3d_shifts[:,:,1])
plt.plot(correct3d_shifts[:,:,0])
plt.show()

plt.plot(rhapp_shifts[:,:,1])
plt.plot(rhapp_shifts[:,:,0])
plt.show()


plt.plot(random_shifts[:,:,1])
plt.plot(random_shifts[:,:,0])
plt.show()


#### Final shifts is the sum of all.

In [None]:
shifts_final = rhapp_shifts + correct3d_shifts*0 + motion_shifts*0 + random_shifts

## Save to file

In [None]:
with h5py.File(f'{path_out}/data{str_z1_ids}.h5', 'a') as fid:
    if f'/exchange/cshifts_final_nocorr' in fid:
            del fid[f'/exchange/cshifts_final_nocorr']
    fid.create_dataset(f'/exchange/cshifts_final_nocorr', data=shifts_final.astype('float32'))