In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import scanpy as sc
import pandas as pd
import numpy as np

import h5py

from matplotlib import pyplot as plt
from matplotlib import rcParams
import seaborn as sns

In [None]:
import cv2
from skimage.filters.rank import entropy
from skimage.morphology import disk
kernel=np.ones((3,3))
def get_mask_outline(mask):
    d = cv2.dilate(mask.astype(np.uint8), kernel, 1)
    return (d-mask)>0


def bias_cell_choice_intensity(h5f, key, channel, normalize='high', set_zero=None):
    """ Get a vector for use in biasing random choice of cells
    
    normalize='low': Low values have high weight
    normalize='high': High values have high weight
    
    Args:
        h5f (h5py.File): dataset
        key (str): top-level key (`cells`, `images`)
        channel (str): h5f['{key}/{channel}'] must exist
        normalize (str): 'high' or 'low' (see description)
        set_zero (bool): if None 0 intensity to 0, else set to this value
    
    Returns:
        p (np.float32): choice bias vector as in: `np.random.choice(..., p=p)`
    """
    vals = h5f[f'{key}/{channel}'][:]
    if normalize=='high':
        vals = vals/vals.max()
    elif normalize=='low':
        vals = 1 - (vals/vals.max())
        
    
    if set_zero is not None:
        vals[vals==0] = set_zero
        
    vals = vals/np.sum(vals)
    return vals
    

from skimage.filters import threshold_otsu
def show_cells(h5f, key, channel, cell_bias=None, 
               seed=None,  ids=None, return_ids=False,
               intensity_key=None,
               n=49, dpi=90, force_max=None,
               nuclei=False
              ):
    dataset = f'{key}/{channel}'
    if (seed is not None) and (ids is None):
        np.random.seed(seed)
    if ids is None:
        ids = np.random.choice(h5f[dataset].shape[0], n, replace=False, p=cell_bias)
        
    if intensity_key is not None:
        ivals = np.log1p(h5f[f'{intensity_key}/{channel}'][:])
        icut = ivals[ivals>0]
        # subtract non-zero mean
        imean = np.mean(icut)
        icut = icut-imean
        thr = threshold_otsu(icut)
        
    # Make this nice
    ncol=7
    nrow=7
    fig,axs = plt.subplots(nrow,ncol, figsize=(1.7*ncol,1.5*nrow), dpi=dpi,
                           gridspec_kw=dict(hspace=0.5,wspace=0.5))
    for i,ax in zip(ids, axs.ravel()):
        cell_id = h5f['meta/Cell_IDs'][i].decode('utf-8')
        img = h5f[dataset][i,...]
        if nuclei and (key=='cells'):
            mask = h5f['meta/nuclear_masks'][i,...]
            img[get_mask_outline(mask)] = max(img.max(),1)
        m = ax.matshow(img, vmax=force_max)
        ax.set_xticks([])
        ax.set_yticks([])
        plt.colorbar(m, ax=ax)
        
        if intensity_key is not None:
            intens = ivals[i] - imean 
            c = 'k' if intens > thr else 'r'
            
            ax.annotate(f'{intens:3.3f}', (0, 1), xycoords='axes fraction', 
                        color=c,
                        va='bottom')
        
#     if h5f[dataset].attrs['threshold']:
#         title = f'{key}/{channel} ({h5f[dataset].attrs["threshold"]:3.3f})'
#     else:
#         title = f'{key}/{channel}'
        
    title = f'{key}/{channel}: >{h5f[dataset].attrs["threshold"]:3.3f}'
    plt.suptitle(title,y=0.9,va='bottom') 
    if return_ids:
        return ids

In [None]:
!ls -d /storage/codex/preprocessed_data/*

In [None]:
channel = 'CD20'
intensity_key = 'cell_intensity'
sample_id = '210115_Breast_Cassette15_reg1'
pth = f'/storage/codex/preprocessed_data/{sample_id}/{sample_id}.hdf5'

kws = {
    'force_max': None,
    'nuclei': False,
    'intensity_key': intensity_key
}

print(pth)
with h5py.File(pth, "r") as h5f:
    seed = np.random.choice(999)
    print(seed)
    bias = bias_cell_choice_intensity(h5f, intensity_key, channel, normalize='high', set_zero=1e-4)
    ids = show_cells(h5f, 'cells', channel, return_ids=True, cell_bias=bias, **kws)
    ids = show_cells(h5f, 'cells', 'DAPI',  seed=seed,  ids=ids, **kws)
                    

In [None]:
with h5py.File(pth, "r") as h5f:
    for k in sorted(h5f['cells'].keys()):
        m = h5f[f'cell_intensity/{k}'].attrs['mean']
        s = h5f[f'cell_intensity/{k}'].attrs['std']
        nz = np.mean(h5f[f'cell_intensity/{k}'][:] == 0)
        disp = s / m
        t = h5f[f'cells/{k}'].attrs['threshold']
        print(f'{k:<10} {m:5.3f}\t{s:3.3f}\t{t}\t{100*nz:3.2f}%')