In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_io as tfio
# import tensorflow_addons as tfa

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import MoCo, MoCo_Classifier, UpdateQueue
from micron2.data import stream_dataset, stream_dataset_parallel

import h5py
import tqdm.auto as tqdm
from tqdm.keras import TqdmCallback

AUTO = tf.data.experimental.AUTOTUNE

In [None]:
!ls /storage/codex/preprocessed_data/pembro_TLS_panel/210702_PembroRT_Cas19_TLSpanel_reg2/collection*

In [None]:
# Change this cell to load from multiple samples, and concatenate. 
# make sure to place channels correctly, in case they're permuted for some reason between collections.
import glob
srch = '/storage/codex/preprocessed_data/pembro_TLS_panel/210702_PembroRT_Cas19_TLSpanel_reg2/collection*'

image_files = sorted(glob.glob(f'{srch}/training_cells_images.npy'))
annot_files = sorted(glob.glob(f'{srch}/training_cells_annots.npy'))
channel_files = sorted(glob.glob(f'{srch}/training_cells_channels.npy'))

image_files

images = np.concatenate([np.load(f) for f in image_files], axis=0)
annots = np.concatenate([np.load(f) for f in annot_files], axis=0)
channels = np.load(channel_files[0])

channels = {k:i for i,k in enumerate(channels)}

print(images.shape, images.dtype)
print(annots.shape)
print(np.unique(annots, return_counts=True))
print(channels)

perm = np.random.choice(images.shape[0], images.shape[0], replace=False)
images = images[perm]
annots = annots[perm]

u_annots, annots_int = np.unique(annots, return_inverse=True)
labels = np.eye(len(u_annots))[annots_int]
# labels[annots == '', :] = 0 
labels = labels[:,1:]

print(labels.shape)
print('labelled cells:', np.sum(np.sum(labels, axis=1)>0))

In [None]:
!pwd

In [None]:
# input_size = 64
crop_size = 52
batch_size = 64
max_queue_len = 4096
n_channels = len(channels)

outdir = '/home/ingn/devel/micron2/notebooks/devel_debug/moco-models'
if not os.path.isdir(outdir):
    os.makedirs(outdir)
    
with open(f'{outdir}/channels.txt', 'w+') as f:
    for c in channels.keys():
        f.write(f'{c}\n')

In [None]:
def process(x,y):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)/255.
    
    x = tf.image.random_crop(x, [crop_size, crop_size, n_channels])
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)

    return x,y
    
# Create a dataset with infinite repeating 

# dataset = (tf.data.Dataset.from_tensor_slices(images)
#            .shuffle(1024, reshuffle_each_iteration=True)
#            .map(process, num_parallel_calls=4)
#            .batch(batch_size, drop_remainder=True)
#            .prefetch(1)
#           )


image_dataset = tf.data.Dataset.from_tensor_slices(images)
label_dataset = tf.data.Dataset.from_tensor_slices(labels)
dataset = (tf.data.Dataset.zip((image_dataset, label_dataset)) 
           .shuffle(2048, reshuffle_each_iteration=True)
           .map(process, num_parallel_calls=4)
           .batch(batch_size, drop_remainder=True)
           .prefetch(1)
          )

In [None]:
for batch,label in dataset:
    print(batch.shape, label.shape)
    break

In [None]:
sample_x = tf.zeros([1, crop_size, crop_size, len(channels)],dtype=tf.float32)
model = MoCo_Classifier(data_shape=[crop_size, crop_size, len(channels)], 
                        z_dim=256, 
                        n_classes=labels.shape[1],
                        mlp_dim=128,
                        max_queue_len=max_queue_len,  
                        batch_size=batch_size, 
                        temp=0.1, 
                        encoder_type='EfficientNetB1')

model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-2), 
              loss=tf.keras.losses.sparse_categorical_crossentropy,)
z = model(sample_x)
g = model.encode_g(sample_x)
k = model.encode_k(sample_x)
y = model.classifier(sample_x)

print(z.shape)
print(g.shape)
print(k.shape)
print(y.shape)

model.summary()

In [None]:
model.fit(dataset, epochs=25, verbose=0,
          callbacks = [UpdateQueue(0.999, max_queue_len), 
                       TqdmCallback(verbose=2)]
         )

In [None]:
52 / 64

In [None]:
has_labels = labels.sum(axis=1) > 0
print(np.sum(has_labels))
labelled_images = images[has_labels]
use_labels = labels[has_labels]

def process(x,y):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)/255.
    
    x = tf.image.central_crop(x, 0.8125)
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)

    return x,y

image_dataset = tf.data.Dataset.from_tensor_slices(labelled_images)
label_dataset = tf.data.Dataset.from_tensor_slices(use_labels)
dataset = (tf.data.Dataset.zip((image_dataset, label_dataset)) 
           .repeat(10)
           .shuffle(1024, reshuffle_each_iteration=True)
           .map(process, num_parallel_calls=4)
           .batch(64, drop_remainder=True)
           .prefetch(1)
          )

for batch in dataset:
    batch_images, batch_labels = batch
    print(batch_images.shape, batch_labels.shape)
    break

In [None]:
# change the classifier loss weight
model.alpha = 0.99
model.beta = 0.01
model.fit(dataset, epochs=20, verbose=0,
          callbacks = [UpdateQueue(0.999, max_queue_len), 
                       TqdmCallback(verbose=2)]
         )

In [None]:
print(outdir)
model.save_weights(f'{outdir}/weights.h5')
model.encode_g.save_weights(f'{outdir}/weights_g.h5')
model.encode_k.save_weights(f'{outdir}/weights_k.h5')
model.classifier.save_weights(f'{outdir}/weights_cls.h5')