We scanned 10 fishes for Carolina, let's fiddle with the data.

In [None]:
import platform
import os
import glob
import pandas
import imageio
import numpy
import scipy
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
import seaborn
import dask
import dask.array
import dask_image.imread
import dask_image.ndfilters
from dask.distributed import Client
from numcodecs import Blosc
import skimage
from tqdm import notebook
import statsmodels
import scipy.signal
import sklearn.cluster
from skimage.segmentation import random_walker

In [None]:
# Set dask temporary folder
# Do this before creating a client: https://stackoverflow.com/a/62804525/323100
import tempfile
if 'Linux' in platform.system():
    tmp = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')
elif 'Darwin' in platform.system():
    tmp = tempfile.gettempdir()
else:
    if 'anaklin' in platform.node():
        tmp = os.path.join('F:\\')
    else:
        tmp = os.path.join('D:\\')
dask.config.set({'temporary_directory': os.path.join(tmp, 'tmp')})
print('Dask temporary files to to %s' % dask.config.get('temporary_directory'))

In [None]:
client = Client()
# Then go to http://localhost:8787/status

In [None]:
print('You can seee what DASK is doing at "http://localhost:%s/status"' % client.scheduler_info()['services']['dashboard'])

In [None]:
# https://stackoverflow.com/a/62242245/323100
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Set up figure defaults
plt.rc('image', cmap='gray', interpolation='nearest')  # Display all images in b&w and with 'nearest' interpolation
plt.rcParams['figure.figsize'] = (16, 9)  # Size up figures a bit
plt.rcParams['figure.dpi'] = 300  # Increase dpi

In [None]:
# Setup scale bar defaults
plt.rcParams['scalebar.location'] = 'lower right'
plt.rcParams['scalebar.frameon'] = False
plt.rcParams['scalebar.color'] = 'white'

In [None]:
# Display all plots identically
lines = 2
# And then do something like
# plt.subplot(lines, numpy.ceil(len(Data) / float(lines)), c + 1)

In [None]:
platform.node()

In [None]:
# Different locations if running either on Linux or Windows
if 'anaklin25' in platform.node():
    FastSSD = True
else:
    FastSSD = False
# to speed things up significantly
if 'Linux' in platform.system():
    if FastSSD:
        BasePath = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')
    else:
        BasePath = os.path.join(os.sep, 'home', 'habi', '1272')
elif 'Darwin' in platform.system():
    BasePath = os.path.join('/Users/habi/Data')
else:
    if FastSSD:
        BasePath = os.path.join('F:\\')
    else:
        if 'anaklin' in platform.node():
            BasePath = os.path.join('F:\\')
        else:
            BasePath = os.path.join('D:\\Results')
Root = os.path.join(BasePath, 'Zebrafish_Carolina_Muscles')
print('We are loading all the data from %s' % Root)

In [None]:
def get_git_hash():
    '''
    Get the current git hash from the repository.
    Based on http://stackoverflow.com/a/949391/323100 and
    http://stackoverflow.com/a/18283905/323100
    '''
    from subprocess import Popen, PIPE
    import os
    gitprocess = Popen(['git',
                        '--git-dir',
                        os.path.join(os.getcwd(), '.git'),
                        'rev-parse',
                        '--short',
                        '--verify',
                        'HEAD'],
                       stdout=PIPE)
    (output, _) = gitprocess.communicate()
    return output.strip().decode("utf-8")

In [None]:
# Make directory for output
OutPutDir = os.path.join(os.getcwd(), 'Output', get_git_hash())
print('We are saving all the output to %s' % OutPutDir)
os.makedirs(OutPutDir, exist_ok=True)

In [None]:
# Make us a dataframe for saving all that we need
Data = pandas.DataFrame()

In [None]:
# Get *all* log files in the root folder
Data['LogFile'] = [f for f in sorted(glob.glob(os.path.join(Root, '**', '*.log'),
                                               recursive=True))]

In [None]:
# Get all folders
Data['Folder'] = [os.path.dirname(f) for f in Data['LogFile']]

In [None]:
# Get rid of all non-rec logfiles
for c, row in Data.iterrows():
    if 'rec' not in row.Folder:
        Data.drop([c], inplace=True)
    elif 'SubScan' in row.Folder:
        Data.drop([c], inplace=True)
    elif 'rectmp' in row.LogFile:
        Data.drop([c], inplace=True)
# Reset dataframe to something that we would get if we only would have loaded the 'rec' files
Data = Data.reset_index(drop=True)

In [None]:
print('We found %s subfolders in %s' % (len(Data), Root))

In [None]:
Data['Sample'] = [l[len(Root)+1:].split(os.sep)[0] for l in Data['LogFile']]

In [None]:
def whichexperiment(i):
    '''Categorize into 'WT' and 'KO' '''
    if 'ko' in i:
        return 'KO'
    if 'wt' in i:
        return 'WT'

In [None]:
Data['Experiment'] = [whichexperiment(f) for f in Data['Sample']]

In [None]:
def whichfish(i):
    '''Give each fish a number '''
    return int(i[2:])

In [None]:
Data['Fish'] = [whichfish(f) for f in Data['Sample']]

In [None]:
# Get the file names of the reconstructions
Data['Reconstructions'] = [sorted(glob.glob(os.path.join(f, '*rec0*.png'))) for f in Data['Folder']]
Data['Number of reconstructions'] = [len(r) for r in Data.Reconstructions]

In [None]:
def get_pixelsize(logfile):
    """Get the pixel size from the scan log file"""
    pixelsize = numpy.nan
    with open(logfile, 'r') as f:
        for line in f:
            if 'Image Pixel' in line and 'Scaled' not in line:
                pixelsize = float(line.split('=')[1])
    return(pixelsize)

In [None]:
# Get parameters to doublecheck from logfiles
Data['Voxelsize'] = [get_pixelsize(log) for log in Data['LogFile']]

In [None]:
Data

In [None]:
# Load reconstructions
Data['OutputNameRec'] = [os.path.join(os.path.dirname(f),
                                      fish + '_rec.zarr') for f, fish in zip(Data['Folder'],
                                                                             Data['Sample'])]
Reconstructions = [dask.array.from_zarr(file) for file in Data['OutputNameRec']]

In [None]:
# How big are the datasets?
Data['Size'] = [rec.shape for rec in Reconstructions]

In [None]:
# The three cardinal directions
# Names adapted to fishes: https://en.wikipedia.org/wiki/Fish_anatomy#Body
directions = ['Anteroposterior',
              'Lateral',
              'Dorsoventral']

In [None]:
# Read or calculate the middle slices, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['Mid_' + direction] = ''
for c, row in notebook.tqdm(Data.iterrows(), desc='Middle images', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions),
                                      desc='Fish %s' % row['Sample'],
                                      leave=False,
                                      total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.Middle.%s.png' % (row['Sample'],
                                                         direction))
        if os.path.exists(outfilepath):
            Data.at[c, 'Mid_' + direction] = imageio.imread(outfilepath)
        else:
            # Generate requested axial view
            if 'Anteroposterior' in direction:
                Data.at[c, 'Mid_' + direction] = Reconstructions[c][Data['Size'][c][0] // 2].compute()
            if 'Lateral' in direction:
                Data.at[c, 'Mid_' + direction] = Reconstructions[c][:, Data['Size'][c][1] // 2, :].compute()
            if 'Dorsoventral' in direction:
                Data.at[c, 'Mid_' + direction] = Reconstructions[c][:, :, Data['Size'][c][2] // 2].compute()
            # Save the calculated 'direction' view to disk
            imageio.imwrite(outfilepath, (Data.at[c, 'Mid_' + direction]))

In [None]:
# Read or calculate the directional MIPs, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['MIP_' + direction] = ''
for c, row in notebook.tqdm(Data.iterrows(), desc='MIPs', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions),
                                      desc=row['Sample'],
                                      leave=False,
                                      total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']), '%s.MIP.%s.png' % (row['Sample'], direction))
        if os.path.exists(outfilepath):
            Data.at[c,'MIP_' + direction] = imageio.imread(outfilepath)
        else:
            # Generate MIP
            Data.at[c,'MIP_' + direction] = Reconstructions[c].max(axis=-d).compute()
            # Save it out
            imageio.imwrite(outfilepath, Data.at[c,'MIP_' + direction].astype('uint8'))

In [None]:
# Transpose images, so the fishes are horizontal...
for c, row in notebook.tqdm(Data.iterrows(), desc='Transpose anteroposterior and lateral images', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions[1:]),
                                      desc=row['Sample'],
                                      leave=False,
                                      total=len(directions[1:])):
        Data.at[c,'Mid_' + direction] = Data.at[c,'MIP_' + direction].transpose()
        Data.at[c,'MIP_' + direction] = Data.at[c,'MIP_' + direction].transpose()

In [None]:
# Detect 'center' of the fishes
# For this we select the sagittal center between the otholiths

In [None]:
def otolither(img, sigma=5, threshold=180, verbose=False):
    '''
    Function to detect the otoliths in the axial MIPs.
    We simply look for peaks in the gray values :)
    '''
    # Smooth image for less noise
    smoothed = scipy.ndimage.gaussian_filter(img, sigma=sigma, order=0)
    # Detect peaks in smoothed image, in x- and y-direction
    x = numpy.mean(smoothed>threshold, axis=0)
    y = numpy.mean(smoothed>threshold, axis=1)
    peaksx, _ = scipy.signal.find_peaks(x)
    peaksy, _ = scipy.signal.find_peaks(y)
    if verbose:
        plt.imshow(img)
        plt.imshow(numpy.ma.masked_less(img, threshold), cmap='viridis', alpha=0.618)            
        for p in peaksx:
            plt.axvline(p, alpha=0.5)
        for p in peaksy:
                plt.axhline(p, alpha=0.5)
        plt.axvline(numpy.mean(peaksx))
        plt.axhline(numpy.mean(peaksy))
        plt.show()
    return([peaksx.tolist(), peaksy.tolist()])

In [None]:
# Try to estimate the otholith position
# We use this later for cropping the head off
for d, direction in enumerate(directions):
    Data['Otholith_' + direction] = ''
for c, row in notebook.tqdm(Data.iterrows(), desc='Find Otolith position', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions),
                                      desc=row['Sample'],
                                      leave=False,
                                      total=len(directions)):
        Data.at[c,'Otholith_' + direction] = otolither(row['MIP_' + direction], threshold=180, verbose=False)

In [None]:
# Preview what we found
for d, direction in enumerate(directions):
    for c, row in Data.iterrows():
        plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)
        plt.imshow(row['MIP_' + direction], vmax=150)
        plt.gca().add_artist(ScaleBar(row['Voxelsize'], 'um'))
        plt.title(row['Sample'])
        plt.axis('off')
        for i in row['Otholith_' + direction][0]:
            plt.axvline(i, c='r')
        for i in row['Otholith_' + direction][1]:
            plt.axhline(i, c='g')    
    plt.show()

In [None]:
def get_oto(curve, verbose=False):
    '''
    Function to detect the start/end of the tailfin.
    Adapted from the 'detect_minima' function from the ZMK tooth cohort notebook (https://git.io/J3qqL)
    Ultimately based on https://stackoverflow.com/a/28541805/323100 and some manual tweaking
    '''
    from statsmodels.nonparametric.smoothers_lowess import lowess
    # Smooth the curve and look for the largest deviation
    smoothed = lowess(curve, range(len(curve)), return_sorted=False, frac=0.025)
    maxima = numpy.argmax(smoothed)
    if verbose:
        plt.plot(curve, alpha=0.6, label='Input curve')
        plt.plot(smoothed, label='LOWESS')
        plt.axvline(maxima, c='r', label='Maximum deviation')
        plt.legend()
        plt.show()
    return(maxima)

In [None]:
def headcutter(whichone, sigma=5, threshold=150, verbose=False):
    '''
    Function to detect where the tail is.
    We simply look for peaks in the gray values :)
    '''
    img = Data['MIP_Lateral'][whichone]
    # Smooth image for less noise
    smoothed = scipy.ndimage.gaussian_filter(img, sigma=sigma, order=0)
    # Project average brightness
    x = numpy.mean(smoothed>threshold, axis=0)
    # Use only the tail-part of the fish
    cut = get_oto(x)
    if verbose:
        plt.imshow(img, vmax=150)
        plt.plot(img.shape[0]/x.max()*0.618*x,c='r')
        plt.axvline(cut)
        plt.axis('off')
        plt.show()
    return(cut)

In [None]:
# Calculate where we crop the head off
Data['HeadCrop'] = [headcutter(i, verbose=False) for i in range(len(Data))]

In [None]:
def get_minimum(curve, verbose=False):
    '''
    Function to detect the start/end of the tailfin.
    Adapted from the 'detect_minima' function from the ZMK tooth cohort notebook (https://git.io/J3qqL)
    Ultimately based on https://stackoverflow.com/a/28541805/323100 and some manual tweaking
    '''
    from statsmodels.nonparametric.smoothers_lowess import lowess
    # Smooth the curve and look for the largest deviation
    smoothed = lowess(curve, range(len(curve)), return_sorted=False, frac=0.025)
    maxima = numpy.argmax(numpy.diff(smoothed))
    if verbose:
        plt.plot(curve, alpha=0.6, label='Input curve')
        plt.plot(smoothed, label='LOWESS')
        plt.axvline(maxima, c='r', label='Maximum deviation')
        plt.legend()
        plt.show()
    return(maxima)

In [None]:
def tailcutter(whichone, start=2000, sigma=11, verbose=False):
    '''
    Function to detect where the tail is.
    We simply look for peaks in the gray values :)
    '''
    img = Data['MIP_Lateral'][whichone]
    # Smooth image for less noise
    smoothed = scipy.ndimage.gaussian_filter(img, sigma=sigma, order=0)
    # Project average brightness
    x = numpy.mean(smoothed, axis=0)
    # Use only the tail-part of the fish
    cut = get_minimum(x[start:])
    if verbose:
        plt.imshow(img, vmax=150)
        plt.plot(img.shape[0]/x.max()*0.618*x,c='r')
        plt.axvline(cut+start)
        plt.axis('off')
        plt.show()
    return(cut+start)

In [None]:
# Calculate where we crop the tail off
Data['TailCrop'] = [tailcutter(i, verbose=False) for i in range(len(Data))]

In [None]:
# Show the locations of the crops that we found
for c,row in Data.iterrows():
    plt.imshow(row['MIP_Lateral'])
    plt.axvline(row.HeadCrop, c='r')
    plt.axvline(row.TailCrop, c='r')
    plt.title(row.Sample)
    plt.axis('off')
    plt.gca().add_artist(ScaleBar(row['Voxelsize'], 'um'))
    plt.show()

In [None]:
# Actually crop the reconstructions down
ReconstructionsCrop = [rec[headcrop:tailcrop] for rec, headcrop, tailcrop in zip(Reconstructions,
                                                                                 Data['HeadCrop'],
                                                                                 Data['TailCrop'])]

In [None]:
# Show some slices laterally along the fishes
for c, rec in enumerate(ReconstructionsCrop):
    for k,i in enumerate(range(50, rec.shape[1], 100)):
        plt.subplot(1,len(range(50, rec.shape[1], 100)), k+1)
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][c], 'um'))
        plt.imshow(rec[:,i,:], vmax=150)
        plt.axis('off')
        plt.title('%s: slice %s' % (Data['Sample'][c], i))
    plt.show()

In [None]:
# Recalculate the size of the datasets
Data['SizeCrop'] = [rec.shape for rec in ReconstructionsCrop]

In [None]:
# Read or calculate the cropped middle slices, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['Mid_Crop_' + direction] = ''
for c, row in notebook.tqdm(Data.iterrows(), desc='Cropped middle images', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions),
                                      desc='Fish %s' % row['Sample'],
                                      leave=False,
                                      total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.Crop.Middle.%s.png' % (row['Sample'],
                                                              direction))
    if os.path.exists(outfilepath):
        Data.at[c, 'Mid_Crop_' + direction] = imageio.imread(outfilepath)
    else:
        # Generate requested axial view
        if 'Anteroposterior' in direction:
            Data.at[c, 'Mid_Crop_' + direction] = ReconstructionsCrop[c][Data['SizeCrop'][c][0] // 2]
        if 'Lateral' in direction:
            Data.at[c, 'Mid_Crop_' + direction] = ReconstructionsCrop[c][:, Data['SizeCrop'][c][1] // 2, :]
        if 'Dorsoventral' in direction:
            Data.at[c, 'Mid_Crop_' + direction] = ReconstructionsCrop[c][:, :, Data['SizeCrop'][c][2] // 2]
        # Save the calculated 'direction' view to disk
        imageio.imwrite(outfilepath, (Data.at[c, 'Mid_Crop_' + direction]))

In [None]:
# Read or calculate the cropped directional MIPs, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['MIP_Crop_' + direction] = ''
for c, row in notebook.tqdm(Data.iterrows(), desc='Calculating cropped MIPs', total=len(Data)):
    for d, direction in notebook.tqdm(enumerate(directions),
                                      desc=row['Sample'],
                                      leave=False,
                                      total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']), '%s.Crop.MIP.%s.png' % (row['Sample'], direction))
        if os.path.exists(outfilepath):
            Data.at[c,'MIP_Crop_' + direction] = imageio.imread(outfilepath)
        else:
            # Generate MIP
            Data.at[c,'MIP_Crop_' + direction] = ReconstructionsCrop[c].max(axis=-d)
            # Save it out
            imageio.imwrite(outfilepath, Data.at[c,'MIP_Crop_' + direction].astype('uint8'))

In [None]:
# Calculate the histograms of all reconstructions
# Caveat dask.da.histogram returns histogram AND bins, making each histogram a 'nested' list of [h, b]
Data['Histogram'] = [dask.array.histogram(rec, bins=2**8, range=[0, 2**8]) for rec in ReconstructionsCrop]
# Calculate histogram data and put only h into the dataframe, since we use it quite often below.
# Discard the bins
Data['Histogram'] = [h.compute() for h,b in Data['Histogram']]

In [None]:
def histogramclusterer(img, number_of_clusters = 5, verbose=False):
    '''Calculate the k-means clusters
    Speed things up with MiniBatchKMeans
    https://jakevdp.github.io/PythonDataScienceHandbook/05.11-k-means.html
    '''
    # Setup k-means
    kmeans_volume_subset = sklearn.cluster.MiniBatchKMeans(number_of_clusters)
    # Cluster the histogram into the requested numer of clusters
    # Do this on a subset of the images, to speed things up
    ClusteredImg = kmeans_volume_subset.fit_predict(sorted(numpy.array(img).reshape(-1,1)))
    # Reshape image
    ClusteredImg.shape = img.shape
    if verbose:
        # Calculate histogram
        histogram, bins = dask.array.histogram(img, bins=2**8, range=[0, 2**8])
        plt.semilogy(numpy.log(histogram), label='Gray value histogram')
        plt.semilogy(histogram, label='Gray value histogram (log)')
        for c, cluster in enumerate(sorted(kmeans_volume_subset.cluster_centers_.squeeze())):
            plt.axvline(cluster, label='Cluster center %s at %0.0f' % (c,  cluster),
                        color=seaborn.color_palette(n_colors=number_of_clusters)[c])
        plt.legend()
        plt.xlim([0,2**8])
        plt.title('Logarithmic histogram of input image with %s cluster centers' % number_of_clusters)
        plt.show()
    return(sorted(kmeans_volume_subset.cluster_centers_.squeeze()))

In [None]:
Data['ClusterCenters'] = [histogramclusterer(rec[::100], verbose=False) for rec in ReconstructionsCrop]

In [None]:
# Histograms per experiment
for c, experiment in enumerate(Data.Experiment.unique()):
    plt.subplot(2,1,c+1)
    plt.title(experiment)
    for c,row in Data[Data.Experiment == experiment].iterrows():
        plt.semilogy(row.Histogram,
                     label=row.Sample,
                     color=seaborn.color_palette(n_colors=len(Data))[c])
        for cc in row.ClusterCenters:
            plt.axvline(cc,
                        color=seaborn.color_palette(n_colors=len(Data))[c],
                        alpha=.616)
    plt.xlim([0,2**8])
    plt.legend()
plt.savefig(os.path.join(OutPutDir, 'Histograms.Experiment.ClusterCenters.png'))   
plt.show()

In [None]:
# All histograms, colored per experiment
for c,row in Data.iterrows():
    color=0
    if row.Experiment=='WT':
        color=1
    plt.semilogy(row.Histogram,
                 label=row.Sample,
                 color=seaborn.color_palette(n_colors=2)[color])
plt.xlim([0,2**8])
plt.legend()
plt.savefig(os.path.join(OutPutDir, 'Histograms.Experiment.png'))
plt.show()

In [None]:
iterator = 500
for c, clustercntr in enumerate(Data['ClusterCenters']):
    print('-----Sample %s----' % Data['Sample'][c])
    for imgnr, image in enumerate(ReconstructionsCrop[c][::iterator]):
        for d, threshold in enumerate(clustercntr):
            plt.subplot(1,len(clustercntr), d+1)
            plt.imshow(image)
            plt.imshow(image<threshold, cmap='viridis', alpha=0.5)
            plt.gca().add_artist(ScaleBar(Data['Voxelsize'][c], 'um'))
            plt.title('%s\nThreshold %s' % (os.path.basename(row.Reconstructions[row.HeadCrop:row.TailCrop][::iterator][imgnr]),
                                            round(threshold, 2)))
            plt.axis('off')
        plt.show()

In [None]:
# Apply a median filter to the cropped reconstructions
ReconstructionsMedian = [dask_image.ndfilters.median_filter(rec, size=5) for rec in ReconstructionsCrop]

In [None]:
# Write median-filtered reconstructions to zarr files
Data['OutputNameMedian'] = [os.path.join(os.path.dirname(f),
                                         fish + '_rec_median.zarr') for f, fish in zip(Data['Folder'],
                                                                                       Data['Sample'])]
for c, row in notebook.tqdm(Data.iterrows(),
                            desc='Saving out median-filtered recs to .zarr',
                            total=len(Data)):
    if not os.path.exists(row['OutputNameMedian']):
        ReconstructionsMedian[c].rechunk(chunks=200).to_zarr(row['OutputNameMedian'],
                                                             overwrite=True,
                                                             compressor=Blosc(cname='zstd',
                                                                              clevel=9,
                                                                              shuffle=Blosc.BITSHUFFLE))

In [None]:
# Load median-filtered reconstructions back in
ReconstructionsMedian = [dask.array.from_zarr(file) for file in Data['OutputNameMedian']]            

In [None]:
# Calculate the histograms of median-filtered reconstructions
Data['HistogramMedian'] = [dask.array.histogram(rec,
                                                bins=2**8,
                                                range=[0, 2**8]) for rec in ReconstructionsMedian]
Data['HistogramMedian'] = [h.compute() for h,b in Data['HistogramMedian']]

In [None]:
# Show what we did there
whichsample = 3
whichslice = 800
plt.subplot(221)
plt.imshow(ReconstructionsMedian[whichsample][whichslice])
plt.axis('off')
plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichsample], 'um'))
plt.title('%s, median-filtered rec %s' % (Data.Sample[whichsample], whichslice))
plt.subplot(222)
plt.semilogy(dask.array.histogram(ReconstructionsMedian[whichsample][whichslice], bins=2**8, range=[0, 2**8])[0])
plt.title('Logarithmic histogram')
plt.subplot(223)
plt.imshow(ReconstructionsCrop[whichsample][whichslice])
plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichsample], 'um'))
plt.title('%s, original rec %s' % (Data.Sample[whichsample], whichslice))
plt.axis('off')
plt.subplot(224)
plt.semilogy(dask.array.histogram(ReconstructionsCrop[whichsample][whichslice], bins=2**8, range=[0, 2**8])[0])
plt.title('Logarithmic histogram of original rec')
plt.show()

In [None]:
Data['ClusterCentersMedian'] = [histogramclusterer(rec[::111], verbose=False) for rec in ReconstructionsMedian]

In [None]:
# Histograms of median data per experiment
for c, experiment in enumerate(Data.Experiment.unique()):
    plt.subplot(2,1,c+1)
    plt.title(experiment)
    for c,row in Data[Data.Experiment == experiment].iterrows():
        plt.semilogy(row.HistogramMedian,
                     label=row.Sample,
                     color=seaborn.color_palette(n_colors=len(Data))[c])
        for cc in row.ClusterCentersMedian:
            plt.axvline(cc,
                        color=seaborn.color_palette(n_colors=len(Data))[c],
                        alpha=.616)
    plt.xlim([0,2**8])
    plt.legend()
plt.savefig(os.path.join(OutPutDir, 'Histograms.Median.Experiment.ClusterCenters.png'))   
plt.show()

In [None]:
# All median histograms, colored per experiment
for c,row in Data.iterrows():
    color=0
    if row.Experiment=='WT':
        color=1
    plt.semilogy(row.HistogramMedian,
                 label=row.Sample,
                 color=seaborn.color_palette(n_colors=2)[color])
plt.xlim([0,2**8])
plt.legend()
plt.savefig(os.path.join(OutPutDir, 'Histograms.Median.Experiment.png'))
plt.show()

In [None]:
# Calculate peaks of all histograms, we use them for the segmentation afterwards
Data['Peaks'] = [scipy.signal.find_peaks(h,prominence=[777, None]) for h in Data['Histogram']]
Data['Peaks'] = [numpy.ma.masked_less(p,23).compressed() for p,details in Data['Peaks']]

In [None]:
Data['Peaks']

In [None]:
# All median histograms, colored per experiment
lines = 2
for c,row in Data.iterrows():
    plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)
    color=0
    if row.Experiment=='WT':
        color=1
    plt.semilogy(row.Histogram,
                 label=row.Sample,
                 color=seaborn.color_palette(n_colors=3)[0])
    plt.semilogy(row.HistogramMedian,
                 label=row.Sample,
                 color=seaborn.color_palette(n_colors=3)[1])    
    #Plot them peaks
    for p in row.Peaks:
        plt.axvline(p, label=p, color=seaborn.color_palette(n_colors=3)[0])
    plt.xlim([0,2**8])
    plt.legend()
plt.show()

In [None]:
# Write out median filtered reconstructions   
for c, row in notebook.tqdm(Data.iterrows(),
                            desc='Writing median-filtered reconstructions',
                            total=len(Data)):
    # Generate output folder
    os.makedirs(os.path.join(os.path.dirname(row.Folder), 'rec_median'), exist_ok=True)
    # For every reconstructions, load it's median-filtered counterpart
    # But only do this for the relevant filenames, e.g. those between the crops :)
    for d, name in notebook.tqdm(enumerate(row.Reconstructions[row.HeadCrop:row.TailCrop]),
                                 total=len(ReconstructionsMedian[c]),
                                 leave=False):
        filename = name.replace('rec', 'rec_median')
        if not os.path.exists(filename):
            imageio.imwrite(filename, ReconstructionsMedian[c][d])

In [None]:
def segmentor(image, peaks, verbose=False):
    # https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_random_walker_segmentation.html#sphx-glr-auto-examples-segmentation-plot-random-walker-segmentation-py
    markers = numpy.zeros_like(image, dtype=numpy.uint)
    markers[image < peaks[0]] = 1
    markers[image > peaks[0]] = 2
    labels = random_walker(image, markers, beta=500)
    # 'Scale' image from 0 to 1
    labels = (labels - 1)
    if verbose:
        plt.figure()
        plt.subplot(131)
        plt.imshow(image)
        plt.axis('off')
        plt.subplot(132)
        plt.imshow(markers)
        plt.axis('off')
        plt.subplot(133)
        plt.imshow(labels)
        plt.axis('off')
    # Force image return as 8bit
    labels = numpy.array(labels * (2**8 - 1),dtype='uint8')
    return(labels)

In [None]:
# Write out random-walker-segmented median filtered reconstructions
for c, row in notebook.tqdm(Data.iterrows(), desc='Writing random-walker segmentation', total=len(Data)):
    # Generate output folder
    os.makedirs(os.path.join(os.path.dirname(row.Folder), 'rec_median_segmented'), exist_ok=True)
    # For every reconstructions, load it's median-filtered counterpart, random-walker-segment it and write it out
    # But only do this for the relevant filenames, e.g. those between the crops :)
    for d, name in notebook.tqdm(enumerate(row.Reconstructions[row.HeadCrop:row.TailCrop]),
                                 desc=row.Sample,
                                 total=len(ReconstructionsMedian[c]),
                                 leave=False):
        filename = name.replace('rec', 'rec_median_segmented')
        if not os.path.exists(filename):
            try:
                imageio.imwrite(filename, segmentor(ReconstructionsMedian[c][d], row.Peaks))
            except IndexError:
                # f we're missing a peak, write out original data and let's mess with it later...
                imageio.imwrite(filename, ReconstructionsMedian[c][d])

In [None]:
# Read in segmented slices and save to zarr
Data['OutputNameSegmented'] = [os.path.join(os.path.dirname(f),
                                      fish + '_rec_median_segmented.zarr') for f, fish in zip(Data['Folder'],
                                                                                              Data['Sample'])]
for c, row in notebook.tqdm(Data.iterrows(),
                            desc='Converting segmented slices to .zarr',
                            total=len(Data)):
    if not os.path.exists(row['OutputNameSegmented']):
        print('%2s/%2s: Reading %s slices and saving to %s' % (c + 1,
                                                               len(Data),
                                                               row['Number of reconstructions'],
                                                               row['OutputNameSegmented'][len(Root)+1:]))
        Segmented = dask_image.imread.imread(os.path.join(row.Folder.replace('rec', 'rec_median_segmented'),
                                                                '*rec*.png'))       
        Segmented.rechunk(chunks=200).to_zarr(row['OutputNameSegmented'],
                                              overwrite=True,
                                              compressor=Blosc(cname='zstd',
                                                               clevel=9,
                                                               shuffle=Blosc.BITSHUFFLE))

In [None]:
# Load the (segmented) slices from their zarr arrays
Segmented = [dask.array.from_zarr(file) for file in Data['OutputNameSegmented']]

In [None]:
def get_largest_region(segmentation, verbose=False):
    # Get out biggest item from https://stackoverflow.com/a/55110923/323100
    # Also used in EAWAG/ExtractOtoliths.ipynb
    labels = skimage.measure.label(segmentation)
    assert(labels.max() != 0)  # assume at least 1 CC
    largestCC = labels == numpy.argmax(numpy.bincount(labels.flat)[1:]) + 1
    if verbose:
        plt.subplot(121)
        plt.imshow(segmentation)
        plt.subplot(122)
        plt.imshow(largestCC)
        plt.suptitle('Largest connected component')
        plt.show()
    return largestCC

In [None]:
img = ReconstructionsCrop[3][444]
plt.imshow(img>33)

In [None]:
# Extract largest connected component of the segmented datasets and save out again
t = 130
whichsample = 9
a = get_largest_region(Segmented[whichsample]>t, verbose=False)

In [None]:
slice = 999
plt.subplot(131)
plt.imshow(ReconstructionsCrop[whichsample][slice])
plt.subplot(132)
plt.imshow(Segmented[whichsample][slice])
plt.subplot(133)
plt.imshow(a[slice])
plt.show()

In [None]:
# Let's see what we did there
whichsample = 0
whichslice = 100

# Show image
plt.imshow(Segmented[whichsample][whichslice])
plt.imshow(dask.array.ma.masked_equal(Segmented[whichsample][whichslice], 255), alpha=0.309, cmap='viridis_r')

# Output counts
print('Background: %s px' % dask.array.ma.masked_equal(Segmented[whichsample][whichslice], 0).sum().compute())
print('Segmented: %s px' % dask.array.ma.masked_equal(Segmented[whichsample][whichslice], 255).compute().mask.sum())
print('Image size: %s px x %s px = '
      '%s px - %s px segmented = '
      '%s px background' % (Segmented[whichsample][whichslice].shape[0],
                            Segmented[whichsample][whichslice].shape[1],
                            Segmented[whichsample][whichslice].shape[0] * Segmented[whichsample][whichslice].shape[1],
                            dask.array.ma.masked_equal(Segmented[whichsample][whichslice], 255).compute().mask.sum(),
                            Segmented[whichsample][whichslice].shape[0] * Segmented[whichsample][whichslice].shape[1] - dask.array.ma.masked_equal(Segmented[whichsample][whichslice], 255).compute().mask.sum()))
# So we can 'just' sum the masked segmented data correctly :)

In [None]:
# Mask everything that was *not* segmented and calculate the sum of this volume
Data['SegmentedVolume'] = [dask.array.ma.masked_equal(s, 255).compute().mask.sum() for s in Segmented]

In [None]:
Data['VolumeCrop'] = [x*y*z for x,y,z in Data['SizeCrop']]

In [None]:
ReconstructionsCrop[0].shape[0]

In [None]:
# Normalize to cut length of fishes
Data['SegmentedVolume_normalized_vol'] = [vol_seg / vol_data for vol_seg, vol_data in zip(Data['SegmentedVolume'],
                                                                                          Data['VolumeCrop'])]
Data['SegmentedVolume_normalized_length'] = [vol_seg / zxy[0] for vol_seg, zxy in zip(Data['SegmentedVolume'],
                                                                                      Data['Size'])]

In [None]:
# Convert volumes to cubic mm
Data['SegmentedVolume_mm'] = [vol_vx * (vs **3) * 1e-9 for vol_vx, vs in zip(Data['SegmentedVolume'],
                                                                             Data['Voxelsize'])]
Data['SegmentedVolume_normalized_vol_mm'] = [vol_vx * (vs **3) * 1e-9 for vol_vx, vs in zip(Data['SegmentedVolume_normalized_vol'],
                                                                                            Data['Voxelsize'])]
Data['SegmentedVolume_normalized_length_mm'] = [vol_vx * (vs **3) * 1e-9 for vol_vx, vs in zip(Data['SegmentedVolume_normalized_length'],
                                                                                               Data['Voxelsize'])]

In [None]:
Data[Data.Experiment=='WT']

In [None]:
for c, exp in enumerate(Data.Experiment.unique()):
    print(exp)
    for d, row in Data[Data.Experiment==exp].iterrows():
        print(d, row.SegmentedVolume_normalized_vol)

In [None]:
seaborn.boxplot(data=Data,
                x='Experiment',
                y='SegmentedVolume_mm',
                saturation=1)
seaborn.swarmplot(data=Data,
                  x='Experiment',
                  y='SegmentedVolume_mm',
                  s=25,
                  linewidth=2)
for c, exp in enumerate(Data.Experiment.unique()):
    for d, row in Data[Data.Experiment==exp].iterrows():
        plt.text(c, row.SegmentedVolume_mm, row.Sample)
plt.ylabel('Segmented volume [mm³]')
plt.savefig(os.path.join(OutPutDir, 'SegmentedVolume.png'))
plt.show()

In [None]:
seaborn.boxplot(data=Data,
                x='Experiment',
                y='SegmentedVolume_normalized_vol',
                saturation=1)
seaborn.swarmplot(data=Data,
                  x='Experiment',
                  y='SegmentedVolume_normalized_vol',
                  s=25,
                  linewidth=2)
for c, exp in enumerate(Data.Experiment.unique()):
    for d, row in Data[Data.Experiment==exp].iterrows():
        plt.text(c, row.SegmentedVolume_normalized_vol, row.Sample)
plt.ylabel('Segmented volume, normalized to data volume')
plt.savefig(os.path.join(OutPutDir, 'SegmentedVolume.Normalized.Volume.png'))
plt.show()

In [None]:
seaborn.boxplot(data=Data,
                x='Experiment',
                y='SegmentedVolume_normalized_length',
                saturation=1)
seaborn.swarmplot(data=Data,
                  x='Experiment',
                  y='SegmentedVolume_normalized_length',
                  s=25,
                  linewidth=2)
for c, exp in enumerate(Data.Experiment.unique()):
    for d, row in Data[Data.Experiment==exp].iterrows():
        plt.text(c, row.SegmentedVolume_normalized_length, row.Sample)
plt.ylabel('Segmented volume, normalized to data volume')
plt.savefig(os.path.join(OutPutDir, 'SegmentedVolume.Normalized.Length.png'))
plt.show()

In [None]:
for experiment in Data.Experiment.unique():
    print('The %s fishes have a mean segmented volume (not normalized) of %7.3f mm³'
          % (experiment, Data[Data.Experiment == experiment]['SegmentedVolume_mm'].mean()))

In [None]:
# Tell us which fish is the median one
# https://stackoverflow.com/a/61047899/323100
for experiment in Data.Experiment.unique():
    print('The median fish of the %s fishes is fish %s (df index %s) and has a volume of %7.3f mm³'
          % (experiment,
             Data[Data['SegmentedVolume_mm'] == Data[Data.Experiment == experiment]['SegmentedVolume_mm'].median()].iloc[0]['Sample'],
             Data[Data['SegmentedVolume_mm'] == Data[Data.Experiment == experiment]['SegmentedVolume_mm'].median()].index[0],
             Data[Data.Experiment == experiment]['SegmentedVolume_mm'].median()))

In [None]:
Data[['Experiment', 'Sample', 'SegmentedVolume_mm']]

In [None]:
print('All %s fishes have a mean segmented volume of %0.2f mm³'
      % (len(Data), Data['SegmentedVolume_mm'].mean()))

In [None]:
for experiment in Data.Experiment.unique():
    print('The %s fishes have a mean segmented volume (normalized to length) of %6.0f voxels'
          % (experiment, int(Data[Data.Experiment == experiment]['SegmentedVolume_normalized_length'].mean())))

In [None]:
# Tell us which fish is the median one
# https://stackoverflow.com/a/61047899/323100
for experiment in Data.Experiment.unique():
    print('The median fish of the %s fishes is fish %s and has a volume of %6.0f mm³'
          % (experiment,
             Data[Data['SegmentedVolume_normalized_length'] == Data[Data.Experiment == experiment]['SegmentedVolume_normalized_length'].median()].iloc[0]['Sample'],
             Data[Data.Experiment == experiment]['SegmentedVolume_normalized_length'].median()))

In [None]:
# Use wt05 and k003 for the visualization

In [None]:
Data[['Experiment', 'Sample', 'SegmentedVolume_normalized_length']]

In [None]:
# for i in Data:
#     print(i)

In [None]:
# Write XLS sheet for Carolina
Output = Data[['Sample', 'Folder', 'LogFile', 'Experiment', 'Fish',
               'Voxelsize', 'Number of reconstructions',  'OutputNameRec',
               'Size', 'HeadCrop', 'TailCrop', 'SizeCrop',  'VolumeCrop',
               'Peaks',
               'SegmentedVolume', 'SegmentedVolume_normalized_vol', 'SegmentedVolume_normalized_length',
               'SegmentedVolume_mm', 'SegmentedVolume_normalized_vol_mm', 'SegmentedVolume_normalized_length_mm']]
Output.to_excel('Data.xlsx')
Output.to_excel(os.path.join(OutPutDir, 'Data.xls'))

In [None]:
print('Saved all the asked data to %s' % OutPutDir)

#### 

In [None]:
asdfasdf==

In [None]:
# Does that make sense?
for whichslice in range(250,2222,250):
    for c,row in Data.iterrows():
        plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)
        plt.imshow(ReconstructionsCrop[c][whichslice])
        plt.imshow(Segmented[c][whichslice], alpha=0.618, cmap='viridis')
        plt.title('%s: Slice %s' % (row.Sample, whichslice))
        plt.gca().add_artist(ScaleBar(row['Voxelsize'], 'um'))
        plt.axis('off')   
    plt.tight_layout(h_pad=0.5, w_pad=0.5)
    plt.savefig(os.path.join(OutPutDir, 'SegmentedSlices%04d.png' % whichslice))
    plt.show()

In [None]:
asdfasdfasdf==

In [None]:
import itkwidgets
from itkwidgets import view  # 3d viewer

In [None]:
view(Reconstructions[0])

In [None]:
seg_explicit_thresholds = sitk.ConnectedThreshold(img_T1,
                                                  seedList=[(600,200)],
                                                  lower=10,
                                                  upper=80)

In [None]:
writer = sitk.ImageFileWriter()
writer.SetFileName('out.png')
writer.Execute(seg_explicit_thresholds)

In [None]:
plt.imshow(seg_explicit_thresholds)