In [None]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%reload_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

import sys
sys.path.append('../src')

import cv2
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view as sww
from pathlib import Path
import os
import re
import tifffile as tiff
from time import time

from metadata import metadata
from utils import list_subdir_filter as lsd

md = metadata()

In [None]:
def time_diff(reference_time, decimals=1):
    return f'{round(time() - reference_time, decimals)}'


def tiled_median(array_, square_side):
    background_values = []
    for channel in array_:
        sections = np.vstack(sww(channel, (square_side, square_side))[::square_side,::square_side])
        background_values.append(np.min(np.median(sections, axis=(1,2))).astype('uint16'))

    return background_values

def extract_remove_bg(array_, square_side):
    for i in range(len(array_)):
        channel = array_[i]
        sections = np.vstack(sww(channel, (square_side, square_side))[::square_side,::square_side])
        bg = np.min(np.median(sections, axis=(1,2))).astype('uint16')
        array_[i] = np.maximum(channel.astype(np.int32) - bg, 0).astype('uint16') >> 8

    return array_.astype('uint8')


def process_raw_image(image_file, square_side=6500, redux_factor=8):
    sample = re.sub('^.*(A40\.[0-9]+|LCCH-[0-9T-]+).*$', '\\1', image_file)
    
    print(f'{sample}--------------------------------')
    if os.path.exists(os.path.join(out_dir, f'{sample}_bg_removed_uint8.npy')):
        print(f'Sample {sample} exists. Skipping...')
        return 
    t_0 = time()
    t_ = time()
    
    print(f'{"Loading":.>25}', end=' ')
    img = tiff.imread(image_file)
    print(f'{time_diff(t_0):>6}s')
    
    img_redux = img[:, ::redux_factor, ::redux_factor]

    
    t_ = time()
    print(f'{"Extract/remove bg":.>25}', end=' ')
    img = extract_remove_bg(img, 6500)
    print(f'{time_diff(t_):>6}s')

    
    # t_ = time()
    # print(f'{"Removing background":.>25}', end=' ')
    # for i in range(len(img)):
    #     img[i] = np.maximum(img[i].astype(np.int32) - background_values[i], 0).astype('uint16') >> 8
    # img = img.astype('uint8')
    # print(f'{time_diff(t_):>6}s')

    # t_ = time()
    # print(f'{"Extracting background":.>25}', end=' ')
    # background_values = tiled_median(img, square_side)
    # print(f'{time_diff(t_):>6}s')

    
    # t_ = time()
    # print(f'{"Removing background":.>25}', end=' ')
    # for i in range(len(img)):
    #     img[i] = np.maximum(img[i].astype(np.int32) - background_values[i], 0).astype('uint16') >> 8
    # img = img.astype('uint8')
    # print(f'{time_diff(t_):>6}s')

    
    t_ = time()
    print(f'{"Saving clean data (npy)":.>25}', end=' ')
    np.save(os.path.join(md.folders['bg_removed'], f'NEW_{sample}_bg_removed_uint8.npy'), img)
    print(f'{time_diff(t_):>6}s')

    
    t_ = time()
    print(f'{"Plotting clipped":.>25}', end=' ')
    img_redux = img[:, ::redux_factor, ::redux_factor]
    fig, axes = plt.subplots(4,3, figsize=(18, 24))
    ax = axes.ravel()
    for i, im in enumerate(img_redux):
        ax[i].imshow(im, cmap='gray', vmin=0, vmax=np.quantile(im, 0.95))
        ax[i].axis('off')
        ax[i].text(
            0.5, 0.98, list(md.markers.keys())[i], 
            horizontalalignment='center', verticalalignment='top', 
            transform=ax[i].transAxes, fontsize=22, c='white'
        )
        if i == len(ax) - 2:
            ax[i].text(
                0.5, 0.02, sample, 
                horizontalalignment='center', verticalalignment='bottom', 
                transform=ax[i].transAxes, fontsize=25, c='lightgrey'
            )
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f'{sample}_clipped.png'))
    plt.close()
    
    print(f'{time_diff(t_):>6}s')
    print(f'{"Total duration:":_>26} {time_diff(t_0):>6}s\n')
    

In [None]:
all_imgs = lsd(md.folders['images'], True, '\.tif')
base_dir = '/projects/ag-bozek/lunaphore/'
out_dir = os.path.join(base_dir, 'reports/figures/bg_correction_new/')
Path(out_dir).mkdir(exist_ok=True, parents=True)

In [None]:
for file in all_imgs[3:]:
    process_raw_image(file)
print('_'*40, '\n\n--------- All files processed ----------')

  47.0s
.........Plotting clipped 