In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import h5py
import numpy as np
import tqdm.auto as tqdm

# Needs for RTX 3000 series cards running CUDA 11.0, cudnn 8.0.4, tensorflow 2.4 (2021-Jan-17)
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.embed_sets import SetEncoding, train_sets_SimCLR, stream_sets
from micron2.clustering import train_moco
from micron2 import stream_dataset

from matplotlib import pyplot as plt
import seaborn as sns

from MulticoreTSNE import MulticoreTSNE

In [None]:
def process(x):
    x = tf.cast(x, tf.float32) 
    m = tf.math.reduce_mean(x, axis=[0,1,2,3])
    s = tf.math.reduce_std(x, axis=[0,1,2,3])
    x = (x - m) / s
#     x = tf.cast(x, tf.float32)/255.
#     x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    return x

In [None]:
with h5py.File("/home/ingn/tmp/micron2-data/dataset.hdf5", "r") as h5f:
    all_channels = [b.decode('utf-8') for b in h5f['meta/channel_names'][:]]

In [None]:
USE_CHANNELS = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD20', 'C1q', 'CD4', 'CD8', 'CD40', 'HLA-DR']
dataset = stream_dataset('/home/ingn/tmp/micron2-data/setdataset.hdf5', 
                         use_channels=all_channels)
dataset = (dataset.repeat(None)
          .shuffle(2048)
          .map(process, num_parallel_calls=8)
          .batch(16)
          .prefetch(32))

In [None]:
model = SetEncoding(inner_dim=256, outer_dim=128, g_dim=32, 
                    crop_size=48, # crop_size is for model set up only
                    encoder_type='keras_resnet',
                    n_channels=len(all_channels)
                   )
kmodel = SetEncoding(inner_dim=256, outer_dim=128, g_dim=32, 
                    crop_size=48, 
                    encoder_type='keras_resnet',
                    n_channels=len(all_channels)
                   )

# Call it with correctly sized dummy batch to initialize variables
x = np.zeros((2, 6, 64, 64, len(all_channels)))
y = model(x)
yk = kmodel(x)
print(y.shape)

In [None]:
loss_history = train_moco(dataset, model, kmodel, max_queue_len=512, 
                          crop_size=48, max_steps=1e5, temp=0.1, 
                          lr = 1e-4,
                          perturb=False)

In [None]:
plt.figure(figsize=(10,1), dpi=180)
plt.plot(np.arange(len(loss_history)), loss_history, lw=0.1)

In [None]:
smth = [np.mean(loss_history[i:i+50]) for i in np.arange(0, len(loss_history)-51, 10)]
plt.figure(figsize=(10,1), dpi=180)
_ = plt.plot(np.arange(len(smth)), smth, lw=1)

In [None]:
z = []

dataset = stream_dataset('/home/ingn/tmp/micron2-data/setdataset.hdf5', 
                         use_channels=all_channels)

dataset = (dataset.map(process)
           .batch(16)
           .prefetch(128))

for i,batch in enumerate(tqdm.tqdm(dataset)):
    zb = model(batch, training=False).numpy()
    z.append(zb)
    if i > 2000:
        break

z = np.concatenate(z, axis=0)
print(z.shape)

In [None]:
emb = MulticoreTSNE(n_jobs=24).fit_transform(z)
print(emb.shape)

In [None]:
plt.figure(figsize=(4,4), dpi=180)
plt.scatter(emb[:,0], emb[:,1], s=0.1, color='k')

In [None]:
z = []

dataset = stream_dataset('/home/ingn/tmp/micron2-data/setdataset.hdf5', 
                         use_channels=all_channels)

dataset = (dataset.map(process)
           .batch(32)
           .prefetch(128))

for i,batch in enumerate(tqdm.tqdm(dataset)):
    zb = model(batch, training=False).numpy()
    z.append(zb)

z = np.concatenate(z, axis=0)
print(z.shape)

In [None]:
import os
save_dir = '/home/ingn/tmp/micron2-data/sets_moco'
os.makedirs(save_dir, exist_ok=True)

model.save_weights(f'{save_dir}/weights.h5')
np.save(f'{save_dir}/embedding.npy', z)