In [None]:
import cv2
import numpy as np
from skimage import io, morphology
import matplotlib.pyplot as plt
import imageio

In [None]:
# Normalize between 0 and 1, then multiply by 255
def normalize_img(image, dtype=np.uint8):
    image = image.astype(float)
    image -= np.amin(image.astype(float))
    image /= np.amax(image)
    image *= 255
    return image.astype(dtype)

In [None]:
def set_black(image, dtype=np.uint8):
    
    window = image[:250,:250]
    black_pt = np.mean(window)

    # Set all pixels below the black point to zero
    thresholded_image = np.maximum(image, black_pt)

    # Adjust the scale of other pixels
    scale_factor = 255 / (255 - black_pt)
    adjusted_image = (thresholded_image - black_pt) * scale_factor
    adjusted_image = np.clip(adjusted_image, 0, 255).astype(np.uint8)

    return adjusted_image

In [None]:
def sigmoid_adjustment(image, std_mult=1.5, alpha=.5):

    # Get mean and standard deviation of image
    mean_val = np.mean(image)
    std_dev = np.std(image)

    # Set the midpoint for the sigmoid function to 1 standard deviation above the mean * arbitrary multiplier
    midpoint = mean_val + std_dev * std_mult

    # Apply the sigmoid function to adjust image contrast
    # 'alpha' controls the steepness of the sigmoid curve
    adjusted_image = 1 / (1 + np.exp(-alpha * (image - midpoint)))

    # Scale back to 0-255 and convert to uint8
    adjusted_image = np.uint8(255 * (adjusted_image - adjusted_image.min()) / (adjusted_image.max() - adjusted_image.min()))

    return adjusted_image

In [None]:
def mask_slice(image, mask):
    image *= mask
    return image

In [None]:
def threshold_slice(image, threshold=0, cleaned=False, min_size=4):
    canvas = np.ones_like(image)
    image = np.where(image > threshold, canvas, 0)

    if cleaned:
        binary_img = image > 0
        image = morphology.remove_small_objects(binary_img, min_size, connectivity=2)
        
    return image

In [None]:
# Function to apply an arbitrary process to each slice in a stack
def process_stack(input_stack, process_func, *args, **kwargs):

    output_stack = np.zeros_like(input_stack)

    for z in range(input_stack.shape[0]):
        slice = input_stack[z,:,:]
        output_stack[z,:,:] = process_func(slice, *args, **kwargs)

    return output_stack

In [None]:
def visualize_slices(processed_stack, preview_slices):

    # Determine the number of subplots needed
    num_slices = len(preview_slices)
    num_rows = int(np.ceil(np.sqrt(num_slices)))
    num_cols = int(np.ceil(num_slices / num_rows))
    
    # Create subplots
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    axs = axs.ravel()  # Flatten the axis array if necessary

    # Display each slice in its subplot
    for i, slice_idx in enumerate(preview_slices):
        slice_img = processed_stack[slice_idx, :, :]
        axs[i].imshow(slice_img, cmap='gray')
        axs[i].set_title(f'Slice {slice_idx}')
        axs[i].axis('off')  # Optional: remove axes for cleaner look

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.show()

In [None]:
# Load image stack
stack = io.imread('stacks/exp240202_01_E.tif')
input_stack = stack[0,:,:,:]
del stack

In [None]:
# Normalize between 0-255
norm_stack = process_stack(input_stack, normalize_img)
#input_stack = process_stack(input_stack, cv2.convertScaleAbs, alpha=1, beta=-50)


In [None]:
# Set black point to average background color
black_stack = process_stack(norm_stack, set_black)
print(np.amin(black_stack[16,:,:]))
print(np.amax(black_stack[16,:,:]))

In [None]:
# Pull out signal using sigmoid
processed_stack = process_stack(black_stack, sigmoid_adjustment, std_mult=3, alpha=1) 

In [None]:
# Visualize
preview_slices = [16, 32, 64, 96]
visualize_slices(processed_stack, preview_slices)


### Create MIP Mask

In [None]:
mip_image = np.max(input_stack, axis=0)
canvas = np.ones_like(mip_image)
mip_mask = np.where(normalize_img(mip_image) >= 4, canvas, 0)
np.amax(mip_mask)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(mip_mask, cmap='gray')

In [None]:
# Remove small objects (optional)
from skimage import morphology

binary_img = mip_mask > 0
cleaned_mask = morphology.remove_small_objects(binary_img, min_size=48, connectivity=4)

# Display the result
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(cleaned_mask, cmap='gray')

### Final Threshold

In [None]:
masked_stack = process_stack(processed_stack, mask_slice, mask=mip_mask)

In [None]:
visualize_slices(masked_stack, preview_slices)

In [None]:
thresh_stack = process_stack(masked_stack, threshold_slice, threshold=64, min_size=4)

In [None]:
visualize_slices(thresh_stack, preview_slices)

### Remove Floating Regions

In [None]:
from scipy.ndimage import label

def filter_fuzzy_regions(image):
    labeled, num_features = label(image)
    output = np.zeros_like(image)

    # Analyze each region
    for i in range(1, num_features + 1):
        region = (labeled == i)
        if np.mean(region[image == 1]) > 0.95 or np.mean(region[image == 0]) > 0.95:
            output[region] = image[region]

    return output


In [None]:
# Remove fuzzy regions from each slice
#thresh_stack = process_stack(thresh_stack, filter_fuzzy_regions)

In [None]:
# Try removing small objects on full 3d stack
thresh_stack = morphology.remove_small_objects(thresh_stack > 0, min_size=48000, connectivity=26) # 6 = face connectivity, 18 = face + edge, 26 = face + edge + corner

In [None]:
visualize_slices(thresh_stack, preview_slices)

In [None]:
# Convert bool to uint8
thresh_stack = thresh_stack.astype(np.uint8)
np.amax(thresh_stack)

In [None]:
# Export tiff stack
with imageio.get_writer('thresh_stack.tif', format='TIFF', mode='I') as writer:
    for slice in thresh_stack:
        writer.append_data(slice*255)

In [None]:
from scipy.ndimage import label

# Use a more restrictive connectivity structure
structure = np.array([
    [[0, 0, 0], [0, 1, 0], [0, 0, 0]],
    [[0, 1, 0], [1, 1, 1], [0, 1, 0]],
    [[0, 0, 0], [0, 1, 0], [0, 0, 0]]
])
labeled_array, num_features = label(thresh_stack, structure=structure)

print("Number of distinct branches:", num_features)
print("Labeled Array:", labeled_array)

In [None]:
plt.imshow(labeled_array[16,:,:])  # Change slice index as needed
plt.title("Sample Slice")
plt.show()

In [None]:
from skimage.morphology import skeletonize_3d
from skimage.measure import label

# Apply skeletonization to reduce to a minimal connected structure
skeleton = skeletonize_3d(thresh_stack)

# Label the skeletonized image
labeled_skeleton = label(skeleton, connectivity=3)

In [None]:
labeled_skeleton.shape

In [None]:
plt.imshow(labeled_skeleton[32,:,:])  # Change slice index as needed
plt.title("Sample Slice")
plt.show()