# Preview the Stickleback scans

The initial state of this notebook is a simple copy of [the notebook generated for the EAWAG Cichlids project](https://github.com/habi/EAWAG/blob/main/DisplayFishes.ipynb).

The cells below are used to set up the whole notebook.
They load needed libraries and set some default values.

In [None]:
# Load the modules we need
import platform
import os
import glob
import pandas
import imageio
import numpy
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
import seaborn
import dask
import dask_image.imread
from dask.distributed import Client, LocalCluster
from tqdm.auto import tqdm, trange

In [None]:
# Load our own log file parsing code
# This is loaded as a submodule to alleviate excessive copy-pasting between *all* projects we do
# See https://github.com/habi/BrukerSkyScanLogfileRuminator for details on its inner workings
from BrukerSkyScanLogfileRuminator.parsing_functions import *

In [None]:
# # Code linting
# %load_ext pycodestyle_magic

In [None]:
# %pycodestyle_on

In [None]:
# Set dask temporary folder
# Do this before creating a client: https://stackoverflow.com/a/62804525/323100
# We use the fast internal SSD for speed reasons
import tempfile
if 'Linux' in platform.system():
    # Check if me mounted the FastSSD, otherwise go to standard tmp file
    if os.path.exists(os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')):
        tmp = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD', 'tmp')
    else:
        tmp = tempfile.gettempdir()
elif 'Darwin' in platform.system():
    tmp = tempfile.gettempdir()
else:
    if 'anaklin' in platform.node():
        tmp = os.path.join('F:\\tmp')
    else:
        tmp = os.path.join('D:\\tmp')
dask.config.set({'temporary_directory': tmp})
print('Dask temporary files go to %s' % dask.config.get('temporary_directory'))

In [None]:
from dask.distributed import Client
client = Client()

In [None]:
client

In [None]:
seaborn.set_context('notebook')

In [None]:
# Set up figure defaults
plt.rc('image', cmap='gray', interpolation='nearest')  # Display all images in b&w and with 'nearest' interpolation
plt.rcParams['figure.figsize'] = (16, 9)  # Size up figures a bit
plt.rcParams['figure.dpi'] = 200

In [None]:
# Setup scale bar defaults
plt.rcParams['scalebar.location'] = 'lower right'
plt.rcParams['scalebar.frameon'] = False
plt.rcParams['scalebar.color'] = 'white'

In [None]:
# Display all plots identically
lines = 3
# And then do something like
# plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)

Since the (tomographic) data can reside on different drives we set a folder to use below

In [None]:
# Different locations if running either on Linux or Windows
FastSSD = True
if 'Linux' in platform.system():
    if FastSSD:
        BasePath = os.path.join(os.sep, 'media', 'habi', 'Fast_SSD')
    else:
        BasePath = os.path.join(os.path.sep, 'home', 'habi', '2214')
elif 'Windows' in platform.system():
    if FastSSD:
        BasePath = os.path.join('F:\\')
    else:
        BasePath = os.path.join('N:\\')
Root = os.path.join(BasePath, 'IEE Stickleback')
print('We are loading all the data from %s' % Root)

Now that we are set up, actually start to load/ingest the data.

In [None]:
# Make us a dataframe for saving all that we need
Data = pandas.DataFrame()

In [None]:
# Get *all* log files present on disk
# Using os.walk is way faster than using recursive glob.glob
# Not sorting the found logfiles is also making it quicker
Data['LogFile'] = [os.path.join(root, name)
                   for root, dirs, files in os.walk(Root)
                   for name in files
                   if name.endswith((".log"))]

In [None]:
# Get all folders
Data['Folder'] = [os.path.dirname(f) for f in Data['LogFile']]

In [None]:
# Show a (small) sampler of the loaded data as a first check
Data.sample(n=5)

In [None]:
# Check for samples which are not yet reconstructed
for c, row in Data.iterrows():
    # Iterate over every 'proj' folder
    if 'proj' in row.Folder:
        if 'TScopy' not in row.Folder and 'PR' not in row.Folder:
            # If there's nothing with 'rec*' on the same level, then tell us
            if not glob.glob(row.Folder.replace('proj', '*rec*')):
                print('- %s is missing matching reconstructions' % row.LogFile[len(Root) + 1:])

In [None]:
# Search for any .csv files in each folder.
# These are only generated when the "X/Y Alignment With a Reference Scan" was performed in NRecon.
# If those files do *not* exist we have missed to do it and should correct for this.
Data['XYAlignment'] = [glob.glob(os.path.join(f, '*T*.csv')) for f in Data['Folder']]

In [None]:
# Display samples which are missing the .csv-files for the XY-alignment
for c, row in Data.iterrows():
    # Iterate over every 'proj' folder
    if 'proj' in row['Folder']:
        if not row['XYAlignment']:
            if not any(x in row.LogFile for x in ['rectmp.log',  # because we only exclude temporary logfiles in a later step
                                                  'proj_nofilter',  # since these two scans of single teeth don't contain a reference scan
                                                  'TScopy',  # discard *t*hermal *s*hift data
                                                  ]):
                print('- %s has *not* been X/Y aligned' % row.LogFile[len(Root) + 1:])

In [None]:
# Get rid of all the logfiles from all the folders that might be on disk but that we don't want to load the data from
for c, row in Data.iterrows():
    if os.path.split(row.Folder)[-1] == 'proj':  # drop all projections folders
        Data.drop([c], inplace=True)
    elif 'ucket' not in row.Folder:  # Remove all test scans which are not named 'Sticklbucket_*' or something else containing 'ucket'
        Data.drop([c], inplace=True)
    elif '_regions' in row.Folder:  # Exclude all log files that we write in this notebook (to $scan$_region folders)
        Data.drop([c], inplace=True)
# Reset dataframe to something that we would get if we only would have loaded the 'rec' files
Data = Data.reset_index(drop=True)

In [None]:
# Generate us some meaningful colums in the dataframe
Data['Sample'] = [os.path.basename(log).replace('_rec.log', '') for log in Data['LogFile']]
Data['Scan'] = [os.path.basename(os.path.dirname(log)) for log in Data['LogFile']]

In [None]:
# Quickly show the data from the last loaded scans
Data.tail(n=5)

In [None]:
# Load the file names of all the reconstructions of all the scans
Data['Filenames Reconstructions'] = [sorted(glob.glob(os.path.join(f, '*rec0*.png'))) for f in Data['Folder']]
# How many reconstructions do we have?
Data['Number of reconstructions'] = [len(r) for r in Data['Filenames Reconstructions']]

In [None]:
# Drop samples which have either not been reconstructed yet or of which we deleted the reconstructions with
# `find . -name "*rec*.png" -type f -mtime +333 -delete`
# Based on https://stackoverflow.com/a/13851602
# for c,row in Data.iterrows():
#     if not row['Number of reconstructions']:
#         print('%s contains no PNG files, we might be currently reconstructing it' % row.Folder)
Data = Data[Data['Number of reconstructions'] > 0]
# Reset the dataframe count/index for easier indexing afterwards
Data.reset_index(drop=True, inplace=True)
print('We have %s folders with reconstructions' % (len(Data)))

In [None]:
# Get parameters to doublecheck from logfiles
Data['Voxelsize'] = [pixelsize(log) for log in Data['LogFile']]
Data['Filter'] = [whichfilter(log) for log in Data['LogFile']]
Data['Exposuretime'] = [exposuretime(log) for log in Data['LogFile']]
Data['Scanner'] = [scanner(log) for log in Data['LogFile']]
Data['Averaging'] = [averaging(log) for log in Data['LogFile']]
Data['ProjectionSize'] = [projection_size(log) for log in Data['LogFile']]
Data['RotationStep'] = [rotationstep(log) for log in Data['LogFile']]
Data['Grayvalue'] = [reconstruction_grayvalue(log) for log in Data['LogFile']]
Data['RingartefactCorrection'] = [ringremoval(log) for log in Data['LogFile']]
Data['BeamHardeningCorrection'] = [beamhardening(log) for log in Data['LogFile']]
Data['DefectPixelMasking'] = [defectpixelmasking(log) for log in Data['LogFile']]
Data['Scan date'] = [scandate(log) for log in Data['LogFile']]

In [None]:
# Sort dataframe based on the scan date
Data.sort_values(by=['Scan date'],
                 ignore_index=True,
                 inplace=True)

Display the parameters we extracted from the log files (with [our log file parser](https://github.com/habi/BrukerSkyScanLogfileRuminator)) to check for consistency.

In [None]:
# Check ring removal parameters
for machine in Data['Scanner'].unique():
    print('For the %s we have '
          'ringartefact-correction values of %s' % (machine,
                                                    Data[Data.Scanner == machine]['RingartefactCorrection'].unique()))

In [None]:
# Display ring removal parameter
for rac in sorted(Data['RingartefactCorrection'].unique()):
    print('Ringartefact-correction %02s is found in %03s scans' % (rac,
                                                                   Data[Data.RingartefactCorrection == rac]['RingartefactCorrection'].count()))

In [None]:
# Display ring removal parameter for non-zero values
for scanner in Data.Scanner.unique():
    print('----', scanner, '----')
    for c, row in Data[Data.Scanner == scanner].iterrows():
        if not row.RingartefactCorrection:  # is set to 'nan' when zero, so we only show the values that are set
            print('Fish %s scan %s was reconstructed with RAC of %s' % (row['Sample'],
                                                                        row['Scan'],
                                                                        row['RingartefactCorrection']))

In [None]:
# Check beamhardening parameters
for scanner in Data.Scanner.unique():
    print('For the %s we have '
          'beamhardening correction values of %s' % (scanner,
                                                     Data[Data.Scanner == scanner]['BeamHardeningCorrection'].unique()))

In [None]:
# Display beamhardening parameters
for scanner in Data.Scanner.unique():
    print('----', scanner, '----')
    for c, row in Data[Data.Scanner == scanner].iterrows():
        if not row.BeamHardeningCorrection:  # is set to 'nan' when zero, so we only show the values that are set
            print('Scan %s of fish %s was reconstructed with beam hardening correction of %s' % (row['Sample'],
                                                                                                 row['Scan'],
                                                                                                 row['BeamHardeningCorrection']))

In [None]:
# Check defect pixel masking parameters
for scanner in Data.Scanner.unique():
    print('For the %s we have '
          'defect pixel masking values of %s' % (scanner,
                                                 Data[Data.Scanner == scanner]['DefectPixelMasking'].unique()))

In [None]:
# Display defect pixel masking parameters
for dpm in sorted(Data['DefectPixelMasking'].unique()):
    print('A defect pixel masking of %02s is found in %03s scans' % (dpm,
                                                                     Data[Data.DefectPixelMasking == dpm]['DefectPixelMasking'].count()))

In [None]:
# seaborn.scatterplot(data=Data, x='Fish', y='DefectPixelMasking', hue='Scanner')
# plt.title('Defect pixel masking')
# plt.show()

In [None]:
# Display defect pixel masking parameters
for scanner in Data.Scanner.unique():
    print('----', scanner, '----')
    for c, row in Data[Data.Scanner == scanner].iterrows():
        if row.Scanner == 'SkyScan1272' and row.DefectPixelMasking != 50:
            print('Fish %s scan %s was reconstructed with DPM of %s' % (row['Fish'],
                                                                        row['Scan'],
                                                                        row['DefectPixelMasking']))
        if row.Scanner == 'SkyScan2214' and row.DefectPixelMasking != 0:
            print('Fish %s scan %s was reconstructed with DPM of %s' % (row['Fish'],
                                                                        row['Scan'],
                                                                        row['DefectPixelMasking']))

----
Now that we've double-checked some of the parameters (and corrected any issues that might have shown up) we start to load the preview images.
If the three cells below are uncommented, the machine-generated previews are shown, otherwise we just continue.

In [None]:
Data['Filename PreviewImage'] = [sorted(glob.glob(os.path.join(f, '*_spr.bmp')))[0] for f in Data['Folder']]
Data['PreviewImage'] = [dask_image.imread.imread(pip).squeeze()
                        if pip
                        else numpy.random.random((100, 100)) for pip in Data['Filename PreviewImage']]

In [None]:
# Make an approximately square overview image
lines = 2

In [None]:
for c, row in Data.iterrows():
    plt.subplot(lines, int(numpy.ceil(len(Data) / float(lines))), c + 1)
    plt.imshow(row.PreviewImage.squeeze())
    plt.title(os.path.join(row['Sample'], row['Scan']))
    plt.gca().add_artist(ScaleBar(row['Voxelsize'],
                                  'um',
                                  color='black',
                                  frameon=True))
    plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(Root, 'ScanOverviews.png'),
            bbox_inches='tight')
plt.show()

Now we 'load' all reconstructions from disks into stacks.

In [None]:
# # Load all reconstructions DASK arrays
# Reconstructions = [dask_image.imread.imread(os.path.join(folder,'*rec*.png')) for folder in Data['Folder']]
# Load all reconstructions into ephemereal DASK arrays, with a nice progress bar...
Reconstructions = [None] * len(Data)
for c, row in tqdm(Data.iterrows(),
                   desc='Loading reconstructions',
                   total=len(Data)):
    Reconstructions[c] = dask_image.imread.imread(os.path.join(row['Folder'], '*rec*.png'))
Reconstructions = [rec[:, :, :, 0] for rec in Reconstructions]  # Get rid of the color channel

In [None]:
Reconstructions[0]

In [None]:
# What do we have on disk?
print('We have %s reconstructions on %s' % (Data['Number of reconstructions'].sum(), Root))
print('This is about %s reconstructions per scan (%s scans in %s folders)' % (round(Data['Number of reconstructions'].sum() / len(Data)),
                                                                              len(Data),
                                                                              len(Data.Sample.unique())))

In [None]:
# How big are the datasets?
Data['Size'] = [rec.shape for rec in Reconstructions]

In [None]:
# The three cardinal directions
directions = ['Axial',
              'Frontal',
              'Median']

In [None]:
# Read or calculate the directional MIPs, put them into the dataframe and save them to disk
for d, direction in enumerate(directions):
    Data['MIP_' + direction] = ''
for c, row in tqdm(Data.iterrows(), desc='Working on MIPs', total=len(Data)):
    for d, direction in tqdm(enumerate(directions),
                             desc='%s/%s' % (row['Sample'], row['Scan']),
                             leave=False,
                             total=len(directions)):
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.%s.MIP.%s.png' % (row['Sample'], row['Scan'], direction))
        if not os.path.exists(outfilepath):
            # Generate and save MIP
            imageio.imwrite(outfilepath, Reconstructions[c].max(axis=d).compute().astype('uint8'))
        Data.at[c, 'MIP_' + direction] = dask_image.imread.imread(outfilepath).squeeze()

In [None]:
# Show/save MIP slices
for c, row in tqdm(Data.iterrows(),
                   desc='Saving overview of MIP images',
                   total=len(Data)):
    outfilepath = os.path.join(os.path.dirname(row['Folder']),
                               '%s.%s.MIPs.png' % (row['Sample'], row['Scan']))
    if not os.path.exists(outfilepath):
        for d, direction in tqdm(enumerate(directions),
                                 desc='%s/%s' % (row['Sample'], row['Scan']),
                                 leave=False,
                                 total=len(directions)):
            plt.subplot(1, 3, d + 1)
            plt.imshow(row['MIP_' + direction])
            plt.gca().add_artist(ScaleBar(row['Voxelsize'],
                                          'um'))
            plt.title('%s MIP' % direction)
            plt.axis('off')
            plt.title('%s\n%s MIP' % (os.path.join(row['Sample'], row['Scan']), direction))
            plt.savefig(outfilepath,
                        transparent=True,
                        bbox_inches='tight')
        plt.show()

For further checking the data, we look at the gray values and gray value histograms of the reconstructions.
This helps us to find scans that have not been reconstructed well and might either need to be repeated or simply re-reconstructed.

In [None]:
def overeexposecheck(item, threshold=222, howmanypercent=0.01, whichone='Axial', verbose=False):
    '''Function to check if a certain amount of voxels are brighter than a certain value'''
    if (Data['MIP_%s' % whichone][item]>threshold).sum() > (Data['MIP_%s' % whichone][item].size * howmanypercent / 100):
        if verbose:
            plt.imshow(Data['MIP_%s' % whichone][item])
            plt.imshow(dask.array.ma.masked_less(Data['MIP_%s' % whichone][item] > threshold, 1).compute(),
                       cmap='viridis_r',
                       alpha=0.5)
            plt.title('%s/%s\n%s px of %s Mpixels (>%s%%) are brighter '
                      'than %s' % (Data['Sample'][item],
                                   Data['Scan'][item],
                                   (Data['MIP_%s' % whichone][item] > threshold).sum().compute(),
                                   round(1e-6 * Data['MIP_%s' % whichone][item].size,2),
                                   howmanypercent,
                                   threshold))
            plt.axis('off')
            plt.gca().add_artist(ScaleBar(Data['Voxelsize'][item],
                                          'um'))
            plt.show()
        return(True)
    else:
        return(False)   

In [None]:
# Check if 'too much' of the MIP is overexposed
# TODO: How much is too much?
Data['OverExposed'] = [overeexposecheck(c,
                                        whichone='Frontal',
                                        verbose=True) for c, row in Data.iterrows()]

In [None]:
# Calculate the histograms of all the images
# Caveat: dask.da.histogram returns histogram AND bins, making each histogram a 'nested' list of [h, b]
subsample = 1
Data['Histogram'] = [dask.array.histogram(rec[::subsample, ::subsample, ::subsample],
                                          bins=2**8,
                                          range=[0, 2**8]) for rec in Reconstructions]
# Actually compute the data and put only h into the dataframe, so that we can easily plot them below
# Discard the bins
Data['Histogram'] = [h.compute() for h, b in Data['Histogram']]

In [None]:
# Plot all single histograms
for c, row in Data.iterrows():
    if subsample > 1:
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.%s.Histogram.%02dx_subsampled.png' % (row['Sample'],
                                                                             row['Scan'],
                                                                             subsample))
    else:
        outfilepath = os.path.join(os.path.dirname(row['Folder']),
                                   '%s.%s.Histogram.png' % (row['Sample'],
                                                            row['Scan']))
    if not os.path.exists(outfilepath):
        plt.subplot(121)
        plt.plot(row.Histogram)
        plt.title('Linear')
        plt.xlim([0, 2**8])
        plt.ylim(ymin=0)
        seaborn.despine()
        plt.subplot(122)
        plt.semilogy(row.Histogram)
        plt.title('Logarithmic')
        plt.xlim([0, 2**8])
        plt.ylim(ymin=10**0)
        seaborn.despine()
        if subsample > 1:
            plt.suptitle('Histogram of all %s reconstructions (%sx subsampled) of %s' % (row['Number of reconstructions'],
                                                                                         subsample,
                                                                                         row['Sample']))
        else:
            plt.suptitle('Histogram of all %s reconstructions of %s' % (row['Number of reconstructions'], row['Sample']))
        plt.savefig(outfilepath,
                    transparent=True,
                    bbox_inches='tight')
        plt.show()

In [None]:
# Plot all histograms together
for c, row in Data.iterrows():
    plt.semilogy(row.Histogram, label=row['Sample'])
    plt.xlim([0, 2**8])
    plt.ylim(ymin=10**0)
    seaborn.despine()
    if subsample > 1:
        plt.title('Histogram of all %s scans (%sx subsampled)' % (len(Data), subsample))
    else:
        plt.title('Histogram of all %s scans' % (len(Data)))
plt.legend()
plt.show()

In [None]:
print('We have previewed %s scans of %s folders with reconstructions in %s' % (len(Data),
                                                                               len(Data.Sample.unique()),
                                                                               Root))