# Separate the scans

We scanned multiple fish in [a special sample holder](https://github.com/TomoGraphics/Hol3Drs/blob/master/STL/Stickleback.Multiple.stl).
This notebook is used to separate them into different bunch of reonstructions.

The cells below are used to set up the whole notebook.
They load needed libraries and set some default values.

In [None]:
# Load the modules we need
import platform
import os
import glob
import pandas
import imageio
import nrrd
import numpy
import random
import matplotlib
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
import seaborn
import dask
import dask_image.imread
from dask.distributed import Client, LocalCluster
import skimage
from tqdm.auto import tqdm, trange
from joblib import Parallel, delayed

In [None]:
# Load our own log file parsing code
# This is loaded as a submodule to alleviate excessive copy-pasting between *all* projects we do
# See https://github.com/habi/BrukerSkyScanLogfileRuminator for details on its inner workings
from BrukerSkyScanLogfileRuminator.parsing_functions import *

In [None]:
# # Code linting
# %load_ext pycodestyle_magic

In [None]:
# %pycodestyle_on

In [None]:
# Set dask temporary folder
# Do this before creating a client: https://stackoverflow.com/a/62804525/323100
# We use the fast internal SSD for speed reasons
import tempfile
if 'Linux' in platform.system():
    # Check if me mounted the FastSSD, otherwise go to standard tmp file
    if os.path.exists(os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')):
        tmp = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD', 'tmp')
    else:
        tmp = tempfile.gettempdir()
elif 'Darwin' in platform.system():
    tmp = tempfile.gettempdir()
else:
    if 'anaklin' in platform.node():
        tmp = os.path.join('F:\\tmp')
    else:
        tmp = os.path.join('D:\\tmp')
dask.config.set({'temporary_directory': tmp})
print('Dask temporary files go to %s' % dask.config.get('temporary_directory'))

In [None]:
from dask.distributed import Client
client = Client()

In [None]:
client

In [None]:
seaborn.set_context('notebook')

In [None]:
# Set up figure defaults
plt.rc('image', cmap='gray', interpolation='nearest')  # Display all images in b&w and with 'nearest' interpolation
# plt.rcParams['figure.figsize'] = (16, 9)  # Size up figures a bit
plt.rcParams['figure.dpi'] = 200

In [None]:
# Setup scale bar defaults
plt.rcParams['scalebar.location'] = 'lower right'
plt.rcParams['scalebar.frameon'] = False
plt.rcParams['scalebar.color'] = 'white'

In [None]:
# Display all plots identically
lines = 3
# And then do something like
# plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)

Since the (tomographic) data can reside on different drives we set a folder to use below

In [None]:
# Different locations if running either on Linux or Windows
FastSSD = True
if 'Linux' in platform.system():
    if FastSSD:
        BasePath = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')
    else:
        BasePath = os.path.join(os.path.sep, 'home', 'habi', '2214')
elif 'Windows' in platform.system():
    if FastSSD:
        BasePath = os.path.join('F:\\')
    else:
        BasePath = os.path.join('N:\\')
Root = os.path.join(BasePath, 'IEE Stickleback')
print('We are loading all the data from %s' % Root)

Now that we are set up, actually start to load/ingest the data.

In [None]:
# Make us a dataframe for saving all that we need
Data = pandas.DataFrame()

In [None]:
# Get *all* log files
# Using os.walk is way faster than using recursive glob.glob, see DataWrangling.ipynb for details
# Not sorting the found logfiles is also making it quicker
Data['LogFile'] = [os.path.join(root, name)
                   for root, dirs, files in os.walk(Root)
                   for name in files
                   if name.endswith((".log"))]

In [None]:
# Get all folders
Data['Folder'] = [os.path.dirname(f) for f in Data['LogFile']]

In [None]:
# Show a (small) sampler of the loaded data as a first check
Data.sample(n=5)

In [None]:
# Get rid of all the logfiles from all the folders that might be on disk but that we don't want to load the data from
for c, row in Data.iterrows():
    # Since this notebook only deals with the 'BucketOfFish' scans, drop all others
    if 'ucket' not in row.Folder:  # Only use the scans named Bucket* here
        Data.drop([c], inplace=True)
    elif 'rec' not in row.Folder:  # Only look at logs in the rec folders
        Data.drop([c], inplace=True)
    elif '_regions' in row.Folder:  # Exclude all log files that we write in this notebook (to $scan$_region folders)
        Data.drop([c], inplace=True)
    elif 'SubScan' in row.Folder:  # Exclude any log files from rsyncing temporary data
        Data.drop([c], inplace=True)
    elif 'rectmp.log' in row.LogFile:  # Exclude any log files from rsyncing temporary data
        Data.drop([c], inplace=True)
# Reset dataframe to something that we would get if we only would have loaded the 'rec' files
Data = Data.reset_index(drop=True)

In [None]:
# Generate us some meaningful colums in the dataframe
Data['Sample'] = [os.path.basename(logfile).replace('_rec.log', '') for logfile in Data['LogFile']]
Data['Scan'] = [os.path.basename(os.path.dirname(logfile)) for logfile in Data['LogFile']]

In [None]:
# Does the dataframe look plausible?
Data.tail()

In [None]:
# Load the file names of all the reconstructions of all the scans
Data['Filenames Reconstructions'] = [sorted(glob.glob(os.path.join(f, '*rec0*.png'))) for f in Data['Folder']]
# How many reconstructions do we have?
Data['Number of reconstructions'] = [len(r) for r in Data['Filenames Reconstructions']]

In [None]:
# Drop samples which have either not been reconstructed yet or of which we deleted the reconstructions with
Data = Data[Data['Number of reconstructions'] > 0]
# Reset the dataframe count/index for easier indexing afterwards
Data.reset_index(drop=True, inplace=True)
print('We have %s folders with reconstructions' % (len(Data)))

In [None]:
# Get parameters to doublecheck from logfiles
Data['Voxelsize'] = [pixelsize(log) for log in Data['LogFile']]
Data['Filter'] = [whichfilter(log) for log in Data['LogFile']]
Data['Exposuretime'] = [exposuretime(log) for log in Data['LogFile']]
Data['Scanner'] = [scanner(log) for log in Data['LogFile']]
Data['Averaging'] = [averaging(log) for log in Data['LogFile']]
Data['ProjectionSize'] = [projection_size(log) for log in Data['LogFile']]
Data['RotationStep'] = [rotationstep(log) for log in Data['LogFile']]
Data['Grayvalue'] = [reconstruction_grayvalue(log) for log in Data['LogFile']]
Data['RingartefactCorrection'] = [ringremoval(log) for log in Data['LogFile']]
Data['BeamHardeningCorrection'] = [beamhardening(log) for log in Data['LogFile']]
Data['DefectPixelMasking'] = [defectpixelmasking(log) for log in Data['LogFile']]
Data['Scan date'] = [scandate(log) for log in Data['LogFile']]

In [None]:
# Sort dataframe based on the scan date
Data.sort_values(by=['Scan date'],
                 ignore_index=True,
                 inplace=True)

In [None]:
# Does the dataframe look plausible?
Data[['Sample', 'Scan', 'Scan date']]

In [None]:
# # Load all reconstructions DASK arrays
# Reconstructions = [dask_image.imread.imread(os.path.join(folder,'*rec*.png')) for folder in Data['Folder']]
# Load all reconstructions into ephemereal DASK arrays, with a nice progress bar...
Reconstructions = [None] * len(Data)
for c, row in tqdm(Data.iterrows(),
                   desc='Loading reconstructions',
                   total=len(Data)):
    Reconstructions[c] = dask_image.imread.imread(os.path.join(row['Folder'], '*rec*.png'))
Reconstructions = [rec[:, :, :, 0] for rec in Reconstructions]  # Get rid of the color channel

In [None]:
# Extract bucket name
Data['Bucket'] = [(log).split(os.sep)[-3].split('_')[-1] for log in Data['LogFile']]

In [None]:
# The three cardinal directions
directions = ['Axial',
              'Frontal',
              'Median']

In [None]:
# Read or calculate the directional MIPs, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['MIP_' + direction] = ''
for c, row in tqdm(Data.iterrows(), desc='Working on MIPs', total=len(Data)):
    for d, direction in tqdm(enumerate(directions),
                             desc='%s/%s' % (row['Sample'], row['Scan']),
                             leave=False,
                             total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.%s.MIP.%s.png' % (row['Sample'], row['Scan'], direction))
        if not os.path.exists(outfilepath):
            # Generate and write out MIP
            imageio.imwrite(outfilepath, Reconstructions[c].max(axis=d).compute())
        Data.at[c, 'MIP_' + direction] = dask_image.imread.imread(outfilepath).squeeze()

In [None]:
def getLargestCC(segmentation):
    # Based on https://stackoverflow.com/a/55110923
    labels = skimage.measure.label(segmentation)
    assert labels.max() != 0  # assume at least 1 CC
    largestCC = labels == numpy.argmax(numpy.bincount(labels.flat)[1:]) + 1
    return largestCC

In [None]:
def vial_label_extractor(whichscan, threshold=35, part=333, verbose=True):
    bottom_mip_filename = os.path.join(os.path.dirname(Data['Folder'][whichscan]),
                                       '%s.%s.MIP.Bottom%04dslices.png' % (Data['Sample'][whichscan], Data['Scan'][whichscan], part))
    # Generate and write out file to speed up process
    if not os.path.exists(bottom_mip_filename):
        # Let's get out the numbers, they are 'hidden' in the lower part
        bottom_mip = Reconstructions[whichscan][:part].max(axis=0)
        imageio.imwrite(bottom_mip_filename, bottom_mip)
    bottom_mip = dask_image.imread.imread(bottom_mip_filename).squeeze().compute()
    # Clean central part
    region_radius = 100
    bottom_mip[bottom_mip.shape[0] // 2 - region_radius:bottom_mip.shape[0] // 2 + region_radius,
               bottom_mip.shape[1] // 2 - region_radius:bottom_mip.shape[1] // 2 + region_radius] = 0
    if not threshold:
        # Calculate multi Otsu with three classes, use highest threshold
        threshold = skimage.filters.threshold_multiotsu(bottom_mip)[-1]
        # For at least bucket E, the threshold is borked
        # If it's larger than 50, set it to something reasonable
        if threshold > 50:
            threshold = 35
    # remove largest component from thresholded bottom MIP
    # The largest component are the separation walls of the bucket
    numbers = numpy.bitwise_xor(bottom_mip > threshold, getLargestCC(bottom_mip > threshold))
    # Clean up the image by removing small objects
    numbers_cleaned = skimage.morphology.remove_small_objects(numbers, min_size=10000)
    # only labels should remain
    if verbose:
        plt.subplot(121)
        plt.imshow(Data['MIP_Frontal'][whichscan],
                   vmin=0,
                   vmax=2**8)
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
        plt.gca().add_artist(matplotlib.patches.Rectangle((0, 0), Data['MIP_Frontal'][whichscan].shape[1],
                                                          part,
                                                          edgecolor=None,
                                                          facecolor='yellow',
                                                          alpha=0.618))
        plt.title('Bucket %s' % Data['Bucket'][whichscan])
        plt.axis('off')
        plt.subplot(122)
        plt.imshow(Data['MIP_Axial'][whichscan],
                   vmin=0,
                   vmax=2**8)
        plt.imshow(numpy.ma.masked_equal(numbers_cleaned, 0), cmap='viridis_r', alpha=0.618)
        plt.title('MIP of marked region\n%s recs>%s - their largest CC' % (part, threshold))
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
        plt.axis('off')
        plt.savefig('%s.%s.Labels.Overview.png' % (os.path.join(os.path.dirname(Data['Folder'][whichscan]), Data.Sample[whichscan]), Data.Scan[whichscan]))
        plt.show()
    return numbers_cleaned

In [None]:
# vial_label_extractor(4, verbose=True, threshold=None)

In [None]:
# Test the extractor thingamajig
vial_label_extractor(1, verbose=True, part=400, threshold=None)

In [None]:
Data['VialLabels'] = [vial_label_extractor(i, verbose=False, threshold=None) for i in range(len(Data))]

In [None]:
def detect_fish_position(whichscan, threshold=None, verbose=False):
    import matplotlib.patches
    # Detect the fish positions based on blobs in the top-down MIP
    regions = None
    td_mip = Data['MIP_Axial'][whichscan].compute()
    if not threshold:
        threshold = skimage.filters.threshold_otsu(td_mip[td_mip > 10])
    td_mip_thresholded = td_mip > threshold
    # Remove central part, on some scans the connector shows up...
    region_radius = 200
    td_mip_thresholded[td_mip_thresholded.shape[0] // 2 - region_radius:td_mip_thresholded.shape[0] // 2 + region_radius,
                       td_mip_thresholded.shape[1] // 2 - region_radius:td_mip_thresholded.shape[1] // 2 + region_radius] = 0
    # Clean speckles, assuming all fish are larger than 5000 px
    cleaned = skimage.morphology.remove_small_objects(td_mip_thresholded,
                                                      min_size=5000)
    # Remove central part, on some scans the connector shows up...
    region_radius = 275
    cleaned[cleaned.shape[0] // 2 - region_radius:cleaned.shape[0] // 2 + region_radius, cleaned.shape[1] // 2 - region_radius:cleaned.shape[1] // 2 + region_radius] = 0
    # Label image
    label_image = skimage.measure.label(cleaned)
    # Detect regions
    regions = skimage.measure.regionprops(label_image)
    # Drop small areas, if we found more than 6 fish
    if len(regions) > 6:
        print('Found more than 6 regions')
        print('Dropping region with area < 1000 from %s found regions' % len(regions))
        regions = [item for item in regions if item.area > 1000]
    if len(regions) < 6:
        print('Found less than 6 regions')
        # regions = (regions + 6 * [numpy.nan])[:6]
        # print(regions)
    if verbose:
        plt.subplot(121)
        plt.imshow(td_mip)
        plt.imshow(numpy.ma.masked_equal(Data['VialLabels'][whichscan], 0), cmap='viridis_r', alpha=0.618)
        plt.title('Bucket %s' % Data['Bucket'][whichscan])
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
        plt.axis('off')
        plt.subplot(122)
        # to make the background transparent, pass the value of `bg_label`,
        # and leave `bg_color` as `None` and `kind` as `overlay`
        plt.imshow(skimage.color.label2rgb(label_image, image=td_mip, bg_label=0))
        for c, region in enumerate(regions):
            if region.bbox:
                # draw rectangle around segmented fish
                minr, minc, maxr, maxc = region.bbox
                rect = matplotlib.patches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                                    fill=False, edgecolor='white', ls='--')
                plt.gca().add_patch(rect)                
                plt.scatter(region.centroid[1], region.centroid[0], s=200, c='black', edgecolors='white')
                plt.annotate('%s' % region.label,
                             xy=(region.centroid[1], region.centroid[0]),
                             ha='center',
                             va='center',
                             color='white')
        plt.title('%s detected fish' % len(regions))
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
        plt.axis('off')
        plt.savefig('%s.%s.Labels.Detected.png' % (os.path.join(os.path.dirname(Data['Folder'][whichscan]), Data.Sample[whichscan]), Data.Scan[whichscan]))
        plt.show()
    return regions

In [None]:
# Test position detector
test = detect_fish_position(0, verbose=True)

In [None]:
def reorder_list(inputlist, neworder=[0, 1, 3, 5, 4, 2], verbose=True):
    """
    Shuffle fish positions to a new order.
    The default neworder is the "usual" order we find the fish in.
    We *deliberately* want a new list, so we can keep the old one around for double-checks
    https://stackoverflow.com/questions/2177590/how-can-i-reorder-a-list#comment106984501_2177607
    Since double-checking is cumbersome, we print out both lists.
    """
    # Catch less than 6 fish in bucket
    if len(inputlist) != len(neworder):
        print('We found less than six fish, so we simply return the original, unsorted list.')
        print('Call the "reorder_list" command with (for example) "neworder=%s".' % random.sample(range(len(inputlist)), len(inputlist)))
        return inputlist
    else:
        ordered_list = [inputlist[i] for i in neworder]
        if verbose:
            print('Clockwise Original : %s' % [element.label for element in inputlist])
            print('Clockwise Reordered: %s' % [element.label for element in ordered_list])
    return ordered_list

In [None]:
# Detect the fish positions and sort them in a default order in a first pass
Data['Regions'] = [detect_fish_position(i, verbose=False) for i in range(len(Data))]
Data['Regions_Ordered'] = [reorder_list(regions, verbose=False) for regions in Data['Regions']]
# Below we look at each of the scans again and reorder if necessary

In [None]:
# # Some fish regions are not detected in a consistent order, reorder them correctly now.
# # The `verbose` output makes it easy to double-check the correct labels.
# # Bucket C: label 4 and 5 are swapped
# whichbucket = 2
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# # Bucket D: label 4 and 5 are swapped
# whichbucket = 3
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# # Bucket E: label 4 and 5 are swapped
# whichbucket = 4
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket F: we found 7 things to label :)
# whichbucket = 5
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 3, 5, 4, 2],
#                                                     verbose=True)

In [None]:
# # Bucket G
# whichbucket = 6
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket H
# whichbucket = 7
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 3, 5, 4, 2],
#                                                     verbose=True)

In [None]:
# # Bucket I
# whichbucket = 8
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket J
# whichbucket = 9
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# # Bucket K
# whichbucket = 10
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 3, 5, 4, 2],
#                                                     verbose=True)

In [None]:
# # Bucket M
# whichbucket = 12
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket N
# whichbucket = 13
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# # Bucket O
# whichbucket = 14
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket Q
# whichbucket = 16
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# # Bucket R
# whichbucket = 17
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 3, 5, 4, 2],
#                                                     verbose=True)

In [None]:
# # Bucket S
# whichbucket = 18
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 3, 5, 4, 2],
#                                                     verbose=True)

In [None]:
# # Bucket T
# whichbucket = 19
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 4, 5, 3, 2],
#                                                     verbose=True)

In [None]:
# # Bucket U
# whichbucket = 20
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     neworder=[0, 1, 2, 3],  # Only 4 fish!
#                                                     verbose=True)

In [None]:
# Reorder regions from the second (Sticklebucket) batch
whichbucket = 2
print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
_ = detect_fish_position(whichbucket, verbose=True)
Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
                                                    neworder=[0, 1, 2, 4, 3],
                                                    verbose=True)

In [None]:
# Reorder regions from the second (Sticklebucket) batch
whichbucket = 3
print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
_ = detect_fish_position(whichbucket, verbose=True)
Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
                                                    neworder=[0, 1, 4, 5, 3, 2],
                                                    verbose=True)

In [None]:
# Reorder regions from the second (Sticklebucket) batch
whichbucket = 4
print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
_ = detect_fish_position(whichbucket, verbose=True)
Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
                                                    neworder=[0, 1, 4, 5, 3, 2],
                                                    verbose=True)

In [None]:
# Reorder regions from the second (Sticklebucket) batch
whichbucket = 5
print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
_ = detect_fish_position(whichbucket, verbose=True)
Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
                                                    neworder=[0, 2, 3, 4, 1],
                                                    verbose=True)

In [None]:
# Reorder regions from the second (Sticklebucket) batch
whichbucket = 7
print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
_ = detect_fish_position(whichbucket, verbose=True)
Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
                                                    neworder=[0, 2, 3, 5, 4, 1],
                                                    verbose=True)

In [None]:
# # Reorder regions from the second (Sticklebucket) batch
# whichbucket = 8
# print('Looking at Bucket %s' % Data['Bucket'][whichbucket])
# _ = detect_fish_position(whichbucket, verbose=True)
# Data['Regions_Ordered'][whichbucket] = reorder_list(Data['Regions'][whichbucket],
#                                                     # neworder=[0, 2, 3, 5, 4, 1],
#                                                     verbose=True)

In [None]:
# REPEAT THE ABOVE FOR EACH NEWLY ACQUIRED BUCKET
# WITH VERBOSE OUTPUT

In [None]:
# # Construct us a *consecutive* fish number, based on 6 fish per scan
# # This 'FishNumber' is *not* the same as the 'Fish Number' in Bens tracking sheet
# # We use our FishNumber to construct a FishID, which matches the 'Fish Number'
# Data['FishNumber'] = [[reg.label + (6 * c - 1) for reg in region] for c, region in enumerate(Data.Regions)]

In [None]:
# # Overwrite blunders in first two batches with their correct numbers
# # Bucket A
# Data['FishNumber'][0] = [fn + 1 for fn in range(len(Data['FishNumber'][0]))]
# Data['FishNumber'][0][2] = 30
# # Bucket B
# Data['FishNumber'][1][0] = 3

In [None]:
# # Read in tracking sheet from Ben
# trackingsheet = pandas.read_excel('/home/habi/research_storage_ben/microCT_Stickleback/CT_Sticklebacks_Tracking_Sheet.xlsx', index_col=None)

In [None]:
# trackingsheet.head()

In [None]:
# trackingsheet.tail()

In [None]:
# trackingsheet['Unique_ID']

In [None]:
# # Construct the fish IDs
# # This ID corresponds to the 'Unique_ID' in the tracking sheet read in above
# # Set the ID for all buckets, technically they are only valid for buckets A-D
# Data['FishID'] = [['FG.X23.%03d' % n for n in number] for number in Data.FishNumber]

In [None]:
# # Update fish ID for later scans
# # In bucket E we start with the SL fish
# Data['FishID'][5] = ['SL.X23.%03d' % int(c + 1) for c, id in enumerate(Data['FishID'][5])]
# # The first fish of Bucket E is still an FG one
# Data['FishID'][5][-1] = 'FG.X23.031'

In [None]:
# # Buckets F-J are SL fish
# Data['FishID'][6] = ['SL.X23.%03d' % int(c + 1 + 5) for c, id in enumerate(Data['FishID'][6])]
# Data['FishID'][7] = ['SL.X23.%03d' % int(c + 1 + 5 + 6) for c, id in enumerate(Data['FishID'][7])]
# Data['FishID'][8] = ['SL.X23.%03d' % int(c + 1 + 5 + 6 + 6) for c, id in enumerate(Data['FishID'][8])]
# Data['FishID'][9] = ['SL.X23.%03d' % int(c + 1 + 5 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][9])]

In [None]:
# # In bucket J we start with the SR fish
# Data['FishID'][9][3] = 'SR.X23.001'
# Data['FishID'][9][4] = 'SR.X23.002'
# Data['FishID'][9][5] = 'SR.X23.003'
# # Bucket K-P are SR fish
# Data['FishID'][10] = ['SR.X23.%03d' % int(c + 1 + 3) for c, id in enumerate(Data['FishID'][10])]
# Data['FishID'][11] = ['SR.X23.%03d' % int(c + 1 + 3 + 6) for c, id in enumerate(Data['FishID'][11])]
# Data['FishID'][12] = ['SR.X23.%03d' % int(c + 1 + 3 + 6 + 6) for c, id in enumerate(Data['FishID'][12])]
# Data['FishID'][13] = ['SR.X23.%03d' % int(c + 1 + 3 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][13])]
# Data['FishID'][14] = ['SR.X23.%03d' % int(c + 1 + 3 + 6 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][14])]
# Data['FishID'][15] = ['SR.X23.%03d' % int(c + 1 + 3 + 6 + 6 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][15])]

In [None]:
# # In bucket P we start with the WT fish
# Data['FishID'][15][2] = 'WT.X23.001'
# Data['FishID'][15][3] = 'WT.X23.002'
# Data['FishID'][15][4] = 'WT.X23.003'
# Data['FishID'][15][5] = 'WT.X23.004'
# # Bucket Q- are WR fish
# Data['FishID'][16] = ['WR.X23.%03d' % int(c + 1 + 4) for c, id in enumerate(Data['FishID'][16])]
# Data['FishID'][17] = ['WR.X23.%03d' % int(c + 1 + 4 + 6) for c, id in enumerate(Data['FishID'][17])]
# Data['FishID'][18] = ['WR.X23.%03d' % int(c + 1 + 4 + 6 + 6) for c, id in enumerate(Data['FishID'][18])]
# Data['FishID'][19] = ['WR.X23.%03d' % int(c + 1 + 4 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][19])]
# Data['FishID'][20] = ['WR.X23.%03d' % int(c + 1 + 4 + 6 + 6 + 6 + 6) for c, id in enumerate(Data['FishID'][20])]

In [None]:
def mapping(whichscan, verbose=True):
    """
    For the first batch we constructed a consecutive FishID.
    For the second (Sticklebucket) batch, we read a file on disk to map the fish positions to the labels
    """
    # Find the file we want to use for mappingdef reorder_list(list, neworder=[0, 2, 4, 5, 3, 1], verbose=True):
    mapping_file = glob.glob(os.path.join(os.path.dirname(Data['Folder'][whichscan]), '*Mapping*.md'))
    if not len(mapping_file):
        print('You will need to provide a mapping file for this scan')
        return
    # The mapping file is a markdown table with 'FishID', 'Petal' (sample holder compartment), and 'Region' (detected fish position)
    # We want to use this file to do the mapping, eliminating manual guesswork and reordering :)
    # Read the file into a dataframe
    # We made a little markdown-formatted list; if we use | as a separator, then we can easily get the file into a dataframe :)
    # Unfortunately, this leads to spaces in the column names, so we strip those afterwards   
    mapping_df = pandas.read_csv(mapping_file[0], sep='|')
    mapping_df = mapping_df.replace({' None ': None})  # Replace ' None ' with None
    mapping_df.columns = mapping_df.columns.str.replace(' ', '')  # Strip spaces from column names
    # Drop the first line, which is the table separator
    mapping_df.drop(0, inplace=True)
    # Drop the 'Unnamed' columns, which are empty: https://stackoverflow.com/a/52696683/323100
    mapping_df.dropna(how='all', axis='columns', inplace=True)
    # Check if a fish has *not* been scanned, which we noted as 'None' in the .md file
    # Simply drop this row and reset the index
    # Iterate over rows and handle 'Petal' values
    for c, row in mapping_df.iterrows():
        try:
            if row['Petal'] is not None and int(row['Petal']):  # Check if 'Petal' is not None and can be converted to int
                pass
        except ValueError:
            if verbose:
                print('Dropping "%s" for %s' % (row['Petal'], row['Fish']))
            mapping_df.drop([c], inplace=True)
    mapping_df = mapping_df.reset_index(drop=True)
    if verbose:
        print(mapping_df)
    # FishIDs = [id for id in mapping_df.Fish]
    return [id.strip() for id in mapping_df.Fish]

In [None]:
m = mapping(5, verbose=True)

In [None]:
Data['FishID'] = [mapping(i) for i in range(len(Data))]

In [None]:
Data['FishID'][5]

In [None]:
Data[['Bucket', 'FishID']]

In [None]:
# Load overview and labbook image, if present
Data['LabbookImage'] = [dask_image.imread.imread(os.path.join(os.path.dirname(f), '_labbook.jpg')).squeeze()
                        if os.path.exists(os.path.join(os.path.dirname(f), '_labbook.jpg'))
                        else numpy.random.random((2**6, 2**6))
                        for f in Data['Folder']]
Data['OverviewImage'] = [dask_image.imread.imread(os.path.join(os.path.dirname(f), '_overview.jpg')).squeeze()
                         if os.path.exists(os.path.join(os.path.dirname(f), '_overview.jpg'))
                         else numpy.random.random((2**6, 2**6))
                         for f in Data['Folder']]

In [None]:
def doublecheck_fish_position(whichscan):
    plt.subplot(221)
    plt.imshow(Data['MIP_Axial'][whichscan])
    plt.imshow(numpy.ma.masked_equal(Data['VialLabels'][whichscan], False), cmap='viridis_r')
    for c, region in enumerate(Data['Regions'][whichscan]):
        plt.annotate('%s' % str(c + 1),
                     xy=(region.centroid[1] + 222, region.centroid[0]),
                     color='black',
                     va='center',
                     bbox=dict(fc="white", alpha=0.618))
    plt.title('MIP & Calculated labels')
    plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
    plt.axis('off')
    plt.subplot(222)
    plt.imshow(Data['MIP_Axial'][whichscan])
    plt.imshow(numpy.ma.masked_equal(Data['VialLabels'][whichscan], False), cmap='viridis_r')
    for c, region in enumerate(Data['Regions_Ordered'][whichscan]):
        plt.annotate('%s:%s' % (c + 1, Data['FishID'][whichscan][c]),
                     xy=(region.centroid[1], region.centroid[0]),
                     color='black',
                     fontsize=8,
                     va='center',
                     ha='center',
                     bbox=dict(fc="white", alpha=0.618))
        # draw rectangle around segmented fish
        # Bounding box (min_row, min_col, max_row, max_col)
        min_row, min_col, max_row, max_col = region.bbox
        bx = (min_col - buffer, max_col + buffer, max_col + buffer, min_col - buffer, min_col - buffer)
        by = (min_row - buffer, min_row - buffer, max_row + buffer, max_row + buffer, min_row + buffer)
        plt.plot(bx, by, '--')
        # minr, minc, maxr, maxc = region.bbox
        # bx = (minc, maxc, maxc, minc, minc)
        # by = (minr, minr, maxr, maxr, minr)
        # plt.plot(bx, by, '-r')
        # plt.scatter(region.centroid[1], region.centroid[0])
    plt.title('Resorted label:mapped ID')
    plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
    plt.axis('off')
    plt.subplot(223)
    plt.imshow(Data['LabbookImage'][whichscan])
    plt.axis('off')
    plt.title('Photo of labbook')
    plt.subplot(224)
    plt.imshow(Data['OverviewImage'][whichscan])
    plt.axis('off')
    plt.title('Photo of tubes')
    plt.suptitle('Bucket %s' % Data['Bucket'][whichscan])
    plt.savefig('%s.%s.Labels.Check.png' % (os.path.join(os.path.dirname(Data['Folder'][whichscan]), Data.Sample[whichscan]), Data.Scan[whichscan]))
    plt.show()
    return

In [None]:
mapping(5, verbose=True)

In [None]:
# Test checking function
buffer = 50
doublecheck_fish_position(5)

In [None]:
buffer = 50
for i in range(len(Data)):
    doublecheck_fish_position(i)

In [None]:
def regionextractor(whichscan, buffer=50, verbose=True):
    os.makedirs(Data.Folder[whichscan] + '_regions', exist_ok=True)
    for c, region in tqdm(enumerate(Data['Regions_Ordered'][whichscan]),
                          total=len(Data['Regions_Ordered'][whichscan]),
                          desc='Extracting and visualizing regions'):
        outputfilename = os.path.join(Data.Folder[whichscan] + '_regions', 'region_%s_%s.zarr' % (str(c + 1),
                                                                                                  Data['FishID'][whichscan][c]))
        if not os.path.exists(outputfilename):
            # Crop current region out of reconstructions stack, drop RGB axis and rechunk, making for even more efficient access
            currentregion = Reconstructions[whichscan][:, region.bbox[0] - buffer:region.bbox[2] + buffer, region.bbox[1] - buffer:region.bbox[3] + buffer].rechunk('auto')
            if verbose:
                print('Writing to %s. This takes a while...' % outputfilename[len(Root) + 1:])
            dask.array.to_zarr(currentregion, outputfilename)
        if verbose:
            # Read written file back in, so we can profit from the rechunking
            currentregion = dask.array.from_zarr(outputfilename)
            plt.subplot(2, 6, c + 1)
            plt.imshow(Data['MIP_Axial'][whichscan][region.bbox[0] - buffer:region.bbox[2] + buffer, region.bbox[1] - buffer:region.bbox[3] + buffer])
            plt.imshow(numpy.ma.masked_equal(Data['VialLabels'][whichscan][region.bbox[0] - buffer:region.bbox[2] + buffer, region.bbox[1] - buffer:region.bbox[3] + buffer],
                                             False),
                       cmap='viridis_r')
            plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
            plt.title('Cut of MIP\n%s' % str(c + 1))
            plt.axis('off')
            # plt.subplot(6, 2, (2 * c ) + 2)
            plt.subplot(2, 6, c + 1 + 6)
            # Recalculate MIP for double-checking
            plt.imshow(currentregion.max(axis=0))
            plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichscan], 'um'))
            plt.title('MIP of cut\n%s' % Data['FishID'][whichscan][c])
            plt.axis('off')
        plt.suptitle('Bucket %s' % Data['Bucket'][whichscan])
    if verbose:
        plt.savefig('%s.%s.Regions.Check.png' % (os.path.join(os.path.dirname(Data['Folder'][whichscan] + '_regions'),
                                                              Data.Sample[whichscan]),
                                                 Data.Scan[whichscan]))
        plt.show()
    return

In [None]:
buffer = 50
regionextractor(0, buffer=buffer, verbose=True)

In [None]:
# trackingsheet.iloc[0]

In [None]:
# Explicitly state the buffer, we want it later for adding the crop region to the regional log files
buffer = 50
for i in range(len(Data)):
    regionextractor(i, buffer=buffer, verbose=True)

In [None]:
# # Doubleckeck our labels with the tracking sheet
# for c, row in Data.iterrows():
#     print(20 * '-', row.Bucket, 20 * '-')
#     print('FishNumber |', 'Bens ID |', 'Our ID')
#     for d, fishnumber in enumerate(row.FishNumber):
#         print(fishnumber,
#               trackingsheet.iloc[fishnumber - 1]['Unique_ID'],
#               row.FishID[d])

In [None]:
# Read in regions again
# Remember to sort them here, otherwise we'll mix them up willy-nilly
RegionZarrFiles = [sorted(glob.glob(os.path.join(folder + '_regions', '*.zarr'))) for folder in Data['Folder']]
Regions = [[dask.array.from_zarr(f) for f in files] for files in RegionZarrFiles]

In [None]:
# We want to generate log files for the cutout regions
# Aeons ago, we wrote a little wrapper function to log stuff at TOMCAT
# https://github.com/habi/TOMCAT/blob/master/postscan/StackedScanOverlapFinder.py#L104
# The function below is slightly tweaked from there
def myLogger(logfilename, verbose=False):
    import logging
    logger = logging.getLogger(logfilename)
    logger.setLevel(logging.INFO)
    handler = logging.FileHandler(logfilename, 'w')
    logger.addHandler(handler)
    if verbose:
        print('Logging to %s' % logfilename)
    return logger
# Then write to the file with
# logfile = myLogger(Filename))
# logfile.info('Put this into the log file')

In [None]:
# Save out a log file
for c, row in tqdm(Data.iterrows(), total=len(Data), desc='Writing log files for regions'):
    for d, region in tqdm(enumerate(row.Regions),
                          total=len(row.Regions),
                          desc=Data.Folder[c][len(Root) + 1:],
                          leave=False):
        # Generate output directory
        outputdir = os.path.join(row.Folder + '_regions', row['FishID'][d])
        os.makedirs(outputdir, exist_ok=True)
        # Generate logfile name
        logfilename = os.path.join(outputdir, row['FishID'][d] + '.log')
        # Delete logfile, if it already exists
        if os.path.exists(logfilename):
            os.remove(logfilename)
        logfile = myLogger(logfilename)
        logfile.info('Scan = %s' % os.path.join(row.Sample, row.Scan))
        logfile.info('Voxel size = %s um' % row.Voxelsize)
        logfile.info('ID = %s' % row['FishID'][d])
        logfile.info('Vial = %s' % str(d + 1))
        logfile.info('Centroid (x,y) in the original stack = %s, %s' % (int(round(region.centroid[1])), int(round(region.centroid[0]))))
        logfile.info('Bounding box (x1:x2, y1:y2) of this region in the original stack = %s:%s, %s:%s' % (region.bbox[1] - buffer, region.bbox[3] + buffer,
                                                                                                          region.bbox[0] - buffer, region.bbox[2] + buffer))

In [None]:
# Save out one .nrrd file per extracted region
for c, row in tqdm(Data.iterrows(),
                   total=len(Data),
                   desc='Saving out .nrrd file for each region of each bucket'):
    for d, zarrfile in tqdm(enumerate(Regions[c]),
                            total=len(Regions[c]),
                            desc=Data.Folder[c][len(Root) + 1:],
                            leave=False):
        # Generate output name
        outputname = os.path.join(row.Folder + '_regions', row['FishID'][d] + '.nrrd')
        # Generate header that goes into the file
        # https://pynrrd.readthedocs.io/en/stable/examples.html#example-with-fields-and-custom-fields
        header = {'encoding': 'raw',
                  'units': ['mm', 'mm', 'mm'],
                  'spacings': [row.Voxelsize / 1000, row.Voxelsize / 1000, row.Voxelsize / 1000]}
        # Write out file with https://github.com/mhe/pynrrd/
        if not os.path.exists(outputname):
            nrrd.write(outputname,
                       zarrfile.compute(),
                       header,
                       index_order='C'  # with c-index order we get the data out with the stack in z-direction
                       )

In [None]:
def imsaver(image, filename):
    '''
    Function for parallelizing writing out images
    '''
    if not os.path.exists(filename):  # only do something if there's no image on disk yet
        # if image.mean():  # only write something if there's something in the image
        imageio.imwrite(filename, image.astype('uint8'))

In [None]:
# Save out PNG slices
for c, row in tqdm(Data.iterrows(),
                   total=len(Data),
                   desc='Saving out PNGs for each region of each bucket'):
    for d, zarrfile in tqdm(enumerate(Regions[c]),
                            total=len(Regions[c]),
                            desc=Data.Folder[c][len(Root) + 1:],
                            leave=False):
        # print(zarrfile.shape)
        # plt.imshow(zarrfile[666])
        # plt.show()
        # Make output directory
        outputdir = os.path.join(row.Folder + '_regions', row['FishID'][d])
        os.makedirs(outputdir, exist_ok=True)
        outputfilenames = [os.path.join(outputdir,
                                        os.path.basename(fn).replace(Data.Sample[c], Data['FishID'][c][d])) for fn in Data['Filenames Reconstructions'][c]]
        parallelize = True
        if parallelize:
            # Hat tip to Oleksiy for providing a snippet to parallelize the PNG writing
            # It is paramount that the filenames are sorted though!
            Parallel(n_jobs=-1)(delayed(imsaver)(zarrfile[slice],
                                                 outputfilenames[slice]) for slice in range(len(outputfilenames)))
        else:
            for slice in tqdm(range(len(outputfilenames)),
                              total=len(outputfilenames),
                              desc='%s' % os.path.splitext(RegionZarrFiles[c][d])[0][len(Root) + 1:],
                              leave=False):
                if not os.path.exists(outputfilenames[slice]):
                    imageio.imwrite(outputfilenames[slice], zarrfile[slice].astype('uint8'))

In [None]:
def thresholder(stack, discard=5, verbose=False):
    '''
    Threshold function to reliably threshold *only* bones
    A simple 4-class multiotsu, returning only the middle threshold
    '''
    thresholds = skimage.filters.threshold_multiotsu(stack[stack > discard].compute(),
                                                     classes=4)
    if verbose:
        histogram, bins = dask.array.histogram(stack,
                                               bins=2**8,
                                               range=[0, 2**8])
        plt.semilogy(histogram)
        plt.axvline(discard,
                    label='completely discarded, below %s' % discard,
                    color='red')
        for t in thresholds:
            plt.axvline(t, label='threshold %s' % t)
        plt.xlim([0, 2**8])
        plt.legend()
        seaborn.despine()
        plt.show()
    # Return only the middle threshold value
    return thresholds[1]

In [None]:
whichbucket = 4
whichID = 3

In [None]:
threshold = thresholder(Regions[whichbucket][whichID], verbose=True)
print(threshold)

In [None]:
slice = 1500
plt.subplot(131)
plt.imshow(Regions[whichbucket][whichID][slice])
plt.title('%s' % Data.FishID[whichbucket][whichID])
plt.axis('off')
plt.subplot(132)
plt.imshow((Regions[whichbucket][whichID][slice] > threshold))
plt.title('%s > %s' % (Data.FishID[whichbucket][whichID], threshold))
plt.axis('off')
plt.subplot(133)
plt.imshow(Regions[whichbucket][whichID][slice])
plt.imshow(dask.array.ma.masked_equal((Regions[whichbucket][whichID][slice] > threshold), 0).compute(),
           cmap='viridis_r',
           alpha=0.618)
plt.title('Overlay')
plt.axis('off')
plt.show()

In [None]:
# Calculate threshold for each separated region
Data['RegionThreshold'] = [[thresholder(rg) for rg in regions] for regions in Regions]

In [None]:
Data['RegionThreshold']

In [None]:
# for c, row in Data.iterrows():
#     print(row.Bucket)
#     for d, region in enumerate(Regions[c]):
#         plt.imshow(region[len(region) // 5] > row['RegionThreshold'][d] * 2)
#         plt.title('Slice %s of %s > %s' % (len(region) // 5, row['FishID'][d], row['RegionThreshold'][d]))
#         plt.gca().add_artist(ScaleBar(Data['Voxelsize'][c], 'um'))
#         plt.axis('off')
#         plt.show()

In [None]:
# Write out thresholded regions as .zarr files
for c, row in tqdm(Data.iterrows(),
                   total=len(Data),
                   desc='Saving out bucket'):
    for d, region in tqdm(enumerate(Regions[c]),
                          total=len(Regions[c]),
                          desc='Saving out regions',
                          leave=False):
        outputfilename = RegionZarrFiles[c][d].replace('rec_regions',
                                                       'rec_regions_thresholded').replace(row.FishID[d],
                                                                                          '%s_thresholded_%03d' % (row.FishID[d], row.RegionThreshold[d]))
        if not os.path.exists(outputfilename):
            print('Writing %s > %s to %s.' % (row.FishID[d],
                                              row.RegionThreshold[d],
                                              outputfilename[len(Root) + 1:]))
            dask.array.to_zarr((region > row.RegionThreshold[d]), outputfilename)

In [None]:
# Load the thresholded regions
ThresholdedRegionZarrFiles = [sorted(glob.glob(os.path.join(folder + '_regions_thresholded', '*.zarr'))) for folder in Data['Folder']]
ThresholdedRegions = [[dask.array.from_zarr(f) for f in files] for files in ThresholdedRegionZarrFiles]

In [None]:
# Save out a log file for the thresholded files
for c, row in tqdm(Data.iterrows(), total=len(Data), desc='Writing log files for regions'):
    for d, region in tqdm(enumerate(row.Regions),
                          total=len(row.Regions),
                          desc=Data.Folder[c][len(Root) + 1:],
                          leave=False):
        # Generate output directory
        outputdir = os.path.join(row.Folder + '_regions_thresholded',
                                 row['FishID'][d]).replace(row.FishID[d],
                                                           '%s_thresholded_%03d' % (row.FishID[d], row.RegionThreshold[d]))
        os.makedirs(outputdir, exist_ok=True)
        # Generate logfile name
        logfilename = os.path.join(outputdir, row['FishID'][d] + '.log')
        # Delete logfile, if it already exists
        if os.path.exists(logfilename):
            os.remove(logfilename)
        logfile = myLogger(logfilename)
        logfile.info('Scan = %s' % os.path.join(row.Sample, row.Scan))
        logfile.info('Voxel size = %s um' % row.Voxelsize)
        logfile.info('ID = %s' % row['FishID'][d])
        logfile.info('Vial = %s' % str(d + 1))
        logfile.info('Centroid (x,y) in the original stack = %s, %s' % (int(round(region.centroid[1])), int(round(region.centroid[0]))))
        logfile.info('Bounding box (x1:x2, y1:y2) of this region in the original stack = %s:%s, %s:%s' % (region.bbox[1] - buffer, region.bbox[3] + buffer,
                                                                                                          region.bbox[0] - buffer, region.bbox[2] + buffer))
        logfile.info('Threshold = %s' % row.RegionThreshold[d])

In [None]:
# Save out one .nrrd file per thresholded region
for c, row in tqdm(Data.iterrows(),
                   total=len(Data),
                   desc='Saving out .nrrd file for each thresholded region of each bucket'):
    for d, zarrfile in tqdm(enumerate(ThresholdedRegions[c]),
                            total=len(ThresholdedRegions[c]),
                            desc=Data.Folder[c][len(Root) + 1:],
                            leave=False):
        # Generate output name
        outputname = os.path.join(row.Folder + '_regions_thresholded',
                                  row['FishID'][d]).replace(row.FishID[d],
                                                            '%s_thresholded_%03d.nrrd' % (row.FishID[d], row.RegionThreshold[d]))
        # Generate header that goes into the file
        # https://pynrrd.readthedocs.io/en/stable/examples.html#example-with-fields-and-custom-fields
        header = {'encoding': 'raw',
                  'units': ['mm', 'mm', 'mm'],
                  'spacings': [row.Voxelsize / 1000, row.Voxelsize / 1000, row.Voxelsize / 1000]}
        # Write out file with https://github.com/mhe/pynrrd/
        if not os.path.exists(outputname):
            nrrd.write(outputname,
                       zarrfile.compute().astype('uint8'),
                       header,
                       index_order='C'  # with c-index order we get the data out with the stack in z-direction
                       )

In [None]:
# Save out thresholded PNG slices 
for c, row in tqdm(Data.iterrows(),
                   total=len(Data),
                   desc='Saving out PNGs for each thresholded region of each bucket'):
    for d, zarrfile in tqdm(enumerate(ThresholdedRegions[c]),
                            total=len(ThresholdedRegions[c]),
                            desc=Data.Folder[c][len(Root) + 1:],
                            leave=False):
        # Make output directory
        outputdir = os.path.join(row.Folder + '_regions_thresholded', row['FishID'][d]).replace(row.FishID[d],'%s_thresholded_%03d' % (row.FishID[d], row.RegionThreshold[d]))
        os.makedirs(outputdir, exist_ok=True)
        # Write threshold value to file names


        outputfilenames = [os.path.join(outputdir, os.path.basename(fn)
                                        .replace(Data.Sample[c], Data['FishID'][c][d]))
                                        .replace('_rec0', '_thresholded_%03d_rec0' % row.RegionThreshold[d]) for fn in Data['Filenames Reconstructions'][c]]
        parallelize = True
        if parallelize:
            # Hat tip to Oleksiy for providing a snippet to parallelize the PNG writing 
            # It is paramount that the filenames are sorted though!
            Parallel(n_jobs=-1)(delayed(imsaver)(zarrfile[slice],
                                                 outputfilenames[slice]) for slice in range(len(outputfilenames)))
        else:
            for slice in tqdm(range(len(outputfilenames)),
                              total=len(outputfilenames),
                              desc='%s' % os.path.splitext(RegionZarrFiles[c][d])[0][len(Root) + 1:],
                              leave=False):
                if not os.path.exists(outputfilenames[slice]):
                    imageio.imwrite(outputfilenames[slice], zarrfile[slice].astype('uint8'))

In [None]:
import k3d
import math
import numpy as np
from k3d.colormaps import matplotlib_color_maps

In [None]:
subsample = 3

In [None]:
whichbucket = 3
whichfish = 0

In [None]:
print('Displaying fish %s (region %s from bucket %s) below' % (Data['FishID'][whichbucket][whichfish],
                                                               whichfish + 1,
                                                               Data['Bucket'][whichbucket]))

In [None]:
currentfish = Regions[whichbucket][whichfish][::subsample, ::subsample, ::subsample].astype(np.float16).compute()

In [None]:
# Load fish with correct bounds: https://github.com/K3D-tools/K3D-jupyter/issues/417#issuecomment-1557778798
fish = k3d.volume(currentfish,
                  bounds=[0, Data['Voxelsize'][whichbucket] * currentfish.shape[2],
                          0, Data['Voxelsize'][whichbucket] * currentfish.shape[1],
                          0, Data['Voxelsize'][whichbucket] * currentfish.shape[0]])
plot = k3d.plot()
plot += fish
plot.display()

In [None]:
currentfish_thresholded = ThresholdedRegions[whichbucket][whichfish][::subsample, ::subsample, ::subsample].astype(np.float16).compute()

In [None]:
# Load fish with correct bounds: https://github.com/K3D-tools/K3D-jupyter/issues/417#issuecomment-1557778798
thresholdedfish = k3d.volume(currentfish_thresholded,
                             bounds=[0, Data['Voxelsize'][whichbucket] * currentfish_thresholded.shape[2],
                                     0, Data['Voxelsize'][whichbucket] * currentfish_thresholded.shape[1],
                                     0, Data['Voxelsize'][whichbucket] * currentfish_thresholded.shape[0]],
                             color_map=matplotlib_color_maps.Bone
                             )
plot = k3d.plot()
plot += thresholdedfish
plot.display()

In [None]:
# Set nice view above and save camera state
# https://github.com/K3D-tools/K3D-jupyter/issues/417
plot.camera

In [None]:
# Save out HTML page, the bucket directory
outputname = os.path.join(os.path.dirname(Data['Folder'][whichbucket]),
                          '%s.3D.html' % (Data['FishID'][whichbucket][whichfish]))
if not os.path.exists(outputname):
    with open(outputname, "w") as f:
        plot.camera = [12233.580110967298,
                       1495.1256137363334,
                       8669.759910429906,
                       1567.5184326171875,
                       2295.02685546875,
                       11340.1328125,
                       0.018134042199251698,
                       -0.9899408456925172,
                       0.14031492630187442]
        f.write(plot.get_snapshot())
    print('3D view saved to %s' % outputname)
else:
    print('3D view was already saved to %s, not saving it again' % outputname)

In [None]:
StopThisThingHere==

In [None]:
# Lets try to label the stack
def labeler(stack):
    return labeled_stack

In [None]:
# Minimize .zarr files to only fish-extent

In [None]:
for c, region in enumerate(Regions):
    outfilename = RegionZarrFiles[c].replace('_rec.zarr', '.MIPs.png')
    if not os.path.exists(outfilename):
        for d, direction in enumerate(directions):
            plt.subplot(1, 3 , d+1)
            plt.imshow(region.max(axis=d))
            plt.title('Region %s\n%s MIP' % (c, direction))
            plt.axis('off')
            plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
        plt.savefig(outfilename)
        plt.show()
    else:
        print('MIP already saved to %s' % outfilename)

In [None]:
# Calculate the histograms of one of the MIPs
# Caveat: dask.da.histogram returns histogram AND bins, making each histogram a 'nested' list of [h, b]
Histograms = [dask.array.histogram(dask.array.array(region),
                                          bins=2**8,
                                          range=[0, 2**8]) for region in Regions]
# Actually compute the data and put only h into the dataframe, so we can use it below.
# Discard the bins
Histograms = [h.compute() for h, b in Histograms]

In [None]:
Thresholds = [skimage.filters.threshold_otsu(region[:,:,:,0][region[:,:,:,0]>10].compute()) for region in Regions]

In [None]:
for c, hist in enumerate(Histograms):
    plt.semilogy(hist,
                 c=seaborn.color_palette()[c])
    plt.axvline(Thresholds[c],
                label='R%s: %s' % (c, Thresholds[c]),
                c=seaborn.color_palette()[c])
plt.legend()
plt.show()

In [None]:
for c, region in enumerate(Regions):
    outfilename = RegionZarrFiles[c].replace('_rec.zarr', '.MIPsasdfasdfa.png')
    region = region[:,:,:,0].compute()
    if not os.path.exists(outfilename):
        for d, direction in enumerate(directions):
            plt.subplot(1, 3 , d+1)
            plt.imshow((region>Thresholds[c]).max(axis=d))
            plt.title('Region %s\n%s MIP' % (c, direction))
            plt.axis('off')
            plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
        # plt.savefig(outfilename)
        plt.show()
    else:
        print('MIP already saved to %s' % outfilename[len(Root) + 1:])

In [None]:
Thresholds

In [None]:
labels = skimage.morphology.label(Regions[0][:,:,:,0]>Thresholds[0])

In [None]:
import zarr

In [None]:
# Label fish and save out as .zarr
os.makedirs(Data.Folder[whichscan] + '_labeled', exist_ok=True)
for c, region in tqdm(enumerate(Regions), total=len(regions)):
    plt.subplot(1, 6, c+1)
    currentregion = skimage.morphology.label(region[:,:,:,0]>Thresholds[c])
    outputfilename = os.path.join(Data.Folder[whichscan] + '_labeled', 'region_%s_rec_labeled.zarr' % str(c+1))
    if not os.path.exists(outputfilename):
        print('writing to', outputfilename)
        zarr_out_3D_convenient = zarr.save(outputfilename, currentregion)
    else:
        print(outputfilename[len(Root) + 1:], 'already exists')
    currentmip = currentregion.max(axis=0)
    plt.imshow(currentmip)
    plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
    plt.title('Region %s' % c)
    plt.axis('off')
plt.show()

In [None]:
# Read in labels again
LabelZarrFiles = sorted(glob.glob(os.path.join(Data.Folder[whichscan] + '_labeled', '*.zarr')))
Labels = [dask.array.from_zarr(file) for file in LabelZarrFiles]

In [None]:
for c, region in enumerate(Labels):
    outfilename = LabelZarrFiles[c].replace('_rec_labeled.zarr', '.MIPs.labeled.png')
    print(outfilename)
    # region = region[:,:,:,0].compute()
    if not os.path.exists(outfilename):
        for d, direction in enumerate(directions):
            plt.subplot(1, 3 , d+1)
            plt.imshow((region).max(axis=d))
            plt.title('Region %s\n%s MIP' % (c, direction))
            plt.axis('off')
            plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
        plt.savefig(outfilename)
        plt.show()
    else:
        print('MIP overview image already saved to %s' % outfilename[len(Root) + 1:])

In [None]:
for c, region in enumerate(Labels):
    outfilename = LabelZarrFiles[c].replace('_rec_labeled.zarr', '.Summed.labeled.png')
    if not os.path.exists(outfilename):
        for d, direction in enumerate(directions):
            plt.subplot(1, 3 , d+1)
            plt.imshow((region).sum(axis=d))
            plt.title('Region %s\n%s Sum' % (c, direction))
            plt.axis('off')
            plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
        plt.savefig(outfilename)
        plt.show()
    else:
        print('Summed image already saved to %s' % outfilename[len(Root) + 1:])

In [None]:
slice = 333
for c, r in enumerate(Regions):
    plt.subplot(2,3,c+1)
    plt.imshow(r[slice])
    # plt.imshow((r[:,:,:,0]>Thresholds[c])[slice], alpha=0.5, cmap='viridis')
    plt.imshow(skimage.morphology.label(r[:,:,:,0][slice]>Thresholds[c]), alpha=0.5, cmap='viridis')
    plt.title('R%s' % c)
    plt.axis('off')
plt.show()

In [None]:
Labels[0]

In [None]:
# Save out PNG slices for later use
for c, zarrfile in tqdm(enumerate(Labels),
                        total=len(Labels),
                        desc=Data.Folder[whichscan][len(Root) + 1:]):
    # Make output directory
    os.makedirs(os.path.splitext(LabelZarrFiles[c])[0], exist_ok=True)
    for d, slice in tqdm(enumerate(zarrfile),
                         total=len(zarrfile),
                         desc='Saving to %s' % os.path.splitext(LabelZarrFiles[c])[0][len(Root) + 1:],
                         leave=False):
        outfilepath = os.path.join(os.path.splitext(LabelZarrFiles[c])[0],
                                   os.path.basename(Data['Filenames Reconstructions'][whichscan][d])).replace('_rec00', '_region_%s_labeled_rec00' % str(c+1))
        if not os.path.exists(outfilepath):
            # plt.imshow(slice.compute())
            # plt.show()
            # print(type(slice))
            imageio.imwrite(outfilepath, slice.compute().astype('uint8'))

In [None]:
blobs = skimage.feature.blob_dog(clean)

In [None]:
blobs

In [None]:
plt.subplot(121)
plt.imshow(clean)
plt.subplot(122)
plt.imshow(mip)