In [None]:
# Imports

from pathlib import Path
import pandas as pd

from sklearn.pipeline import make_pipeline

import bioblue as bb
from bioblue import fibers
from bioblue.plot import cm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as plt_cm
# from skimage.transform import rotate
from scipy.ndimage import rotate
from skimage.filters import threshold_otsu
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from PIL import Image
from itkwidgets import view
from ipywidgets import interact, interactive
import ipywidgets as widgets
from scipy.interpolate import interpn
import scipy.signal as signal
from tqdm.auto import tqdm
import matplotlib.patches as patches

from numba import njit
plt.rcParams["figure.figsize"] = (10,10)
import plotly.graph_objects as go

In [None]:
# Function definitions, 

def orientation(crop, num=100):
    """ Find the orientation of sheets (or fiber) in a volume.
    
        This happens in the following steps:
        (1) Compute the discrete Fourier transform
        (2) Remove zero frequencies and operate a shift
            to bring to center of volume
        (3) Use only magnitude of signal, and take logarithm
        (4) Take num largest frequencies
        (5) fit data with PCA
        
        Parameters:
            crop: 3d volume of arbitrary size
            num: number of pixels to keep after filtering
        Returns:
            None: oops, backward compatibility
            components: pca components, 3x3 matrix (can be used as basis for rotation)
            variance: explained variance of pca fitting
            ft_filtered: 3d volume of size of crop with 1 for the largest frequencies
            ft_logabs: unfiltered fourier transform
            crop_data: coordinates of the largest frequencies
    """
    ft = np.fft.fftn(crop, axes=(-3,-2,-1))
    ft[0, :, :] = 1
    ft[:, 0, :] = 1
    ft[:, :, 0] = 1
    ft = np.fft.fftshift(ft)
    ft_logabs = np.log(np.abs(ft))
    data = np.unravel_index(np.argsort(-ft_logabs, axis=None)[:num], shape=crop.shape)
    ft_filtered = np.zeros_like(crop)
    ft_filtered[data[0], data[1], data[2]] = 1
    # ft_filtered = ft_logabs > minimum
    if len(data[0]) == 0:
        raise NotImplementedError()
    
    pca = make_pipeline(
        StandardScaler(with_mean=True, with_std=False), PCA(n_components=3)
    )
    crop_data = np.vstack(list(data)).T
    mean = np.mean(crop_data, axis=0)
    mean
    pca = pca.fit(crop_data)
    pca: PCA = pca.named_steps["pca"]
    return (
        None,
        pca.components_,
        (pca.explained_variance_, pca.explained_variance_ratio_),
        ft_filtered,
        ft_logabs,
        crop_data
    )

def volume_slicer(volume):
        
    @interact(axis=widgets.IntSlider(2,0,2), i=widgets.IntSlider(0,0,volume.shape[0]-1,continuous_update=False))
    def _vol_slicer(axis=0, i=0):
        fig, ax = plt.subplots(figsize=(10,10))
        ax.imshow(np.take(volume, i, axis))
        plt.show()
        plt.close(fig)

In [None]:
# Reading from file, run only once, takes a long time

ds_path = Path("../data/PA_fibers/train/image")
volume_path = list(ds_path.iterdir())[0] # Only one image in dataset
volume_npz = np.load(volume_path)

# CHANGE ME : grid_size, the distance between points in the grid
grid_size = 200
# CHANGE ME : crop_size, the size of the 3d (cubic) crop
crop_size = 256
# CHANGE ME ? border, removes unnecessary parts of the volume
border = 400

zsize = len(volume_npz.files)
xsize, ysize = volume_npz[volume_npz.files[0]].shape
xsize, ysize, zsize
x0roi, x1roi = (border, xsize-border)
y0roi, y1roi = (border, ysize - border)
z0roi, z1roi = (border, zsize - border)
# Where can we crop ?
xx, yy, zz = np.mgrid[
    x0roi + crop_size // 2 : x1roi - crop_size // 2 : grid_size,
    y0roi + crop_size // 2 : y1roi - crop_size // 2 : grid_size,
    z0roi + crop_size // 2 : z1roi - crop_size // 2 : grid_size,
]
xx, yy, zz = xx.flatten(), yy.flatten(), zz.flatten()

# Reading the whole volume in memory
volume = np.zeros((*volume_npz["arr_0"].shape, zsize))
for i, file in enumerate(tqdm(volume_npz.files)):
    image = volume_npz[file]
    volume[:, :, i] = image
print(volume.shape)

In [None]:
xx.size, yy.size, zz.size

In [None]:
# CHANGE ME : Choose your crop, from 0 to size-1
x, y, z = xx[1], yy[1], zz[0]

fig, ax = plt.subplots()
ax.imshow(volume[:,:,zz[0]])
rect = patches.Rectangle((y-crop_size//2, x-crop_size//2), crop_size, crop_size, facecolor="none", linewidth=1, edgecolor='r')
ax.scatter([y], [x], s=5, c='red')
ax.add_patch(rect)

In [None]:
crop_volume = np.zeros((crop_size, crop_size, crop_size))
for i, file in enumerate(volume_npz.files[z - crop_size // 2 : z + crop_size // 2]):
    image = volume_npz[file]
    crop_image = image[
        x - crop_size // 2 : x + crop_size // 2,
        y - crop_size // 2 : y + crop_size // 2,
    ]
    crop_volume[:, :, i] = crop_image

volume_slicer(crop_volume)


In [None]:
# Find orientation of one crop

num = 1000
angle, components, variance, ft_filtered, ft, crop_data = orientation(crop_volume, num=num)

In [None]:
# Display filtered 3d Fourier transform 

sample_size = len(ft)
perc = 100 - 100 * (500 / sample_size)

c = crop_size // 2

plotly_fig = go.Figure(
    data=[
        go.Scatter3d(
            x=crop_data[:, 0],
            y=crop_data[:, 1],
            z=crop_data[:, 2],
            mode="markers",
            marker=dict(size=2, opacity=0.5, color="black"),
        ),
        go.Scatter3d(
            x=[c, c + 10 * components[0, 0]],
            y=[c, c + 10 * components[0, 1]],
            z=[c, c + 10 * components[0, 2]],
            line=dict(width=4, color="red"),
            mode="lines",
            marker=dict(size=0),
        ),
        go.Scatter3d(
            x=[c, c + 10 * components[1, 0]],
            y=[c, c + 10 * components[1, 1]],
            z=[c, c + 10 * components[1, 2]],
            line=dict(width=4, color="green"),
            mode="lines",
            marker=dict(size=0),
        ),
        go.Scatter3d(
            x=[c, c + 10 * components[2, 0]],
            y=[c, c + 10 * components[2, 1]],
            z=[c, c + 10 * components[2, 2]],
            line=dict(width=4, color="blue"),
            mode="lines",
            marker=dict(size=0),
        ),
    ]
)

plotly_fig.show()

In [None]:
def rotate_crop(crop_shape, loc, volume, basis):
    """ Get rotated crop from complete volume.
        
        Inputs:
            crop_shape: 3-tuple for crop shape
            loc: location of center inside volume
            basis: matrix with values of the new x,y,z coordinate system
        Returns:
            rot_crop: rotated crop taken from volume
    """
    # step 1 : create crop grid
    sc = crop_shape
    xxc,yyc,zzc = np.mgrid[-sc[0]//2:sc[0]//2,-sc[1]//2:sc[1]//2, -sc[2]//2:sc[2]//2]
    xxc,yyc,zzc = xxc.flatten(), yyc.flatten(), zzc.flatten()
    crop_grid = np.stack([xxc,yyc, zzc])
    # step 2 : rotate cropped grid to new coordinate system
    crop_grid_rot = np.linalg.solve(basis, crop_grid)
    # step 3 : move grid to correct location
    crop_grid_rot = (crop_grid_rot.T + loc).T # check if this is really correct
    # step 5 : interpolate
    s = volume.shape
    rot_crop = interpn((np.arange(0, s[0]), np.arange(0, s[1]), np.arange(0, s[2])),
                      volume,
                      crop_grid_rot.T,
                      bounds_error=True).reshape(sc)
    
    return rot_crop

rot_vol = rotate_crop(crop_volume.shape, np.array((x,y,z)), volume, components)

In [None]:
def find_peaks_in_volume(vol, rel_height=0.5, height=110, flat=False, segment_valleys=False):
    fiber_points = np.zeros_like(vol)
    xx, yy = np.mgrid[0:vol.shape[1], 0:vol.shape[2]]
    xx, yy = xx.flatten(), yy.flatten()
    for i,j in zip(xx, yy):
        x = vol[:,i,j]
        peaks, _ = signal.find_peaks(x, height=height)
        fiber_points[peaks, i, j] = 1
    
    return fiber_points

In [None]:
volume_slicer(rot_vol)

In [None]:
fiber_points = find_peaks_in_volume(rot_vol)

In [None]:
volume_slicer(fiber_points)

In [None]:
v = view(fiber_points[0:10], size_limit_3d=fiber_points.shape)

# @interact(stop=widgets.Play(interval=1000, value=0, min=0, max=fiber_points.shape[0]-1))
def view_slice(stop):
    v.image = fiber_points[stop:stop+5]
play = widgets.Play(interval=250, value=5, min=0, max=fiber_points.shape[0]-1)
slice_view = interactive(view_slice, stop=play)
# display(v)
progress = widgets.IntProgress(value=5, min=0, max=fiber_points.shape[0]-1)
widgets.jslink((play,'value'), (progress, 'value'))
widgets.VBox([widgets.HBox([slice_view, progress]), v])

In [None]:
view(fiber_points)

In [None]:
widgets.Play?

In [None]:
volume_slicer(crop_volume)