In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_io as tfio

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import MoCo, UpdateQueue
from micron2.data import stream_dataset, stream_dataset_parallel

import h5py
import tqdm.auto as tqdm
from tqdm.keras import TqdmCallback

AUTO = tf.data.experimental.AUTOTUNE

In [None]:
input_size = 64
crop_size = 56
batch_size = 32
max_queue_len = 4096

# # use_channels = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD4', 'CD8', 'PDGFRb', 'CD20', 'CD68']
with h5py.File('/home/ingn/tmp/micron2-data/bladder/bladder_merged_v4.hdf5', 'r') as f:
    all_channels = [b.decode('UTF-8') for b in f['meta/channel_names'][:]]
    use_channels = all_channels
 
#     means =  tf.constant([f[f'cells/{c}'].attrs['mean'] for c in use_channels], dtype=tf.float32)
#     maxes =  tf.constant([f[f'cells/{c}'].attrs['max'] for c in use_channels], dtype=tf.float32)
#     print(means)
#     print(maxes)
    
print(use_channels)


In [None]:
outdir = '/home/ingn/tmp/micron2-data/bladder/moco-cells-v2'
if not os.path.isdir(outdir):
    os.makedirs(outdir)
    
with open(f'{outdir}/use_channels.txt', 'w+') as f:
    for c in use_channels:
        f.write(f'{c}\n')
        
# import os
# if os.path.exists(f'{outdir}/weights.h5'):
#     model.load_weights(f'{outdir}/weights.h5')

In [None]:
fname = '/home/ingn/tmp/micron2-data/bladder/bladder_merged_v4.hdf5'
dataset = (stream_dataset(fname, all_channels)
           .shuffle(1024 * 2, reshuffle_each_iteration=True)
           .batch(batch_size, drop_remainder=True)
           .prefetch(16)
          )

mean_tensor = []
for i,batch in enumerate(dataset):
    if i == 200:
        break
    mean_tensor.append(batch.numpy())
    
    
mean_tensor = np.concatenate(mean_tensor, axis=0)
mean_tensor = tf.reduce_mean(mean_tensor, axis=0)
print(mean_tensor.shape)

In [None]:
def process(x):
    """
    x is [N, h, w, c]
    """
    x = x - mean_tensor
    x = tf.cast(x, tf.float32)/255.
    #x = x / maxes
    
    #x = tf.image.random_brightness(x, 0.2)
    x = tf.image.random_crop(x, [56, 56, 36])
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    return x
    
# Create a dataset with infinite repeating 
fname = '/home/ingn/tmp/micron2-data/bladder/bladder_merged_v4.hdf5'

dataset = (stream_dataset(fname, all_channels)
           .shuffle(1024 * 2, reshuffle_each_iteration=True)
           .map(process, num_parallel_calls=6)
           .batch(batch_size, drop_remainder=True)
           .prefetch(16)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )


In [None]:
sample_x = tf.zeros([1, crop_size, crop_size, len(use_channels)],dtype=tf.float32)
model = MoCo(data_shape=[crop_size, crop_size, len(use_channels)],
             z_dim=128, max_queue_len=max_queue_len, 
             batch_size=batch_size,
             temp=0.1, crop_size=crop_size,
             encoder_type='EfficientNetB1')

model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4), 
              loss=tf.keras.losses.sparse_categorical_crossentropy,)
z = model(sample_x)
_ = model.encode_g(sample_x)
_ = model.encode_k(sample_x)
print(z.shape)

model.summary()

In [None]:
model.fit(dataset, epochs=20, verbose=0,
          callbacks = [UpdateQueue(0.999, max_queue_len), 
                       TqdmCallback(verbose=2)]
         )

In [None]:
model.save_weights(f'{outdir}/weights.h5')
model.encode_g.save_weights(f'{outdir}/weights_g.h5')
model.encode_k.save_weights(f'{outdir}/weights_k.h5')

In [None]:
crop_pct = 56 / 64
def process(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)/255.
#     x = x / maxes
    
#     x = tf.image.random_brightness(x, 0.2)
#     x = tf.image.random_crop(x, [48, 48, 40])
#     x = tf.image.random_flip_left_right(x)
#     x = tf.image.random_flip_up_down(x)
    x = tf.image.central_crop(x, crop_pct)
    return x
    
dataset = stream_dataset(fname, all_channels)

# dataset = (tfio.IOdataset.from_hdf5(fname, '/images/cells')
dataset = (dataset
           .map(process, num_parallel_calls=6)
           .batch(batch_size)
           .prefetch(128)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )


In [None]:
z = []
for batch in tqdm.tqdm(dataset):
    z_ = model.encode_g(batch, training=False)
    z.append(z_.numpy())
    
z = np.concatenate(z, axis=0)
print(z.shape)

In [None]:
outdir

In [None]:
np.save(f'{outdir}/embedding.npy', z)