In [None]:
# OPTIONAL: Load the "autoreload" extension so that code can change
%reload_ext autoreload

# OPTIONAL: always reload modules so that as you change code in src, it gets loaded
%autoreload 2

import sys
sys.path.append('../src')

import cv2
from datetime import datetime
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view as sww
from pathlib import Path
import os
import re
import tifffile as tiff
from time import time
from tqdm import tqdm

from metadata import metadata
from utils import list_subdir_filter as lsd, get_id, time_diff, unique

global md
md = metadata()

In [None]:
def how_old(path):
    age = int((time() - os.path.getmtime(path)) // 3600)
    if age < 1:
        return '<1h'
    if age < 24:
        return f'{age}h'
    age = int(age // 24)
    if age < 365:
        return f'{age}d'
    return '>1y'
    
        
def export_img_signal_histogram(signal_minus_AF, sample_id, marker_name):

    plt.figure(figsize=(10,4))
    p = plt.hist(signal_minus_AF.flatten()[::4], log=True, bins=150, color='green')
    plt.title(f'{sample_id} ({marker_name}) post-AF correction')
    plt.vlines(x=0, ymin=0.1, ymax=p[0].max()*1.02, color='gold')
    plt.xlim(-65535, 65535)
    plt.savefig(os.path.join(md.folders['bg_removed'], f'intensity_dist_AF_correction_{sample_id}_{marker_name}.png'))
    plt.close()



def export_img_signal_AF_overlap(rgb_array, sample_id, marker_name):

    if np.shape(rgb_array)[0] < np.shape(rgb_array)[-1]:
        rgb_array = np.moveaxis(rgb_array, 0, -1)
    
    plt.figure(figsize=(6,6))
    plt.imshow(rgb_array)
    plt.axis('off')
    plt.title(f'{sample_id}: AF/{marker_name}')
    plt.savefig(os.path.join(md.folders['bg_removed'], f'AF_correction_{sample_id}_{marker_name}.png'), transparent=True)
    plt.close()



def combine_clean_AF_channels(signal_minus_AF, AF, clipping_threshold, scale_factor=8):
    
    def normalize_clipped(array, clipping_threshold):
        return np.clip(array, 0, clipping_threshold)/clipping_threshold
    
    combined_RGB = [
        normalize_clipped(signal_minus_AF[::scale_factor, ::scale_factor], clipping_threshold),
        normalize_clipped(AF[::scale_factor, ::scale_factor], clipping_threshold),
        normalize_clipped(signal_minus_AF[::scale_factor, ::scale_factor], clipping_threshold)
    ]

    return combined_RGB



def subtract_AF(signal, af, img_id, marker_name, clipping_threshold=20000):

    clean = signal.astype('int32') - af
    
    export_img_signal_histogram(clean, img_id, marker_name)

    combined_clean_AF = combine_clean_AF_channels(
        clean, 
        af, 
        clipping_threshold=clipping_threshold, 
        scale_factor=8
    )

    export_img_signal_AF_overlap(combined_clean_AF, img_id, marker_name)

    np.save(
        os.path.join(md.folders['bg_removed'], f'clean_{marker_name}_{img_id}.npy'),
        (np.clip(clean, 0, np.Inf).astype('uint16') >> 8).astype('uint8')
    )



def do_the_cleaning(img_file):
    img_id = get_id(img_file)
    print(img_id)

    markers = list(md.markers)

    existing = sum([os.path.exists(os.path.join(md.folders['bg_removed'], f'clean_{marker}_{img_id}.npy')) for marker in markers])
    if existing == 9:
        print(f'{img_id} already processed, skipping')
        return None
    
    bg = tiff.imread(img_file, key=1)
    
    for i in range(3, len(md.markers), 2):
        marker = markers[i]
        print(marker)
        signal = tiff.imread(img_file, key=i)
        subtract_AF(signal, bg, img_id, marker_name=marker)

    bg = tiff.imread(img_file, key=2)
    for i in range(4, len(md.markers), 2):
        marker = markers[i]
        print(marker)
        signal = tiff.imread(img_file, key=i)
        subtract_AF(signal, bg, img_id, marker_name=marker)


def safe_cleaning(img_file):
    img_id = get_id(img_file)
    print(img_id)
    do_the_cleaning(img_file)

    # try:
    #     do_the_cleaning(img_file)
    # except:
    #     print(f'error with {img_id}')

In [None]:
all_img_files = lsd(md.folders['images'], True, '.tif')

for f in all_img_files:
    print(get_id(f), how_old(f))

In [None]:

for img_file in tqdm(all_img_files)
    safe_cleaning(img_file)