In [None]:
#@title 🌐 🐍 Imports and Installs
!pip install uv
!uv pip install git+https://github.com/eigenP/utils.git

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage.data import cells3d
from eigenp_utils.plotting_utils import hist_imshow
from eigenp_utils.intensity_rescaling import adjust_gamma_per_slice

In [None]:
# Load 3D data (Nuclei channel)
# Shape: (Z, C, Y, X) -> we take channel 1 (Nuclei) -> (Z, Y, X)
image_raw = cells3d()[:, 1, ...]
print(f"Loaded image with shape: {image_raw.shape}")

In [None]:
#@title ## Test `hist_imshow`
print("Testing hist_imshow...")
# hist_imshow now returns a dictionary {'fig': ..., 'axes': ...}
res = hist_imshow(image_raw)
res['fig'].show()

### Test gamma correction by degrading the cells3d and then correcting it again

In [None]:
# Simulate signal decay along Z axis
print("Simulating signal decay...")
z_slices = image_raw.shape[0]

# Create exponential decay curve
# Decay from 1.0 down to exp(-1.5) ~= 0.22
decay_factor = np.exp(-np.linspace(0, 1.5, z_slices))

# Apply decay
image_degraded = image_raw * decay_factor[:, None, None]

# Add some noise
rng = np.random.default_rng(42)
noise = rng.normal(0, 10, image_degraded.shape)
image_degraded = image_degraded + noise

# Clip to valid range and cast back to original dtype
image_degraded = np.clip(image_degraded, 0, 65535).astype(image_raw.dtype)

print("Displaying degraded image...")
res_degraded = hist_imshow(image_degraded)
res_degraded['fig'].show()

In [None]:
# Restore signal with adjust_gamma_per_slice
print("Restoring signal with adjust_gamma_per_slice (exponential fit)...")

# 'gamma_fit_func'='exponential' fits the decay and adjusts gamma per slice to compensate
image_restored = adjust_gamma_per_slice(image_degraded, gamma_fit_func='exponential')

print("Displaying restored image...")
res_restored = hist_imshow(image_restored)
res_restored['fig'].show()

In [None]:
# Compare Z-intensity profiles
print("Plotting Z-intensity profiles...")

mean_raw = image_raw.mean(axis=(1, 2))
mean_degraded = image_degraded.mean(axis=(1, 2))
mean_restored = image_restored.mean(axis=(1, 2))

plt.figure(figsize=(10, 6))
plt.plot(mean_raw, label='Original', linewidth=2, alpha=0.7)
plt.plot(mean_degraded, label='Degraded (Simulated)', linestyle='--', linewidth=2)
plt.plot(mean_restored, label='Restored (Corrected)', linestyle='-.', linewidth=2)

plt.legend()
plt.title('Z-Axis Mean Intensity Profile')
plt.xlabel('Z Slice')
plt.ylabel('Mean Intensity')
plt.grid(True, alpha=0.3)
plt.show()