In [None]:
from tifffile import imread, TiffFile
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import concurrent

In [None]:
np.random.randint(10000)

In [None]:
np.random.seed(4409)

In [None]:
basepath = Path('Y:\Eva\CARE')

dirlist = ['08.12.20_14h',
           '08.12.20_19h',
           '08.12.20_24h',
           '09.12.20_14h']

In [None]:
p = Path(basepath)
pattern = 'Pos*/*_ch1_*.tif*'
min_z=16
from posixpath import join
lowChannel='ch2'
gtChannel='ch1'

image_paths = [f.parts for d in dirlist for f in ( p / d ).glob(pattern)]

z_heights = [int(parts[-1].split('Nz')[-1].split('.tif')[0]) for parts in image_paths]

source_paths = [join(*parts[:-1], parts[-1].replace(gtChannel, lowChannel)) for parts, Nz in zip(image_paths, z_heights) if Nz >= min_z ]

In [None]:
std_output_path = 'debug/z_standard_deviations.csv'
argmax_output_path = 'debug/argmax.csv'

In [None]:
if not Path(std_output_path).is_file():

    def compute_mean_std_prctile(path):
        im = imread(str(path))
        # Cut away overview plane and do not care about std values higher than 20 planes
        im = im[1:21]

        im_ = im.reshape(-1, 1024*1024)
        # The threshold is required due to 'hot pixels' which occure in 1:10 of all stacks
        threshs = np.percentile(im_, 99.99, axis=1, keepdims=True)
        im_ = im_.astype('float')
        im_[im_>threshs] = np.nan

        return np.nanstd(im_, axis=1)

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [executor.submit(compute_mean_std_prctile, p) for p in source_paths]
        concurrent.futures.wait(futures)

    stds = [f.result()  for f in futures]
    
    """# Simple alternative:
    stds = []

    for p in tqdm(source_paths):
        im = imread(str(p))
        im = im[1:]
        im_ = im.reshape(-1, 1024*1024)
        threshs = np.percentile(im_, 99.99, axis=1, keepdims=True)
        im_ = im_.astype('float')
        im_[im_>threshs] = np.nan
        stds.append(np.nanstd(im_, axis=1))
    """
    
    
    stds = np.array(stds)
    np.savetxt(std_output_path, stds)
else:
    stds = np.genfromtxt(std_output_path)
    

In [None]:
# Based on the shape of the standard deviation per slice, we expect:
# - a steep increase for the first slices
# - a peak
# - a slow decrease
# The end of the steep increase can be spot by linearly extrapolating the expected standard deviation.
# The first value which is below the measured value, the standard deviation did not increase as much as expected.
# So we are close to the peak.

stds_argmax = []
for std in stds:
    intpl = std[0:-2] + 2*np.diff(std[:-1])

    # + 2 is necessary since we can start comparing only after the second std measurement
    # So the stds_argmax is the true index in the stack (without the overview plane)
    stds_argmax.append(np.argwhere(intpl - std[2:] > 0)[0, 0] + 2)
    
plt.hist(stds_argmax);

In [None]:
df = pd.DataFrame(np.asarray([source_paths, stds_argmax]).transpose(), columns=['Path', 'std_argmax'])

In [None]:
df = df.astype({'Path':'str', 'std_argmax':'uint8'})

In [None]:
df

In [None]:
df.to_csv(argmax_output_path)

# Control results

In [None]:
from tifffile import TiffFile

def readTiffPage(path, i):
    with TiffFile(path) as tif:
        return tif.pages[i].asarray()

In [None]:
order = np.argsort(stds_argmax)
source_paths_ordered = [source_paths[i] for i in order]
stds_max_ordered = [stds_argmax[i] for i in order]
stds_ordered = stds[order]

In [None]:
n = 10
m = 5

In [None]:
f, axes = plt.subplots(n, m, figsize=(15, 30))

for p, idx, ax in zip(source_paths_ordered[-n*m:], stds_max_ordered[-n*m:], axes.flat):
    im = readTiffPage(Path(p), idx+2)
    ax.imshow(im)
    ax.set_title(str(idx))
    ax.set_xticks([])
    ax.set_yticks([])
    
plt.tight_layout()

In [None]:
f, axes = plt.subplots(n, m, figsize=(15, 30))

for p, idx, ax in zip(source_paths_ordered[:n*m], stds_max_ordered[:n*m], axes.flat):
    im = readTiffPage(Path(p), idx+2)
    ax.imshow(im)
    ax.set_title(str(idx))
    ax.set_xticks([])
    ax.set_yticks([])
    
plt.tight_layout()

In [None]:
f, axes = plt.subplots(n, m, figsize=(15, 30))

for std, idx, ax in zip(stds_ordered[-n*m:], stds_max_ordered[-n*m:], axes.flat):
    ax.plot(std)
    
    ax.set_title(str(idx))
    
plt.tight_layout()

In [None]:
f, axes = plt.subplots(n, m, figsize=(15, 30))

for std, idx, ax in zip(stds_ordered[:n*m], stds_max_ordered[:n*m], axes.flat):
    ax.plot(std)
    
    ax.set_title(str(idx))
    
plt.tight_layout()