In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import Encoder, train_moco
from micron2.data import stream_dataset

import h5py
import tqdm.auto as tqdm

In [None]:
!ls /home/ingn/tmp/micron2-data

In [None]:
input_size = 128
crop_size = 96
# use_channels = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD4', 'CD8', 'PDGFRb', 'CD20', 'CD68', 'IgG', 'C1q']

with h5py.File('/home/ingn/tmp/micron2-data/pembroRT-set1-set2/merged.hdf5', 'r') as f:
    all_channels = [b.decode('UTF-8') for b in f['meta/channel_names'][:]]
    print(all_channels)
    
    # just use all the channels
    use_channels = all_channels

#     means = tf.constant([f[f'cell_intensity/{c}'].attrs['mean'] for c in use_channels], dtype=tf.float32)
#     stds =  tf.constant([f[f'cell_intensity/{c}'].attrs['std'] for c in use_channels], dtype=tf.float32)
#     print(means)
#     print(stds)


In [None]:
outdir = '/home/ingn/tmp/micron2-data/pembroRT-set1-set2/moco-tiles'
if not os.path.isdir(outdir):
    os.makedirs(outdir)
    
with open(f'{outdir}/use_channels.txt', 'w+') as f:
    for c in use_channels:
        f.write(f'{c}\n')
        
# import os
# if os.path.exists(f'{outdir}/weights.h5'):
#     model.load_weights(f'{outdir}/weights.h5')

In [None]:
print(use_channels)

def process(x):
    """
    x is [N, h, w, c]
    """
#     x = tf.cast(x, tf.float32)
#     x = (x - means) / stds
    
    x = tf.cast(x, tf.float32)/255.
    x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    return x
    
# Create a dataset with infinite repeating 
dataset = stream_dataset('/home/ingn/tmp/micron2-data/pembroRT-set1-set2/merged.hdf5', 
                         use_channels=use_channels,
                         group_name='images' )
dataset = (dataset.repeat(None)
           .shuffle(1024 * 4)
           .map(process, num_parallel_calls=4)
           .batch(32)
           .prefetch(64)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )


In [None]:
sample_x = tf.zeros([1, input_size, input_size, len(use_channels)],dtype=tf.float32)
model = Encoder(input_shape=[crop_size, crop_size, len(use_channels)])
y, z_g = model(sample_x, return_g=True)
print(y.shape)
print(z_g.shape)
z = model.encode(sample_x)
print(z.shape)

model.summary()

In [None]:
# # Placing the key encoder on CPU helps training speed, maybe
with tf.device('/CPU:0'):
# kmodel = Encoder(input_shape=sample_x.shape[1:])
    kmodel = Encoder(input_shape=[crop_size, crop_size, len(use_channels)])
    y, z_g = kmodel(sample_x, return_g=True)
    print(y.shape)
    print(z_g.shape)
    z = kmodel.encode(sample_x)
    print(z.shape)

kmodel.summary()

In [None]:
loss_history = train_moco(dataset, model, kmodel, 
                          lr=1e-4,
                          max_queue_len=128, 
                          crop_size=crop_size,
                          max_steps=7500*3, temp=0.1,
                          perturb=False
                         )

In [None]:
from matplotlib import pyplot as plt
lh = len(loss_history)
plt.figure(figsize=(8,2))
plt.plot(np.arange(lh), np.log10(loss_history), lw=0.1)

In [None]:
model.save_weights(f'{outdir}/weights.h5')
np.save(f'{outdir}/loss_history.npy', np.array(loss_history),)

# Process a whole dataset

In [None]:
crop_frac = crop_size / input_size
print(crop_frac)
def process_crop(x):
    """
    x is [N, h, w, c]
    """
#     x = tf.cast(x, tf.float32)
#     x = (x - means) / stds
    
    x = tf.cast(x, tf.float32)/255.
    x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    x = tf.image.central_crop(x, crop_frac)
    return x

dataset = stream_dataset('/home/ingn/tmp/micron2-data/pembroRT-set1-set2/merged.hdf5', 
                         use_channels=use_channels,
                         group_name='images'
                        )
dataset = (dataset.map(process_crop, num_parallel_calls=2)
           .batch(32)
           .prefetch(8)
          )

z = []
for batch in tqdm.tqdm(dataset):
    z.append(model.encode(batch, training=False).numpy())
    
z = np.concatenate(z, axis=0)
print(z.shape)

In [None]:
np.save(f'{outdir}/z.npy', z)