# Fiddle with the EAWAG scans
Look at the orientation and see if we can do some cropping based on landmarks.

In [None]:
import platform
import os
import glob
import pandas
import imageio
import numpy
import scipy
import k3d
import matplotlib
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
import seaborn
import dask
import dask_image.imread
import skimage
from tqdm.auto import tqdm, trange
import math
from numcodecs import Blosc
from skimage.segmentation import random_walker

In [None]:
# Set dask temporary folder
# Do this before creating a client: https://stackoverflow.com/a/62804525/323100
import tempfile
if 'Linux' in platform.system():
    # Check if me mounted the FastSSD, otherwise go to standard tmp file
    if os.path.exists(os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')):
        tmp = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD', 'tmp')
    else:
        tmp = tempfile.gettempdir()
elif 'Darwin' in platform.system():
    tmp = tempfile.gettempdir()
else:
    if 'anaklin' in platform.node():
        tmp = os.path.join('F:\\tmp')
    else:
        tmp = os.path.join('D:\\tmp')
dask.config.set({'temporary_directory': tmp})
print('Dask temporary files go to %s' % dask.config.get('temporary_directory'))

In [None]:
from dask.distributed import Client
client = Client()

In [None]:
client

In [None]:
print('You can seee what DASK is doing at "http://localhost:%s/status"' % client.scheduler_info()['services']['dashboard'])

In [None]:
# # Ignore warnings in the notebook
# import warnings
# warnings.filterwarnings("ignore")

In [None]:
# Set seaborn theme
seaborn.set_theme(context='notebook', style='dark')
# Set up figure defaults
plt.rc('image', cmap='gray', interpolation='nearest')  # Display all images in b&w and with 'nearest' interpolation
plt.rcParams['figure.figsize'] = (8, 4.5)  # Size up figures a bit
plt.rcParams['figure.dpi'] = 200

In [None]:
# Setup scale bar defaults
plt.rcParams['scalebar.location'] = 'lower right'
plt.rcParams['scalebar.frameon'] = False
plt.rcParams['scalebar.color'] = 'white'

In [None]:
# Display all plots identically
lines = 3
# And then do something like
# plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)

In [None]:
# Different locations if running either on Linux or Windows
FastSSD = True
overthere = False  # Load the data directly from the iee-research_storage drive
nanoct = True  # Load the data directly from the 2214
# to speed things up significantly
if 'Linux' in platform.system():
    if FastSSD:
        BasePath = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')
    elif overthere:
        BasePath = os.path.join(os.sep, 'home', 'habi', 'research-storage-iee')
    elif nanoct:
        BasePath = os.path.join(os.path.sep, 'home', 'habi', '2214')
    else:
        BasePath = os.path.join(os.sep, 'home', 'habi', '1272')
elif 'Darwin' in platform.system():
    FastSSD = False
    BasePath = os.path.join('/Users/habi/Dev/EAWAG/Data')
elif 'Windows' in platform.system():
    if FastSSD:
        BasePath = os.path.join('F:\\')
    else:
        if 'anaklin' in platform.node():
            BasePath = os.path.join('S:\\')
        else:
            BasePath = os.path.join('D:\\Results')
if not overthere:
    Root = os.path.join(BasePath, 'EAWAG')
else:
    Root = BasePath
# if overthere:
#         Root = os.path.join('I:\\microCTupload')
print('We are loading all the data from %s' % Root)

In [None]:
def get_pixelsize(logfile):
    """Get the pixel size from the scan log file"""
    with open(logfile, 'r') as f:
        for line in f:
            if 'Image Pixel' in line and 'Scaled' not in line:
                pixelsize = float(line.split('=')[1])
    return(pixelsize)

In [None]:
def get_git_hash():
    '''
    Get the current git hash from the repository.
    Based on http://stackoverflow.com/a/949391/323100 and
    http://stackoverflow.com/a/18283905/323100
    '''
    from subprocess import Popen, PIPE
    import os
    gitprocess = Popen(['git',
                        '--git-dir',
                        os.path.join(os.getcwd(), '.git'),
                        'rev-parse',
                        '--short',
                        '--verify',
                        'HEAD'],
                       stdout=PIPE)
    (output, _) = gitprocess.communicate()
    return output.strip().decode("utf-8")

In [None]:
# # Make directory for output
# OutPutDir = os.path.join(os.getcwd(), 'Output', get_git_hash())
# print('We are saving all the output to %s' % OutPutDir)
# os.makedirs(OutPutDir, exist_ok=True)

In [None]:
# Make us a dataframe for saving all that we need
Data = pandas.DataFrame()

In [None]:
# Get *all* log files, unsorted but fast
Data['LogFile'] = [os.path.join(root, name)
                   for root, dirs, files in os.walk(Root)
                   for name in files
                   if name.endswith((".log"))]

In [None]:
print('We found %s log files in %s' % (len(Data), Root))

In [None]:
# Limit *all* the data to only the 'head' scans
Data = Data[Data['LogFile'].str.contains('head')]

In [None]:
print('We have %s log files with "head" in their name in %s' % ((len(Data)), Root))

In [None]:
# Generate folder name
Data['Folder'] = [os.path.dirname(f) for f in Data['LogFile']]

In [None]:
# Get rid of all non-rec logfiles
for c, row in Data.iterrows():
    if 'rec' not in row.Folder:
        Data.drop([c], inplace=True)
    elif 'SubScan' in row.Folder:
        Data.drop([c], inplace=True)
    elif 'rectmp.log' in row.LogFile:
        Data.drop([c], inplace=True)
# Reset dataframe to something that we would get if we only would have loaded the 'rec' files
Data = Data.reset_index(drop=True)

In [None]:
# Generate us some meaningful colums
Data['Fish'] = [l[len(Root) + 1:].split(os.sep)[0] for l in Data['LogFile']]
Data['Scan'] = ['_'.join(l[len(Root) + 1:].split(os.sep)[1:-1]) for l in Data['LogFile']]

In [None]:
Data.tail()

In [None]:
# Get the file names of the reconstructions
Data['Reconstructions'] = [[os.path.join(root, name)
                            for root, dirs, files in os.walk(f)
                            for name in files
                            if 'rec0' in name and name.endswith((".png"))] for f in Data['Folder']]
# Count how many files we have
Data['Number of reconstructions'] = [len(r) for r in Data.Reconstructions]

In [None]:
# Drop samples which have either not been reconstructed yet or of which we deleted the reconstructions with
# `find . -name "*rec*.png" -type f -mtime +333 -delete`
# Based on https://stackoverflow.com/a/13851602
# for c, row in Data.iterrows():
#     if not row['Number of reconstructions']:
#         print('%s contains no PNG files, we might be currently reconstructing it' % row.Folder)
Data = Data[Data['Number of reconstructions'] > 0]
Data.reset_index(drop=True, inplace=True)
print('We have %s folders with reconstructions in %s' % ((len(Data)), Root))

In [None]:
# Get parameters we need from the log files
Data['Voxelsize'] = [get_pixelsize(log) for log in Data['LogFile']]

In [None]:
# Load all reconstructions into DASK arrays
Reconstructions = [None] * len(Data)
for c, row in tqdm(Data.iterrows(),
                   desc='Load reconstructions',
                   total=len(Data)):
    Reconstructions[c] = dask_image.imread.imread(os.path.join(row['Folder'],
                                                               '*rec*.png'))

In [None]:
# Check if something went wrong
# for file in Data['OutputNameRec']:
#     print(file)
#     dask.array.from_zarr(file)

In [None]:
# How big are the datasets?
Data['Size'] = [rec.shape for rec in Reconstructions]

In [None]:
# The three cardinal directions
# Names adapted to fishes: https://en.wikipedia.org/wiki/Fish_anatomy#Body
directions = ['Anteroposterior',
              'Lateral',
              'Dorsoventral']

In [None]:
# Read in previously generated MIPs or calculate them
for d, direction in enumerate(directions):
    Data['MIP_' + direction] = ''
for c, row in tqdm(Data.iterrows(), desc='Calculating MIPs', total=len(Data)):
    for d, direction in tqdm(enumerate(directions),
                             desc='%s/%s: %s' % (row['Fish'], row['Scan'], direction),
                             leave=False,
                             total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.%s.MIP.%s.png' % (row['Fish'], row['Scan'], direction))
        if os.path.exists(outfilepath):
            Data.at[c, 'MIP_' + direction] = dask_image.imread.imread(outfilepath).squeeze()
        else:
            # Generate MIP
            Data.at[c, 'MIP_' + direction] = Reconstructions[c].max(axis=d).compute().squeeze()
            # Save it out
            imageio.imwrite(outfilepath, Data.at[c, 'MIP_' + direction].astype('uint8'))

In [None]:
# Collect views
for c, row in tqdm(Data.iterrows(),
                   desc='Saving overview of MIP images',
                   total=len(Data)):
    outfilepath = os.path.join(os.path.dirname(row['Folder']),
                               '%s.%s.MIPs.png' % (row['Fish'], row['Scan']))
    if not os.path.exists(outfilepath):    
        print('%s/%s: %s' % (c, len(Data), os.path.join(row.Fish, row.Scan)))
        for d, direction in tqdm(enumerate(directions),
                                 desc='%s/%s' % (row['Fish'], row['Scan']),
                                 leave=False,
                                 total=len(directions)):
            plt.subplot(1, 3, d + 1)
            plt.imshow(row['MIP_' + direction])
            plt.gca().add_artist(ScaleBar(row['Voxelsize'],
                                          'um'))
            plt.title('%s MIP' % direction)
            plt.axis('off')
            plt.title('%s\n%s MIP' % (os.path.join(row['Fish'], row['Scan']), direction))
            plt.savefig(outfilepath,
                        transparent=True,
                        bbox_inches='tight')
        plt.show()

The functions below were copied from Hearts-Melly/SubMyocardAnalysis.ipynb, in which we also look at the orientation of things

In [None]:
# From Hearts-Melly/SubMyocardAnalysis.ipynb
def get_properties(roi, verbose=False):
    # Label filled image
    labeled_img = skimage.measure.label(roi)
    # Extract regionprops of image and put data into pandas
    # https://stackoverflow.com/a/66632023/323100
    props = skimage.measure.regionprops_table(labeled_img,
                                              properties=('label',
                                                          'centroid',
                                                          'area',
                                                          'perimeter',
                                                          'orientation'))
    table = pandas.DataFrame(props)
    table_sorted = table.sort_values(by='area', ascending=False)
    # return only the region with the biggest area
    properties = table_sorted.iloc[:1].reset_index()
    if verbose:
        plt.imshow(roi, alpha=0.5)
        plt.title('Original')
        plt.axis('off')
        plt.imshow(numpy.ma.masked_equal(labeled_img, 0), cmap='viridis', alpha=0.5)
        plt.title('Labelled')
        plt.axis('off')
        plt.show()
    return(properties)

In [None]:
def get_largest_region(segmentation, verbose=False):
    # Get out biggest item from https://stackoverflow.com/a/55110923/323100
    labels = skimage.measure.label(segmentation)
    assert(labels.max() != 0)  # assume at least 1 CC
    largestCC = labels == numpy.argmax(numpy.bincount(labels.flat)[1:]) + 1
    if verbose:
        plt.subplot(121)
        plt.imshow(segmentation)
        plt.subplot(122)
        plt.imshow(largestCC)
        plt.suptitle('Largest connected component')
        plt.show()
    return largestCC

In [None]:
def threshold(image, verbose=False):
    # Calculate threshold of image where image is non-zero
    threshold = skimage.filters.threshold_otsu(image[image > 0])
    if verbose:
        plt.subplot(121)
        plt.imshow(image)
        plt.imshow(dask.array.ma.masked_equal(image > threshold, 0),
                   alpha=0.618,
                   cmap='viridis_r')
        plt.subplot(122)
        plt.semilogy(histogram(image), label='Log-Histogram')
        plt.axvline(threshold, label='Otsu threshold: %s' % threshold)
        plt.legend()
        plt.show()
    return(threshold)

In [None]:
# Calculate the histogram of an image
# We can safely assume to only use 8bit images
def histogram(img):
    histogram, bins = dask.array.histogram(dask.array.array(img),
                                           bins=2**8,
                                           range=[0, 2**8])
    return(histogram)

In [None]:
whichone = 4
print(os.path.join(Data.Fish[whichone], Data.Scan[whichone]))
img = dask.array.asarray(Data.MIP_Anteroposterior[whichone])

In [None]:
t = threshold(img.compute(), verbose=True)

In [None]:
a = get_largest_region(img > t, verbose=True)

In [None]:
def get_centroid(img, verbose=False):
    props = get_properties(img)
    # Drawing from https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_regionprops.html
    y0, x0 = props['centroid-0'], props['centroid-1']
    if verbose:
        plt.imshow(img)
        plt.scatter(props['centroid-1'], props['centroid-0'], marker=None, color='r')
        plt.axis('off')
        plt.show()
    return((x0, y0))

In [None]:
def get_contour(filled_img, verbose=False):
    # Contouring from https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_regionprops.html
    largest_region = get_largest_region(filled_img, verbose=False)
    contour = skimage.measure.find_contours(largest_region)
    # Even though we look only at the largest region, we still might get out more than one contour
    # Let's thus sort the list and just continue with the longest one
    (contour).sort(key=len)
    cy, cx = contour[-1].T
    if verbose:
        plt.imshow(filled_img)
        plt.plot(cx, cy, lw=1, c='r')
        plt.axis('off')
        plt.show()
    return(cx, cy)

In [None]:
contour = get_contour(img > t, verbose=True)

In [None]:
centroid = get_centroid(img > t, verbose=True)

In [None]:
def draw_orientation(img, x0, x1, x2, y0, y1, y2, self=False):
    if self:
        plt.imshow(img)
    plt.plot((x0, x1), (y0, y1), '-r', linewidth=1)
    plt.plot((x0, x2), (y0, y2), '-r', linewidth=1)
    if self:
        plt.axis('off')
        plt.show()
    return()

In [None]:
def get_orientation(img, voxelsize, length=10, verbose=False):
    '''
    Get and draw orientation onto image, with a (default) length of 10 mm
    '''
    props = get_properties(img)
    whichlengthdowewant = length
    reallength = whichlengthdowewant / voxelsize * 1000  # mm
    # Drawing from https://scikit-image.org/docs/stable/auto_examples/segmentation/plot_regionprops.htm
    x0, y0 = get_centroid(img)
    x1 = x0 + math.cos(props['orientation']) * reallength
    y1 = y0 - math.sin(props['orientation']) * reallength
    x2 = x0 - math.sin(props['orientation']) * reallength
    y2 = y0 - math.cos(props['orientation']) * reallength
    if verbose:
        plt.imshow(img)
        plt.scatter(props['centroid-1'], props['centroid-0'], marker=None, color='r')
        draw_orientation(img, x0, x1, x2, y0, y1, y2)
        plt.gca().add_artist(ScaleBar(voxelsize, 'um'))
        plt.title('Image with %s mm long orientation bars' % length)
        plt.show()
    return(x0, x1, x2, y0, y1, y2)

In [None]:
x0, x1, x2, y0, y1, y2 = get_orientation(img > t,
                                         Data.Voxelsize[4],
                                         length=2,
                                         verbose=True)

The functions above were copied from Hearts-Melly/SubMyocardAnalysis.ipynb, in which we also look at the orientation of things

In [None]:
a = get_properties(img > t)
print(a)

In [None]:
# Use the angle and centroid to rotate image
img_rotated = numpy.empty_like(img)
img_rotated = skimage.transform.rotate(img.compute(),
                                       angle=numpy.rad2deg(a.orientation[0]),
                                       center=(a['centroid-0'][0], a['centroid-1'][0]),
                                       preserve_range=True)

In [None]:
numpy.rad2deg(a.orientation[0])+90

In [None]:
# Let's show what we did
plt.subplot(121)
plt.imshow(img)
plt.scatter(a['centroid-0'], a['centroid-1'], s=50)
plt.title('Original image with centroid')
plt.subplot(122)
plt.imshow(img_rotated)
plt.scatter(a['centroid-0'], a['centroid-1'], s=50)
plt.title('Image rotated by %0.f°, with centroid' % numpy.rad2deg(a.orientation[0]))
plt.show()

Figure out the otolith position on each of the directional views.

In [None]:
def smoother(curve, frac=0.1):
    ''' Smooth a curve '''
    from statsmodels.nonparametric.smoothers_lowess import lowess
    smoothed = lowess(curve, range(len(curve)), return_sorted=False, frac=frac)
    return(smoothed)

In [None]:
def get_minimum(curve, verbose=False):
    '''
    Function to detect get a 'border' based on the gray value along the image.
    We do this by detecting the minimum of the derivative of the smoothed grayvalue curve
    Based on https://stackoverflow.com/a/28541805/323100 and some manual tweaking
    '''
    smoothed = smoother(curve)
    minimal_diff = numpy.argmin(numpy.diff(smoothed))
    if verbose:
        plt.plot(curve, alpha=0.6, label='Input curve')
        plt.plot(smoothed, label='LOWESS smoothed')
        plt.axvline(minimal_diff, c='r', label='Border')
        plt.legend()
    return(minimal_diff)

In [None]:
def get_maximum(curve, verbose=False):
    '''
    Function to detect get a 'border' based on the gray value along the image.
    We do this by detecting the minimum of the derivative of the smoothed grayvalue curve
    Based on https://stackoverflow.com/a/28541805/323100 and some manual tweaking
    '''
    smoothed = smoother(curve)
    maximal_diff = numpy.argmax(numpy.diff(smoothed))
    if verbose:
        plt.plot(curve, alpha=0.6, label='Input curve')
        plt.plot(smoothed, label='LOWESS smoothed')
        plt.axvline(maximal_diff, c='r', label='Border')
        plt.legend()
    return(maximal_diff)

In [None]:
def get_peak(curve, start=None, stop=None, frac=0.25, height=0.25, verbose=False):
    ''' Find a peak in the smoothed curve '''
    # Peak finding from https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.find_peaks.html
    # Mask a bit at the start and a bit start and end of curve, if desired
    mask = numpy.zeros_like(curve)    
    if start:
        mask[:start] = 1
    if stop:
        mask[stop:] = 1
    if start or stop:
        original_curve = curve
        smoothed = smoother(numpy.ma.masked_where(mask, curve).filled(fill_value=0), frac=frac)
    else:
        smoothed = smoother(curve, frac=frac)
    if verbose:
        print('The input curve has a length of %s' % len(curve))
        if start:
            print('We discard the %s values from the start' % start)
        if stop:
            print('We discard the values from %s to the end' % stop)            
        print('The input to the smoother has a length of %s' % len(numpy.ma.masked_where(mask, curve).compressed()))
    peak, _ = scipy.signal.find_peaks(smoothed, width=100)    
    peak_value = smoothed[peak]
    # Peak width from https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.peak_widths.html
    results_width = scipy.signal.peak_widths(smoothed, peak, rel_height=height)
    if len(peak) > 1:
        # Return only 'higher' peak if we have several
        peak = numpy.asarray(peak[numpy.argmax(peak_value)])
        results_width = [item[numpy.argmax(peak_value)] for item in results_width]
        peak_value = int(peak_value[numpy.argmax(peak_value)])
    # Get actual width
    width = results_width[0]
    # if start:
    #     peak = peak + start
    #     results_width = [v[0]+start for v in results_width]
    if verbose:
        if start or stop:
            plt.plot(original_curve, alpha=0.618, label='Original')
        plt.plot(numpy.ma.masked_where(mask, curve), label='Input')
        plt.plot(smoothed, label='Smoothed (frac=%s)' % frac)
        plt.plot(peak,
                 smoothed[peak],
                 'x',
                 color='C2',
                 label='Peak@%s' % int(peak))
        plt.hlines(*results_width[1:],
                   color="C3",
                   label='Peak width at %d%%: %s' % (100*height, int(width)))
        plt.legend()
        # plt.xlim([0,len(curve)])
        plt.show()

    return(int(peak), int(peak_value), int(width))

In [None]:
# Put gray values in temporary lists
whichone = 19
gv =  [None] * len(directions) 
gh =  [None] * len(directions) 
for c, direction in enumerate(directions):
    gv[c] = Reconstructions[whichone].max(axis=c).sum(axis=0).compute()
    gh[c] = Reconstructions[whichone].max(axis=c).sum(axis=1).compute()

In [None]:
get_peak(gv[0], start=100, stop=1000, verbose=True)

In [None]:
for i in gh:
    get_peak(i, verbose=True)

In [None]:
for i in gv:
    get_peak(i, height=0.25, verbose=True)

In [None]:
def rescale_linear(array, new_min, new_max):
    """Rescale an arrary linearly. From https://stackoverflow.com/a/50011743/323100"""
    minimum, maximum = numpy.min(array), numpy.max(array)
    m = (new_max - new_min) / (maximum - minimum)
    b = new_min - m * minimum
    return m * array + b

In [None]:
def get_grayvalues(image, plane, which):
    ''' get grayvalue along horizontal or vertical image plane '''
    grayvalues = []
    if plane == 'horizontal':
        ax = 0
    elif plane == 'vertical':
        ax = 1
    else:
        print('No plane given, specify either "plane=horizontal" or "plane=vertical"')
        print('Returning EMPTY grayvalues')        
        return(grayvalues)
    if which == 'sum':
        grayvalues = image.sum(axis=ax)
    elif which == 'max':
        grayvalues = image.max(axis=ax)
    else:
        print('No method given, specify either "which=max" or "which=sum"')
        print('Returning EMPTY grayvalues')
        return(grayvalues)
    return(grayvalues)

In [None]:
for dir in directions:
    print(dir)

In [None]:
planes = ['horizontal', 'vertical']

In [None]:
# Generate empty columns to fill in the values we calculate below
for c, direction in enumerate(directions):
    Data['Discard_' + direction + '_fish'] = '' # General discard directions
    Data['Otolith_Peak_' + direction + '_fish'] = '' # Merged peak position
    Data['Otolith_Width_' + direction + '_fish'] = '' # Merged peak position    
    for plane in planes:
        Data['Discard_' + direction + '_image_' + plane] = '' # Copy the general discard values to the ones we need for each image
        Data['Otolith_Peak_'  + direction + '_' + plane] = ''
        Data['Otolith_Width_' + direction + '_' + plane] = ''
        Data['Grayvalues_'    + direction + '_' + plane] = ''

In [None]:
def otolither_region(whichone, discard_front=None, bottom=1000, showregion=True, verbose=False):
    '''
    Modeled after algorithm for finding the enamel/dentin border for the tooth project (https://github.com/habi/zmk-tooth-cohort/blob/master/ToothAnalysis.ipynb), we
    look for a change and bump in the gray values along different axis of the fishes.
    This works out nicely to detect the approximate region of the otoliths.
    '''
    if verbose:
        print('We try to find the otholith for the %s scan of fish %s' % (Data.Scan[whichone], Data.Fish[whichone]))
    
    # Discard some regions of the images for finding the otoliths
    # frontally: contains the teeth
    # back: can contain labels or dorsal fins
    Data.at[whichone, 'Discard_Anteroposterior_fish'] = [Data.MIP_Lateral[whichone].shape[0] // 4,
                                                         Data.MIP_Lateral[whichone].shape[0] - Data.MIP_Lateral[whichone].shape[0] // 5]
    # laterally otoliths are in the middle of the fish    
    Data.at[whichone, 'Discard_Lateral_fish'] = [Data.MIP_Dorsoventral[whichone].shape[1] // 8,
                                            Data.MIP_Dorsoventral[whichone].shape[1] - Data.MIP_Dorsoventral[whichone].shape[1] // 8]
    # bottom: a lot of the fish, no otolith
    # top: often empty
    Data.at[whichone, 'Discard_Dorsoventral_fish'] = [Data.MIP_Anteroposterior[whichone].shape[1] // 3,
                                                 Data.MIP_Anteroposterior[whichone].shape[1] - Data.MIP_Anteroposterior[whichone].shape[1] // 8]
    
    # Copy the discard values to different planar directions for every image.
    # This is *very* hacky, but makes it all run in one single loop
    Data.at[whichone, 'Discard_Anteroposterior_image_horizontal'] = Data['Discard_Dorsoventral_fish'][whichone]
    Data.at[whichone, 'Discard_Anteroposterior_image_vertical']   = Data['Discard_Lateral_fish'][whichone]
    Data.at[whichone, 'Discard_Lateral_image_horizontal']         = Data['Discard_Dorsoventral_fish'][whichone]
    Data.at[whichone, 'Discard_Lateral_image_vertical']           = Data['Discard_Anteroposterior_fish'][whichone]
    Data.at[whichone, 'Discard_Dorsoventral_image_horizontal']    = Data['Discard_Lateral_fish'][whichone]
    Data.at[whichone, 'Discard_Dorsoventral_image_vertical']      = Data['Discard_Anteroposterior_fish'][whichone]    

    if verbose:
        print('We discard the ventral %s and the dorsal %s slices of the fish' % (Data['Discard_Dorsoventral_fish'][whichone][0], Data['Discard_Dorsoventral_fish'][whichone][1]))
        print('We discard the anterior %s and the posterior %s slices of the fish' % (Data['Discard_Anteroposterior_fish'][whichone][0], Data['Discard_Anteroposterior_fish'][whichone][1]))
        print('We discard the fish laterally between slices %s and %s' % (Data['Discard_Lateral_fish'][whichone][0], Data['Discard_Lateral_fish'][whichone][1]))

    for direction in directions:
        for plane in planes:
            if verbose:
                print('Calculating Grayvalues_' + direction + '_' + plane)
            if ('Lateral' in direction) and ('horizontal' in plane):
                method = 'max'
            else:
                method = 'sum'
            Data.at[whichone, 'Grayvalues_' + direction + '_' + plane] = get_grayvalues(Data['MIP_' + direction][whichone],
                                                                                        plane=plane,
                                                                                        which=method)            
            # Our peak finder function returns peak position, peak value and peak width
            # We only need position and width and don't save the value (for now)
            if verbose:
                print(80*'-')
                print(direction, plane)
                print('GV length', len(Data['Grayvalues_' + direction + '_' + plane][whichone]))
                print('For %s/%s we want to discard %s' % (direction, plane, Data.at[whichone, 'Discard_' + direction + '_image_' + plane]))
                print('MIP shape', Data['MIP_' + direction][whichone].shape)

                print(80*'-')
            peak, _, width = get_peak(Data['Grayvalues_' + direction + '_' + plane][whichone],
                                      start=Data['Discard_'+ direction + '_image_' + plane][whichone][0],
                                      stop=Data['Discard_'+ direction + '_image_' + plane][whichone][1],
                                      frac=0.25,
                                      height=0.25,
                                      verbose=False)
            Data.at[whichone, 'Otolith_Peak_' + direction + '_' + plane] = peak
            Data.at[whichone, 'Otolith_Width_' + direction + '_' + plane] = width
    # Calculate means to be on the safe side
    Data.at[whichone,'Otolith_Peak_Anteroposterior_fish'] = int(round(numpy.mean((Data['Otolith_Peak_Dorsoventral_vertical'][whichone], Data['Otolith_Peak_Lateral_vertical'][whichone]))))
    Data.at[whichone,'Otolith_Peak_Lateral_fish'] = int(round(numpy.mean((Data['Otolith_Peak_Anteroposterior_vertical'][whichone], Data['Otolith_Peak_Dorsoventral_horizontal'][whichone]))))
    Data.at[whichone,'Otolith_Peak_Dorsoventral_fish'] = int(round(numpy.mean((Data['Otolith_Peak_Lateral_vertical'][whichone], Data['Otolith_Peak_Anteroposterior_horizontal'][whichone]))))            
    Data.at[whichone,'Otolith_Width_Anteroposterior_fish'] = int(round(numpy.mean((Data['Otolith_Width_Dorsoventral_vertical'][whichone], Data['Otolith_Width_Lateral_vertical'][whichone]))))
    Data.at[whichone,'Otolith_Width_Lateral_fish'] = int(round(numpy.mean((Data['Otolith_Width_Anteroposterior_vertical'][whichone], Data['Otolith_Width_Dorsoventral_horizontal'][whichone]))))
    Data.at[whichone,'Otolith_Width_Dorsoventral_fish'] = int(round(numpy.mean((Data['Otolith_Width_Lateral_vertical'][whichone], Data['Otolith_Width_Anteroposterior_horizontal'][whichone]))))                
    return()

In [None]:
whichone = 6

In [None]:
otolither_region(whichone, verbose=False)

In [None]:
def display_otolith_position(whichone):
    # Display everything
    # Based on https://matplotlib.org/tutorials/intermediate/gridspec.html
    fig = plt.figure(constrained_layout=True)
    gs = fig.add_gridspec(2, 3)
    for c, direction in enumerate(directions):
        mip = fig.add_subplot(gs[0, c])
        # Show image
        plt.imshow(Data['MIP_' + direction][whichone])
        
        # # Show discarded regions
        # if not 'Dorsoventral' in direction:
        #     plt.fill_betweenx(range(Data['MIP_' + direction][whichone].shape[0]),
        #                       0, Data['Discard_Dorsoventral_fish'][whichone][0],
        #                       alpha=0.309, label='discarded up to %s' % Data['Discard_Dorsoventral_fish'][whichone][0],
        #                       color=seaborn.color_palette()[2])
        #     plt.fill_betweenx(range(Data['MIP_' + direction][whichone].shape[0]),
        #                       Data['Discard_Dorsoventral_fish'][whichone][1], Data['MIP_' + direction][whichone].shape[1] - 1,
        #                       alpha=0.309, label='discarded from %s' % Data['Discard_Dorsoventral_fish'][whichone][1],
        #                       color=seaborn.color_palette()[2])            
        # else:
        #     plt.fill_betweenx(range(Data['MIP_' + direction][whichone].shape[0]),
        #                       0, Data['Discard_Lateral_fish'][whichone][0],
        #                       alpha=0.309, label='discarded up to %s' % Data['Discard_Lateral_fish'][whichone][0],
        #                       color=seaborn.color_palette()[1])
        #     plt.fill_betweenx(range(Data['MIP_' + direction][whichone].shape[0]),
        #                       Data['Discard_Lateral_fish'][whichone][1], Data['MIP_' + direction][whichone].shape[1] - 1,
        #                       alpha=0.309, label='discarded from %s' % Data['Discard_Lateral_fish'][whichone][1],
        #                       color=seaborn.color_palette()[1])
        # if 'Ante' not in direction:
        #     plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
        #                      0,
        #                      Data['Discard_Anteroposterior_fish'][whichone][0],
        #                      alpha=0.309, label='discarded from %s' % Data['Discard_Anteroposterior_fish'][whichone][0],
        #                      color=seaborn.color_palette()[0])
        #     plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
        #                      Data['Discard_Anteroposterior_fish'][whichone][1],
        #                      Data['MIP_' + direction][whichone].shape[0] - 1,
        #                      alpha=0.309, label='discarded from %s' % Data['Discard_Anteroposterior_fish'][whichone][1],
        #                      color=seaborn.color_palette()[0])
        # else:
        #     plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
        #                      0,
        #                      Data['Discard_Lateral_fish'][whichone][0],
        #                      alpha=0.309, label='discarded from %s' % Data['Discard_Lateral_fish'][whichone][0],
        #                      color=seaborn.color_palette()[1])
        #     plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
        #                      Data['Discard_Lateral_fish'][whichone][1],
        #                      Data['MIP_' + direction][whichone].shape[0] - 1,
        #                      alpha=0.309, label='discarded from %s' % Data['Discard_Lateral_fish'][whichone][1],
        #                      color=seaborn.color_palette()[1])
        
        # *Very* verbose way of drawing the region we look at
        # But since we have all the values, we can easily put them where we want
        plt.gca().add_patch(matplotlib.patches.Rectangle((Data['Discard_' + direction + '_image_horizontal'][whichone][0],
                                                          Data['Discard_' + direction + '_image_vertical'][whichone][0]),
                                                         Data['MIP_' + direction][whichone].shape[1] - (Data['MIP_' + direction][whichone].shape[1] - Data['Discard_' + direction + '_image_horizontal'][whichone][1])-Data['Discard_' + direction + '_image_horizontal'][whichone][0],
                                                         Data['MIP_' + direction][whichone].shape[0] - (Data['MIP_' + direction][whichone].shape[0] - Data['Discard_' + direction + '_image_vertical'][whichone][1])-Data['Discard_' + direction + '_image_vertical'][whichone][0],
                                                         edgecolor=seaborn.color_palette()[c],
                                                         facecolor='none',
                                                         label='Region for detection'))
        
        # Plot gray values onto the image
        # plt.plot(rescale_linear(Data['Grayvalues_' + direction + '_horizontal'][whichone], 0, Data['MIP_' + direction][whichone].shape[0] - 1),
        #          label='horizontal', color='gray')
        # Plot *only* the values we're interested in, i.e. discard start and end
        plt.plot(range(Data['Discard_'+ direction + '_image_horizontal'][whichone][0],len(Data['Grayvalues_' + direction + '_horizontal'][whichone][:Data['Discard_'+ direction + '_image_horizontal'][whichone][1]])),
                 rescale_linear(Data['Grayvalues_' + direction + '_horizontal'][whichone][Data['Discard_'+ direction + '_image_horizontal'][whichone][0]:Data['Discard_'+ direction + '_image_horizontal'][whichone][1]],
                                Data['Discard_' + direction + '_image_vertical'][whichone][0],
                                Data['Discard_' + direction + '_image_vertical'][whichone][1]),
                 # label='horizontal',
                 color='lightgray', alpha=0.618)        
        plt.plot(range(Data['Discard_'+ direction + '_image_horizontal'][whichone][0],len(Data['Grayvalues_' + direction + '_horizontal'][whichone][:Data['Discard_'+ direction + '_image_horizontal'][whichone][1]])),
                 rescale_linear(smoother(Data['Grayvalues_' + direction + '_horizontal'][whichone][Data['Discard_'+ direction + '_image_horizontal'][whichone][0]:Data['Discard_'+ direction + '_image_horizontal'][whichone][1]]),
                                Data['Discard_' + direction + '_image_vertical'][whichone][0],
                                Data['Discard_' + direction + '_image_vertical'][whichone][1]),
                 # label='horizontal',
                 color='white', alpha=0.618)        
        plt.plot(rescale_linear(Data['Grayvalues_' + direction + '_vertical'][whichone][Data['Discard_'+ direction + '_image_vertical'][whichone][0]:Data['Discard_'+ direction + '_image_vertical'][whichone][1]],
                                Data['Discard_' + direction + '_image_horizontal'][whichone][0],
                                Data['Discard_' + direction + '_image_horizontal'][whichone][1]),
                 range(Data['Discard_'+ direction + '_image_vertical'][whichone][0],len(Data['Grayvalues_' + direction + '_vertical'][whichone][:Data['Discard_'+ direction + '_image_vertical'][whichone][1]])),
                 # label='vertical',
                 color='lightgray', alpha=0.618)        
        plt.plot(rescale_linear(smoother(Data['Grayvalues_' + direction + '_vertical'][whichone][Data['Discard_'+ direction + '_image_vertical'][whichone][0]:Data['Discard_'+ direction + '_image_vertical'][whichone][1]]),
                                Data['Discard_' + direction + '_image_horizontal'][whichone][0],
                                Data['Discard_' + direction + '_image_horizontal'][whichone][1]),
                 range(Data['Discard_'+ direction + '_image_vertical'][whichone][0],len(Data['Grayvalues_' + direction + '_vertical'][whichone][:Data['Discard_'+ direction + '_image_vertical'][whichone][1]])),
                 # label='vertical',
                 color='white', alpha=0.618)        
        
        # # Show peaks
        plt.axhline(Data['Otolith_Peak_' + direction + '_vertical'][whichone],
                    label='DV Otolith @ %s' % Data['Otolith_Peak_' + direction + '_vertical'][whichone],
                    color='red')
        plt.axvline(Data['Otolith_Peak_' + direction + '_horizontal'][whichone],
                    label='DV Otolith @ %s' % Data['Otolith_Peak_' + direction + '_horizontal'][whichone],
                    color='green')
        
        # *Very* verbose way of drawing the otolith region on top
        # But since we have all the values, we can easily put them where we want
        plt.gca().add_patch(matplotlib.patches.Rectangle((Data['Otolith_Peak_' + direction + '_horizontal'][whichone] - Data['Otolith_Width_' + direction + '_horizontal'][whichone] // 2,
                                                          Data['Otolith_Peak_' + direction + '_vertical'][whichone] - Data['Otolith_Width_' + direction + '_vertical'][whichone] // 2),
                                                         Data['Otolith_Width_' + direction + '_horizontal'][whichone],
                                                         Data['Otolith_Width_' + direction + '_vertical'][whichone],
                                                         color=seaborn.color_palette()[c],
                                                         alpha=0.618,
                                                         label='Otolith region (%sx%s)' % (Data['Otolith_Width_' + direction + '_fish'][whichone],
                                                                                           Data['Otolith_Width_' + direction + '_fish'][whichone])))
    
        # plt.legend(loc='lower left')
        
        plt.gca().add_artist(ScaleBar(Data['Voxelsize'][whichone], 'um', color=seaborn.color_palette()[c]))
        plt.title('%s MIP\nwith a size of %s x %s px' % (direction, Data['MIP_' + direction][whichone].shape[0], Data['MIP_' + direction][whichone].shape[1]))
        mip = fig.add_subplot(gs[1, c])
        plt.imshow(Reconstructions[whichone][Data['Otolith_Peak_Anteroposterior_fish'][whichone]-Data['Otolith_Width_Anteroposterior_fish'][whichone]//2:
                                             Data['Otolith_Peak_Anteroposterior_fish'][whichone]+Data['Otolith_Width_Anteroposterior_fish'][whichone]//2,
                                             Data['Otolith_Peak_Lateral_fish'][whichone]-Data['Otolith_Width_Lateral_fish'][whichone]//2:
                                             Data['Otolith_Peak_Lateral_fish'][whichone]+Data['Otolith_Width_Lateral_fish'][whichone]//2,
                                             Data['Otolith_Peak_Dorsoventral_fish'][whichone]-Data['Otolith_Width_Dorsoventral_fish'][whichone]//2:
                                             Data['Otolith_Peak_Dorsoventral_fish'][whichone]+Data['Otolith_Width_Dorsoventral_fish'][whichone]//2,
                                            ].max(axis=c))
        plt.title('Extracted %s region' % (direction))
    plt.show()
    return()

In [None]:
display_otolith_position(whichone)

In [None]:
# for direction in directions:
#     for view in views:
#         print(direction, view, 'peak', Data['Otolith_Peak_' + direction + '_' + view][whichone])

In [None]:
# for direction in directions:
#     for view in views:
#         print(direction, view, 'width', Data['Otolith_Width_' + direction + '_' + view][whichone])

In [None]:
for whichone, row in tqdm(Data.iterrows(),
                          desc='Extracting otolith regions',
                          total=len(Data)):
    otolither_region(whichone, verbose=False)

In [None]:
for whichone, row in tqdm(Data.iterrows(),
                          desc='Displaying otolith regions',
                          total=len(Data)):
    display_otolith_position(whichone)

In [None]:
# Save out otolith regions as .zarr files
Data['OutputNameOtolithRegion'] = [os.path.join(os.path.dirname(f), 'otolith.region.zarr') for f in Data['Folder']]
for whichone, row in tqdm(Data.iterrows(),
                          desc='Extracting otolith regions to .zarr files',
                          total=len(Data)):
    if not os.path.exists(row['OutputNameOtolithRegion']):
        Reconstructions[whichone][Data['Otolith_Peak_Anteroposterior_fish'][whichone]-Data['Otolith_Width_Anteroposterior_fish'][whichone]//2:
                                  Data['Otolith_Peak_Anteroposterior_fish'][whichone]+Data['Otolith_Width_Anteroposterior_fish'][whichone]//2,
                                  Data['Otolith_Peak_Lateral_fish'][whichone]-Data['Otolith_Width_Lateral_fish'][whichone]//2:
                                  Data['Otolith_Peak_Lateral_fish'][whichone]+Data['Otolith_Width_Lateral_fish'][whichone]//2,
                                  Data['Otolith_Peak_Dorsoventral_fish'][whichone]-Data['Otolith_Width_Dorsoventral_fish'][whichone]//2:
                                  Data['Otolith_Peak_Dorsoventral_fish'][whichone]+Data['Otolith_Width_Dorsoventral_fish'][whichone]//2,
                                 ].to_zarr(row['OutputNameOtolithRegion'],
                                           overwrite=True,
                                           compressor=Blosc(cname='zstd',
                                                            shuffle=Blosc.BITSHUFFLE))

In [None]:
# Load the otoliths in again
Otoliths = [dask.array.from_zarr(file) for file in Data['OutputNameOtolithRegion']]

In [None]:
asdfasdfasdf==

In [None]:
whichone=12
o = Otoliths[whichone]

In [None]:
Otoliths[whichone].max().compute()

In [None]:
Otoliths[whichone].mean().compute()

In [None]:
Otoliths[whichone].min().compute()

In [None]:
o.shape

In [None]:
img = o[:,:,750].compute()
plt.imshow(img)
plt.show()

In [None]:
# Plot histogramh,b = dask.array.histogram(o, bins=2**8, range=(0,2**8))
plt.semilogy(h)
# Caldulate multiotsu threshold and show them
for t in skimage.filters.threshold_multiotsu(o.compute(), classes=5):
    plt.axvline(t, label=t, c='red')
plt.xlim([0,2**8])
plt.legend()

In [None]:
def segmentor(image, peaks, verbose=False):
    # https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_random_walker_segmentation.html#sphx-glr-auto-examples-segmentation-plot-random-walker-segmentation-py
    # Set the background/discarded pixels to -1, 'stuff to be segmented' 0 and markers to 1
    # the details given on https://scipy-lectures.org/packages/scikit-image/index.html?highlight=random%20walker#random-walker-segmentation
    # Set everything to unlabeled
    if not len(peaks):
        print('Also give us some peaks to work with')
        print('skimage.filters.threshold_multiotsu(STACK.compute(), classes=5) is a good start')
        return()
    markers = numpy.zeros_like(image, dtype='uint8')
    # Set everything below background to be discarded.
    # This is the same as above
    if verbose:
        print('Discarding everything below %s' % peaks[1])
    markers[image < peaks[1]] = -1
    if verbose:
        print('Setting everything between %s and %s to 1' % (peaks[1], peaks[2]))
    markers[(image > peaks[1]) & (image < peaks[2])] = 1
    if verbose:
        print('Setting everything above %s to 2' % peaks[3])
    markers[image > peaks[3]] = 2
    # Do the segmentation now
    labels = random_walker(image.astype('uint8'), markers, copy=False)
    if verbose:
        # markers = markers.compute()
        plt.subplot(2, 3, 1)
        plt.imshow(image)
        # plt.imshow(dask.array.ma.masked_where(-1, markers),
        #            cmap='viridis',
        #            alpha=0.5)
        plt.title('Original')
        plt.axis('off')      
        plt.subplot(2, 3, 2)
        plt.imshow(markers)
        plt.title('Markers')
        plt.axis('off')
        plt.subplot(2, 3, 3)
        plt.imshow(labels)
        plt.title('Labels')        
        plt.axis('off')            
        for c, value in enumerate(numpy.unique(markers)):
            plt.subplot(2, len(numpy.unique(markers)), len(numpy.unique(markers)) + c + 1)
            plt.imshow(image)
            plt.imshow(dask.array.ma.masked_not_equal(markers, value),
                       cmap='viridis_r')
            plt.title('%s=%s' % (c, value))
            plt.axis('off')
        plt.show()
    # Return labeled image as an interpolated 8bit image
    print(numpy.unique(markers))    
    print(numpy.unique(labels))
    return(numpy.interp(labels, (labels.min(), labels.max()), (labels.min(), 255)).astype('uint8'))

In [None]:
# peaks=skimage.filters.threshold_multiotsu(o.compute(), classes=5)
print(peaks)

In [None]:
mask = segmentor(o[:,:,800].compute(),
                 peaks,
                 verbose=True)

In [None]:
numpy.unique(mask)

In [None]:
plt.imshow(o[:,:,800])
plt.imshow(mask,
           alpha=0.309,
           cmap='viridis')

In [None]:
asdfasdfasdfasdf==

In [None]:
mask.shape()

In [None]:
numpy.unique(mask)

In [None]:
# View one otolith in 3D

In [None]:
whichone = 0

In [None]:
subsample = 3

In [None]:
# threshold = threshold(Otoliths[whichone], verbose=False)
# histogram = histogram(Otoliths[whichone])

In [None]:
mask.shape

In [None]:
plt.semilogy(histogram)
plt.axvline(threshold, label='Threshold @%s' % threshold)
plt.legend()
plt.show()

In [None]:
# Load into K3D
# plt_volume = k3d.volume(Otoliths[whichone][::subsample,::subsample,::subsample].astype(numpy.float16), color_range=[threshold,Otoliths[whichone].max().compute()])
plt_volume = k3d.volume((Otoliths[whichone][::subsample,::subsample,::subsample]>peaks[-1]).astype(numpy.float16))

In [None]:
# Display the otolith
plot = k3d.plot()
plot += plt_volume
plot.display()

In [None]:
asdfasdfsda=

In [None]:
# plot5 = k3d.get_plot()
# plot5.snapshot_type = 'inline'
# plot5.display()

# data = plot5.get_snapshot()

# with open('snapshot_inline.html', 'w') as f:
#     f.write(data)

In [None]:
for c, row in Data.iterrows():
    print(row.Fish, row.Folder)
    otolither_region(c, verbose=True)

In [None]:
whichone = 5
out = Reconstructions[5][716:1009+716-500,300:-300,800:1000]

In [None]:
out.shape

In [None]:
plt.imshow(out.max(axis=0))

In [None]:
# Make a column for saving the otolith positions
for d, direction in enumerate(directions):
    Data['Otholith_Positions_' + direction] = ''
    Data['Otholith_Position_Mean_' + direction] = ''

In [None]:
# Detect otolith positions
for c, row in Data.iterrows():
    print('Finding otolith position for %s/%s' % (row.Fish, row.Scan))
    for d, direction in enumerate(directions):
        Data.at[c, 'Otholith_Positions_' + direction] = otolither(row['MIP_' + direction], sigma=11, verbose=True)

In [None]:
# Save us the mean position
for d, direction in enumerate(directions):
    Data['Otholith_Position_Mean_' + direction] = [(numpy.mean(op[0]),
                                                    numpy.mean(op[1])) for op in Data['Otholith_Positions_' + direction]]

In [None]:
for i in Data['Otholith_Positions_Lateral']:
    print(round(numpy.mean(i[0])),
          round(numpy.mean(i[1])))

In [None]:
for i in Data['Otholith_Position_Mean_Lateral']:
    print(round(i[0]), round(i[1]))

In [None]:
for whichone in range(len(Data)):
    print(whichone, os.path.join(Data.Fish[whichone], Data.Scan[whichone]))

In [None]:
direction = 'Lateral'

In [None]:
for c, row in Data.iterrows():
    for d, direction in enumerate(directions):
        plt.subplot(1, 3, d + 1)
        plt.imshow(row['MIP_' + direction])
        plt.title([round(i) for i in row['Otholith_Position_Mean_' + direction]])
        plt.axhline(row['Otholith_Position_Mean_' + direction][1])
        plt.axvline(row['Otholith_Position_Mean_' + direction][0])
        plt.suptitle('%s/%s' % (row.Fish, row.Scan))
        plt.gca().add_artist(ScaleBar(row.Voxelsize, 'um'))
    plt.show()

In [None]:
# Detect otolith position by looking for maximum gray value along fish
for d, direction in enumerate(directions):
    Data['GrayValueAlong_' + direction] = ''
    Data['Otolith_MIP_Position_' + direction] = ''
for whichone, row in Data.iterrows():
    for d, direction in enumerate(directions):
        # Calculate gray value sum along fish.
        Data.at[whichone, 'GrayValueAlong_' + direction] = dask.array.sum(Data['MIP_' + direction][whichone],
                                                                          axis=1)
        # Maximum of this shoud give us the otolith position
        Data.at[whichone,
                'Otolith_MIP_Position_' + direction] = dask.array.argmax(dask.array.sum(Data['MIP_' + direction][whichone],
                                                                                        axis=1))
        # Plot what we found
        plt.subplot(1, 3, d + 1)
        plt.imshow(Data['MIP_' + direction][whichone])
        # Plot the *rescaled* values over the image
        plt.plot(rescale_linear(Data['GrayValueAlong_' + direction][whichone],
                                100,
                                Data['MIP_' + direction][whichone].shape[1] - 100),
                 range(len(Data['GrayValueAlong_' + direction][whichone])),
                 label='Normalized gray value sum along fish',
                 color=seaborn.color_palette()[0])
        plt.axhline(Data['Otolith_MIP_Position_' + direction][whichone],
                    label='Max@%s' % Data['Otolith_MIP_Position_' + direction][whichone].compute(),
                    color=seaborn.color_palette()[1])
        plt.title('%s MIP' % direction)
        plt.suptitle('%s/%s: MIPs of %s/%s' % (whichone, len(Data), Data.Fish[whichone], Data.Scan[whichone]))
        plt.legend(loc='lower center')
    plt.show()

In [None]:
for c, row in Data.iterrows():
    print(c, len(Data), os.path.join(row.Fish, row.Scan))
    print('\t Otolith from MIP',
          row['Otolith_MIP_Position_Anteroposterior'].compute(),
          row['Otolith_MIP_Position_Lateral'].compute(),
          row['Otolith_MIP_Position_Dorsoventral'].compute())
    print('\t Otolith from otholither function',
          row['Otholith_Position_Mean_Anteroposterior'],
          row['Otholith_Position_Mean_Lateral'],
          row['Otholith_Position_Mean_Dorsoventral'])

In [None]:
whichone = 3

In [None]:
for c, direction in enumerate(directions):
    plt.subplot(1, 3, c + 1)
    plt.imshow(Data['MIP_' + direction][whichone])
    # From otholither function
    plt.axhline(Data['Otholith_Position_Mean_' + direction][whichone][1],
                label='otholither mean position 1: %s' % round(Data['Otholith_Position_Mean_' + direction][whichone][1]),
                color=seaborn.color_palette()[0])
    plt.axvline(Data['Otholith_Position_Mean_' + direction][whichone][0],
                label='otholither mean posistion 0: %s' % round(Data['Otholith_Position_Mean_' + direction][whichone][0]),
                color=seaborn.color_palette()[1])
    # From sum along axis
    plt.axhline(Data['Otolith_MIP_Position_' + direction][whichone],
                label='MIP sum: %s' % Data['Otolith_MIP_Position_' + direction][whichone].compute(),
                color=seaborn.color_palette()[3])
    plt.legend(loc='lower center')
    plt.title(direction)
    plt.suptitle('%s/%s' % (Data.Fish[whichone], Data.Scan[whichone]))
    plt.gca().add_artist(ScaleBar(Data.Voxelsize[whichone], 'um'))
plt.show()

In [None]:
for direction in directions:
    print(direction, Data['MIP_' + direction][whichone].shape)

In [None]:
for direction in directions:
    print(direction, Data['Otholith_Position_Mean_' + direction][whichone])

In [None]:
for direction in directions:
    print(direction, round(Data['Otolith_MIP_Position_' + direction][whichone].compute()))

In [None]:
# CAN WE CALCULATE BOTH DV and THE LT POSITION ON THE AP MIP?

In [None]:
# Get us positions of otolith in relation to original data
position_ap = numpy.mean((Data['Otholith_Position_Mean_Lateral'][whichone][1],
                          Data['Otholith_Position_Mean_Dorsoventral'][whichone][1],
                          Data['Otolith_MIP_Position_Lateral'][whichone],
                          Data['Otolith_MIP_Position_Dorsoventral'][whichone]),
                         dtype='int')
# laterally, we assume the center of the image for now
# position_lt = numpy.mean((Data['Otholith_Position_Mean_' + direction][whichone][1],
#                           Data['Otolith_MIP_Position_' + direction][whichone]),dtype='int')
position_lt = Data.MIP_Dorsoventral[whichone].shape[1] // 2
position_dv = numpy.mean((Data['Otholith_Position_Mean_Anteroposterior'][whichone][0],
                          Data['Otholith_Position_Mean_Lateral'][whichone][0]),
                         dtype='int')

In [None]:
print(position_ap)
print(position_lt)
print(position_dv)

In [None]:
slicethickness = 250

In [None]:
for c, direction in enumerate(directions):
    plt.subplot(1, 3, c + 1)
    plt.imshow(Data['MIP_' + direction][whichone])
    if c == 0:
        plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
                         position_lt - slicethickness // 2,
                         position_lt + slicethickness // 2,
                         alpha=0.5)
        plt.fill_between(range(position_dv - slicethickness // 2, position_dv + slicethickness // 2),
                         1,
                         Data['MIP_' + direction][whichone].shape[0] - 1,
                         alpha=0.5)
    elif c == 1:
        plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
                         position_ap - slicethickness // 2,
                         position_ap + slicethickness // 2,
                         alpha=0.5)
        plt.fill_between(range(position_dv - slicethickness // 2, position_dv + slicethickness // 2),
                         1,
                         Data['MIP_' + direction][whichone].shape[0] - 1,
                         alpha=0.5)
    else:
        plt.fill_between(range(Data['MIP_' + direction][whichone].shape[1]),
                         position_ap - slicethickness // 2,
                         position_ap + slicethickness // 2,
                         alpha=0.5)
        plt.fill_between(range(position_lt - slicethickness // 2, position_lt + slicethickness // 2),
                         1,
                         Data['MIP_' + direction][whichone].shape[0] - 1,
                         alpha=0.5)
    plt.title(direction)
    plt.suptitle('%s/%s' % (Data.Fish[whichone], Data.Scan[whichone]))
    plt.gca().add_artist(ScaleBar(Data.Voxelsize[whichone], 'um'))
    
outfilepath = os.path.join(os.path.dirname(Data['Folder'][whichone]),
                           '%s.%s.Otolither.png' % (Data['Fish'][whichone], Data['Scan'][whichone]))
if not os.path.exists(outfilepath):
    plt.savefig(outfilepath,
                oarent=True,
                bbox_inches='tight')
    print('Figure saved to %s' % outfilepath)
plt.show()

In [None]:
print(position_ap)
print(position_lt)
print(position_dv)

In [None]:
# Grab region calculated above from reconstructions
otolithregion = Reconstructions[whichone][position_ap - slicethickness // 2:position_ap + slicethickness // 2,
                                          position_lt - slicethickness // 2:position_lt + slicethickness // 2,
                                          position_dv - slicethickness // 2:position_dv + slicethickness // 2
                                         ]
for ax in range(3):
    plt.subplot(1, 3, ax + 1)
    plt.imshow(dask.array.max(otolithregion, axis=ax))
    plt.suptitle('%s/%s: MIP from AP %s:%s, LT %s:%s, DV %s:%s' % (Data.Fish[whichone],
                                                                   Data.Scan[whichone],
                                                                   position_ap - slicethickness // 2, position_ap + slicethickness // 2,
                                                                   position_lt - slicethickness // 2, position_lt + slicethickness // 2,
                                                                   position_dv - slicethickness // 2, position_dv + slicethickness // 2))
plt.show()

In [None]:
threshold(otolithregion.compute())

In [None]:
threshold(otolithregion[otolithregion > 42].compute())

In [None]:
plt.imshow(otolithregion[slicethickness//2])
plt.imshow(dask.array.ma.masked_equal(otolithregion[slicethickness//2] > 102,0), cmap='viridis', alpha=0.5)

In [None]:
# import pickle

In [None]:
# file = open('largest', 'wb')
# pickle.dump(largest,file)
# file.close()

In [None]:
# file = open('largest', 'rb')
# largest = pickle.load(file)
# file.close()

In [None]:
# largest.shape

In [None]:
# # Make file smaller for testing reasons
# subsample = 4
# largest_smaller = largest[::subsample, ::subsample, ::subsample]
# largest_smaller.shape

In [None]:
subsample = 2

In [None]:
vmin = threshold(otolithregion.compute())
print(vmin)

In [None]:
# Load into K3D
plt_volume = k3d.volume(dask.array.ma.masked_less(otolithregion,102)[::subsample, ::subsample, ::subsample].astype(numpy.float16))

In [None]:
# Display the otolith
plot = k3d.plot()
plot += plt_volume
plot.display()    

In [None]:
# plot5 = k3d.get_plot()
# plot5.snapshot_type = 'inline'
# plot5.display()

# data = plot5.get_snapshot()

# with open('snapshot_inline.html', 'w') as f:
#     f.write(data)