In [3]:
import cv2
import numpy as np
from skimage import io, morphology
import matplotlib.pyplot as plt
import imageio

In [4]:
# Normalize between 0 and 1, then multiply by 255
def normalize_img(image, dtype=np.uint8):
    image = image.astype(float)
    image -= np.amin(image.astype(float))
    image /= np.amax(image)
    image *= 255
    return image.astype(dtype)

In [5]:
def set_black(image, dtype=np.uint8):
    
    window = image[:250,:250]
    black_pt = np.mean(window)

    # Set all pixels below the black point to zero
    thresholded_image = np.maximum(image, black_pt)

    # Adjust the scale of other pixels
    scale_factor = 255 / (255 - black_pt)
    adjusted_image = (thresholded_image - black_pt) * scale_factor
    adjusted_image = np.clip(adjusted_image, 0, 255).astype(np.uint8)

    return adjusted_image

In [6]:
def sigmoid_adjustment(image, std_mult=1.5, alpha=.5):

    # Get mean and standard deviation of image
    mean_val = np.mean(image)
    std_dev = np.std(image)

    # Set the midpoint for the sigmoid function to 1 standard deviation above the mean * arbitrary multiplier
    midpoint = mean_val + std_dev * std_mult

    # Apply the sigmoid function to adjust image contrast
    # 'alpha' controls the steepness of the sigmoid curve
    adjusted_image = 1 / (1 + np.exp(-alpha * (image - midpoint)))

    # Scale back to 0-255 and convert to uint8
    adjusted_image = np.uint8(255 * (adjusted_image - adjusted_image.min()) / (adjusted_image.max() - adjusted_image.min()))

    return adjusted_image

In [7]:
def mask_slice(image, mask):
    image *= mask
    return image

In [8]:
def threshold_slice(image, threshold=0, cleaned=False, min_size=4):
    canvas = np.ones_like(image)
    image = np.where(image > threshold, canvas, 0)

    if cleaned:
        binary_img = image > 0
        image = morphology.remove_small_objects(binary_img, min_size, connectivity=2)
        
    return image

In [9]:
def resize_stack(input_stack, new_size=(512, 512), kernel_size=5):
    assert len(new_size) == 2, "new_size must be a len=2 tuple or list."
    assert len(input_stack.shape) == 3, "input_array array must be 3-dimensional"

    # Pre-allocate the resized array with the new dimensions
    resized_stack = np.zeros((input_stack.shape[0], new_size[1], new_size[0]), dtype=input_stack.dtype)

    # Adjust pad_width to account for larger kernels properly
    pad_width = kernel_size // 2

    # Iterate over each timepoint and each slice
    for z in range(input_stack.shape[0]):  # Loop over slices
        # COnvert to np.uint8, then resize slice and add it to new array

        slice = input_stack[z,:,:].astype(np.uint8)
        slice = cv2.resize(input_stack[z,:,:], new_size, interpolation=cv2.INTER_LINEAR)
        resized_stack[z,:,:] = slice

    return resized_stack

In [10]:
# Function to apply an arbitrary process to each slice in a stack
def process_stack(input_stack, process_func, *args, **kwargs):

    output_stack = np.zeros_like(input_stack)

    for z in range(input_stack.shape[0]):
        slice = input_stack[z,:,:]
        output_stack[z,:,:] = process_func(slice, *args, **kwargs)

    return output_stack

In [11]:
def visualize_slices(processed_stack, preview_slices):

    # Determine the number of subplots needed
    num_slices = len(preview_slices)
    num_rows = int(np.ceil(np.sqrt(num_slices)))
    num_cols = int(np.ceil(num_slices / num_rows))
    
    # Create subplots
    fig, axs = plt.subplots(num_rows, num_cols, figsize=(12, 12))
    axs = axs.ravel()  # Flatten the axis array if necessary

    # Display each slice in its subplot
    for i, slice_idx in enumerate(preview_slices):
        slice_img = processed_stack[slice_idx, :, :]
        axs[i].imshow(slice_img, cmap='gray')
        axs[i].set_title(f'Slice {slice_idx}')
        axs[i].axis('off')  # Optional: remove axes for cleaner look

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.show()

In [12]:
# Load image stack
stack = io.imread('stacks/exp240202_01_E.tif')
input_stack = stack[0,:,:,:]
del stack

In [13]:
# Normalize between 0-255
norm_stack = process_stack(input_stack, normalize_img)
#input_stack = process_stack(input_stack, cv2.convertScaleAbs, alpha=1, beta=-50)


In [14]:
# Set black point to average background color
black_stack = process_stack(norm_stack, set_black)
print(np.amin(black_stack[16,:,:]))
print(np.amax(black_stack[16,:,:]))

0
255


In [15]:
# Pull out signal using sigmoid
processed_stack = process_stack(black_stack, sigmoid_adjustment, std_mult=3, alpha=1) 

In [16]:
# Visualize
# preview_slices = [16, 32, 64, 96]
# visualize_slices(processed_stack, preview_slices)


### Create MIP Mask

In [17]:
mip_image = np.max(input_stack, axis=0)
canvas = np.ones_like(mip_image)
mip_mask = np.where(normalize_img(mip_image) >= 4, canvas, 0)
np.amax(mip_mask)

1

In [18]:
# fig, ax = plt.subplots(figsize=(6, 6))
# ax.imshow(mip_mask, cmap='gray')

In [19]:
# Remove small objects (optional)
from skimage import morphology

binary_img = mip_mask > 0
cleaned_mask = morphology.remove_small_objects(binary_img, min_size=48, connectivity=4)

# Display the result
# fig, ax = plt.subplots(figsize=(12, 12))
# ax.imshow(cleaned_mask, cmap='gray')

### Final Threshold

In [20]:
masked_stack = process_stack(processed_stack, mask_slice, mask=mip_mask)

In [21]:
# visualize_slices(masked_stack, preview_slices)

In [22]:
thresh_stack = process_stack(masked_stack, threshold_slice, threshold=64, min_size=4)

In [23]:
# visualize_slices(thresh_stack, preview_slices)

### Remove Floating Regions

In [24]:
from scipy.ndimage import label

def filter_fuzzy_regions(image):
    labeled, num_features = label(image)
    output = np.zeros_like(image)

    # Analyze each region
    for i in range(1, num_features + 1):
        region = (labeled == i)
        if np.mean(region[image == 1]) > 0.95 or np.mean(region[image == 0]) > 0.95:
            output[region] = image[region]

    return output


In [25]:
# Remove fuzzy regions from each slice
#thresh_stack = process_stack(thresh_stack, filter_fuzzy_regions)

In [26]:
# Try removing small objects on full 3d stack
thresh_stack = morphology.remove_small_objects(thresh_stack > 0, min_size=48000, connectivity=26) # 6 = face connectivity, 18 = face + edge, 26 = face + edge + corner

In [27]:
# visualize_slices(thresh_stack, preview_slices)

In [28]:
# Convert bool to uint8
thresh_stack = thresh_stack.astype(np.uint8)
np.amax(thresh_stack)

1

In [29]:
# Export tiff stack
# with imageio.get_writer('thresh_stack.tif', format='TIFF', mode='I') as writer:
#     for slice in thresh_stack:
#         writer.append_data(slice*255)

In [30]:
small_stack = resize_stack(thresh_stack, (256,256))

In [31]:
small_stack.shape

(188, 256, 256)

In [38]:
import numpy as np
from skimage import io, filters, measure, morphology
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
from scipy import ndimage as ndi

# Load your 3D data
# Assuming 'data' is a 3D numpy array with 1s for the neuron structure and 0s for the background
data = small_stack  # Replace with your actual data loading method

# Step 1: Preprocessing
# Apply Gaussian smoothing to reduce noise
smoothed_data = filters.gaussian(data, sigma=1)

# Step 2: Compute the Distance Transform
distance = ndi.distance_transform_edt(smoothed_data)

# Step 3: Identify Markers
# Find local maxima in the distance transform
local_maxi = peak_local_max(distance, indices=False, footprint=np.ones((3, 3, 3)), labels=smoothed_data)

# Convert boolean mask to integer labels
markers = measure.label(local_maxi)

# Step 4: Apply the Watershed Algorithm
labels = watershed(-distance, markers, mask=smoothed_data.astype(bool))

# Visualize the result (for example, using a maximum intensity projection)
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(np.max(data, axis=0), cmap='gray')
ax[0].set_title('Original Data')
ax[1].imshow(np.max(distance, axis=0), cmap='gray')
ax[1].set_title('Distance Transform')
ax[2].imshow(np.max(labels, axis=0), cmap='nipy_spectral')
ax[2].set_title('Watershed Segmentation')
plt.show()


TypeError: peak_local_max() got an unexpected keyword argument 'indices'

Number of features: 724


In [None]:
import plotly.graph_objects as go

# Prepare data for plotting
labeled_skeleton = np.transpose(labeled_skeleton, (1, 2, 0))

# Step 5: Plot the result using Plotly
def plot_3d_structure(labeled_skeleton):
    x, y, z = np.where(labeled_skeleton > 0)
    labels = labeled_skeleton[x, y, z]

    fig = go.Figure()

    for label in np.unique(labels):
        mask = (labels == label)
        fig.add_trace(go.Scatter3d(
            x=x[mask],
            y=y[mask],
            z=z[mask],
            mode='markers',
            marker=dict(
                size=4,
                color=label,  # Color by label
                colorscale='Viridis',  # Use a colorscale to distinguish branches
                opacity=0.8
            ),
            name=f'Branch {label}'
        ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        title='3D Branched Structure with Labeled Branches'
    )

    fig.show()

plot_3d_structure(labeled_skeleton)