In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_io as tfio

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import Encoder, train_moco
from micron2.data import stream_dataset, stream_dataset_parallel

import tqdm.auto as tqdm
import h5py

AUTO = tf.data.experimental.AUTOTUNE

In [None]:
use_channels = ['C1q', 'CD103', 'CD11c', 'CD134', 'CD138', 'CD20', 'CD31', 
                'CD3e', 'CD4', 'CD40', 'CD40LG', 'CD45', 'CD45RA', 'CD45RO', 
                'CD64', 'CD68', 'CD69', 'CD8', 'CD80', 'CD89', 'CXCL13', 
                'CXCR5', 'DAPI', 'FOXP3', 'GZMB', 'HLA-DR', 'IL7R', 'IgA', 
                'IgG', 'IgM', 'Ki-67', 'LAG3', 'OX40L', 'PD-1', 'PD-L1', 'PDGFRb', 
                'PNaD', 'PanCytoK', 'TIM3', 'aSMA']

In [None]:
input_size = 64
crop_size = 48
batch_size = 64
max_queue_len = 8192

# use_channels = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD4', 'CD8', 'PDGFRb', 'CD20', 'CD68']
with h5py.File('/home/ingn/tmp/micron2-data/pembroRT-set1-set2/merged_v2.hdf5', 'r') as f:
    all_channels = [b.decode('UTF-8') for b in f['meta/channel_names'][:]]
    print(all_channels)
    
    # just use all the channels
#     use_channels = all_channels
 
    means =  tf.constant([f[f'cells/{c}'].attrs['mean'] for c in use_channels], dtype=tf.float32)
    maxes =  tf.constant([f[f'cells/{c}'].attrs['max'] for c in use_channels], dtype=tf.float32)
    print(means)
    print(maxes)
    
print(use_channels)


In [None]:
def process(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)#/255.
    x = x / maxes
    
    x = tf.image.random_brightness(x, 0.2)
    x = tf.image.random_crop(x, [crop_size, crop_size, len(all_channels)])
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    return x
    
# Create a dataset with infinite repeating 

# fname = '/home/ingn/tmp/micron2-data/210122_Breast_Cassette7_reg2.hdf5.cells.hdf5'
# dataset = (tfio.IODataset.from_hdf5(fname, '/images/cells',)
fname = '/home/ingn/tmp/micron2-data/pembroRT-set1-set2/merged_v2.hdf5'
dataset = (stream_dataset(fname, all_channels)
           .repeat(None)
           .shuffle(1024 * 1)
           .map(process, num_parallel_calls=12)
           .batch(batch_size)
           .prefetch(256)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )


In [None]:
for batch in tqdm.tqdm(dataset):
    try:
        pass
    except KeyboardInterrupt:
        break

In [None]:
print(batch.shape)