In [None]:
import numpy as np
import pandas as pd
import pytiff
import h5py
import time
import os

from tqdm.auto import tqdm

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
!ls /mnt/linux-data/codex/preprocessed_data/201021_BreastFFPE_Final/

In [None]:
cells = pd.read_csv('/storage/codex/preprocessed_data/210113_Breast_Cassette11_reg1/210113_Breast_Cassette11_reg1_2_centroids.csv', 
                    index_col=0, header=0)
cells.head(10)

In [None]:
imagefs = !ls /storage/codex/preprocessed_data/210113_Breast_Cassette11_reg1/images/*.tif
dapi_images = [f for f in imagefs if 'DAPI' in f]
non_dapi_images = [f for f in imagefs if 'DAPI' not in f]
non_dapi_images = [f for f in non_dapi_images if 'Blank' not in f]
non_dapi_images = [f for f in non_dapi_images if 'Empty' not in f]
for f in non_dapi_images:
    print(os.path.basename(f))

In [None]:
EXPT_NAME = '210113_Breast_Cassette11_reg1'
channel_names = [os.path.basename(x) for x in non_dapi_images]
channel_names = [x.replace(f'.tif','') for x in channel_names]
channel_names = [x.split('_')[-2] for x in channel_names]
channel_names = ["DAPI"] + channel_names
print( channel_names )

In [None]:
image_paths = [dapi_images[0]] + non_dapi_images
print(len(image_paths))
# image_handles = [pytiff.Tiff(dapi_images[0])] + [pytiff.Tiff(f) for f in non_dapi_images]

In [None]:
image_paths = {ch:p for ch, p in zip(channel_names, [dapi_images[0]] + non_dapi_images)}

In [None]:
_ = plt.hist(cells.Size, bins=100)
np.quantile(cells.Size, [0.01, 0.1, 0.9, 0.99])

In [None]:
size = 64

def pull_nuclei(coords, image_paths, out_file='dataset.hdf5', size=64, min_area=100, channel_names=None):
    h0 = pytiff.Tiff(image_paths[0])
    sizeh = int(size/2)
    h, w = h0.shape
    print(h, w)
    maxh = h - sizeh
    maxw = w - sizeh
    h0.close()
    
    if channel_names is None:
        channel_names = [f'ch{i:02d}' for i in range(len(image_paths))]
    assert len(channel_names) == len(image_paths)
    
    h5f = h5py.File(out_file, "w")
    
    datasets = []
    for c in channel_names:
        d = h5f.create_dataset(f'cells/{c}', shape=(coords.shape[0],size,size), maxshape=(None,size,size),
                               dtype='i', chunks=(1,size,size), compression='gzip')
        datasets.append(d)
        
    # remove coords too near the edges:
    # remember, x = "width" = size[1]; y = "height" = size[0]
    coords = coords.query("X > @sizeh & X < @maxw & Y > @sizeh & Y < @maxh")
    if min_area is not None:
        coords = coords.query("Size > @min_area")
    
    print(f'Pulling {coords.shape[0]} cells')
    for pth, d, c in zip(image_paths, datasets, channel_names):
        print(f'Pulling from channel {c}')
        h = pytiff.Tiff(pth)
        page = h.pages[0][:]
        
        i = 0
        for x, y in tqdm(zip(coords.X, coords.Y)):
            # print(x, y, a)
            bbox = [y-sizeh, y+sizeh, x-sizeh, x+sizeh]
            img = (255 * (page[bbox[0]:bbox[1], bbox[2]:bbox[3]] / 2**16)).astype(np.uint8)
            # img = [255 * (h.pages[0][bbox[0]:bbox[1], bbox[2]:bbox[3]] / 2**16).astype(np.uint8) for h in image_handles]
            # print(np.mean(img))
            # xout[i, :, :, :] = np.dstack(img)
            d[i,...] = img

            i += 1
            # if i > 5000: 
            #     break
        h.close()
        h5f.flush()
    h5f.close()

# pull_nuclei(cells, image_paths, out_file='dataset.hdf5', min_area=100, channel_names=channel_names)

In [None]:
from skimage.filters import threshold_otsu
def pull_nuclei_onechannel(coords, image_path, N, size=64, min_area=100):
    with pytiff.Tiff(image_path) as handle:
        sizeh = int(size/2)
        h, w = handle.shape
        print(h, w)
        maxh = h - sizeh
        maxw = w - sizeh
    # remove coords too near the edges:
    # remember, x = "width" = size[1]; y = "height" = size[0]
    coords = coords.query("X > @sizeh & X < @maxw & Y > @sizeh & Y < @maxh")
    if min_area is not None:
        coords = coords.query("Size > @min_area")
    
    print(f'Found {coords.shape[0]} cells')
    images = []
    raw_images = []
    # Sample to build up a background distribution
    ids = np.random.choice(coords.shape[0], 5000, replace=False)
    with pytiff.Tiff(image_path) as handle:
        page = handle.pages[0][:]
        for i in ids:
            x = coords.X[i]
            y = coords.Y[i]
            bbox = [y-sizeh, y+sizeh, x-sizeh, x+sizeh]
            raw_image = page[bbox[0]:bbox[1], bbox[2]:bbox[3]]
            raw_images.append(raw_image.copy())
    raw_images = np.dstack(raw_images)
    print(raw_images.shape)
    thr = threshold_otsu(raw_images.ravel())/2
    
    raw_images = []
    ids = np.random.choice(coords.shape[0], N, replace=False)
    with pytiff.Tiff(image_path) as handle:
        page = handle.pages[0][:]
        
        for i in ids:
            x = coords.X[i]
            y = coords.Y[i]
            bbox = [y-sizeh, y+sizeh, x-sizeh, x+sizeh]
            raw_image = page[bbox[0]:bbox[1], bbox[2]:bbox[3]]
            raw_images.append(raw_image.copy())
            #raw_image[raw_image<50] = 0
            #thr = threshold_otsu(raw_image)
            raw_image[raw_image<thr] = 0
            img = np.ceil(255 * (raw_image / 2**16)).astype(np.uint8)
            images.append(img.copy())
    return raw_images, images
            


In [None]:
# Make this nice
raw_images, images = pull_nuclei_onechannel(cells, image_paths['CD138'], 25, size=64)
ncol=5
nrow=5
fig,axs = plt.subplots(nrow,ncol, figsize=(2*ncol,2*nrow), dpi=90)
for i,ax in zip(images, axs.ravel()):
    m = ax.matshow(i)
    plt.colorbar(m, ax=ax)
    ax.set_xticks([])
    ax.set_yticks([])
    
fig,axs = plt.subplots(nrow,ncol, figsize=(2*ncol,2*nrow), dpi=90)
for i,ax in zip(raw_images, axs.ravel()):
    m = ax.matshow(i)
    plt.colorbar(m, ax=ax)
    ax.set_xticks([])
    ax.set_yticks([])