In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

We start with a stack of channels across all cycles

The goal is to find consistent common signal and remove it from all images

Ideally, the signal unique to each channel will remain

Strong assumption is that the channels won't perfectly overlap anywhere

In [None]:
%load_ext autoreload
%matplotlib inline

from matplotlib import pyplot as plt
from matplotlib import rcParams
import seaborn as sns
import numpy as np
import pandas as pd

import cv2
import pytiff

In [None]:
!ls /storage/codex/preprocessed_data/210127_Breast_Cassette9_reg2

In [None]:
stain_info = pd.read_csv('/home/ingn/tmp/micron2-data/pembroRT/pembroRT_cycle_channels.csv',
                         index_col=None,header=0)
stain_info

In [None]:
rcParams['figure.facecolor'] = (1,1,1,1)

In [None]:
import glob
bbox = [7758, 8525, 15626, 16461]

sample = '210127_Breast_Cassette9_reg2'
nuclei_f = glob.glob(f'/storage/codex/preprocessed_data/{sample}/{sample}_*nuclei.tif')
print(nuclei_f)

channels = [x for x in stain_info['TTO_550'].values if ('Blank' not in x)&('Empty' not in x)]

image_home = f'/storage/codex/preprocessed_data/{sample}/images'
all_images = sorted(glob.glob(f'{image_home}/*.tif'))

def find_source(ch):
    for i in all_images:
        if f'_{ch}_' in i:
            return i
        
sources = [find_source(ch) for ch in channels]
print(sources)

images = []
for i in sources:
    with pytiff.Tiff(i, "r") as f:
        img = f.pages[0][bbox[0]:bbox[1], bbox[2]:bbox[3]]
        print(i, f.pages[0].shape, img.shape)
        images.append(img.copy())

In [None]:
from skimage.filters import difference_of_gaussians, gaussian


In [None]:
ncol=4
nrow=int(np.ceil(len(channels)/ncol))

ch_saturation = []
fig, axs = plt.subplots(nrow,ncol,figsize=(3.25*ncol,3*nrow),dpi=300)
axs = axs.ravel()
for i,ch in enumerate(channels):
    ax = axs[i]
    img = images[i].copy()
    top = np.quantile(img, 0.999)
    img[img>top] = top
    ch_saturation.append(top)
    img[img<0] = 0
    m = ax.matshow(img)
    ax.set_title(ch)
    plt.colorbar(m, ax=ax, shrink=0.7)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
ncol=4
nrow=int(np.ceil((len(channels)+1)/ncol)) 

gauss_images = [gaussian(i,3,preserve_range=True) for i in images]
correction = np.min(gauss_images, axis=0)

fig, axs = plt.subplots(nrow,ncol,figsize=(5.25*ncol,5*nrow),dpi=300)
axs = axs.ravel()
for i,ch in enumerate(channels+['correction']):
    ax = axs[i]
    if ch == 'correction':
        img = correction
        real_min = np.min(img)
        real_max = np.max(img)
        top = np.quantile(img, 0.999)
        img[img>top] = top
    else:
        img = gauss_images[i].copy().astype(correction.dtype)
        real_min = np.min(img)
        real_max = np.max(img)
        img = img - correction
        top = np.quantile(img, 0.999)
#         top = ch_saturation[i]
        img[img>top] = top
        img[img<0] = 0
    
    m = ax.matshow(img, vmax=top)
    ax.set_title(f'{i}. {ch} ({real_min:3.3f}-{real_max:3.3f})')
    plt.colorbar(m, ax=ax, shrink=0.7)
    
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])