In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import pandas as pd
import h5py

# Needs for RTX 3000 series cards running CUDA 11.0, cudnn 8.0.4, tensorflow 2.4 (2021-Jan-17)
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.embed_sets import SetEncoding, train_sets_SimCLR, stream_sets

In [None]:
with h5py.File("/home/ingn/tmp/micron2-data/dataset_v2.hdf5", "r") as h5f:
    all_channels = [b.decode('utf-8') for b in h5f['meta/channel_names'][:]]
    
model = SetEncoding(inner_dim=256, outer_dim=128, g_dim=32, 
                    crop_size=48, # crop_size is for model set up only
                    encoder_type='keras_resnet',
                    n_channels=len(all_channels)
                   )
x = np.zeros((2, 6, 64, 64, len(all_channels)))
y = model(x)

In [None]:
!ls /home/ingn/tmp/micron2-data/single_moco

In [None]:
model.load_weights('/home/ingn/tmp/micron2-data/single_moco/weights.h5', by_name=True)

In [None]:
all_channels

In [None]:
with h5py.File("/home/ingn/tmp/micron2-data/dataset_v2.hdf5", "r") as h5f:
    means = tf.constant([h5f[f'cell_intensity/{c}'].attrs['mean'] for c in all_channels], dtype=tf.float32)
    stds =  tf.constant([h5f[f'cell_intensity/{c}'].attrs['std'] for c in all_channels], dtype=tf.float32)
    
# print(means, stds)

In [None]:
import tqdm.auto as tqdm
def process(x):
    crop_pct = 48 / x.shape[1] 
    x = tf.cast(x, tf.float32)
    x = tf.image.central_crop(x, crop_pct)
    x = (x - means) / stds
    return x
    
    
zs = []
with h5py.File("/home/ingn/tmp/micron2-data/dataset_v2.hdf5", "r") as h5f:
    coords = h5f['meta/cell_coordinates'][:]
    streamer = stream_sets(h5f, coords=coords, use_channels=all_channels)
    
    batch = []
    for i, x in enumerate(tqdm.tqdm(streamer)):
        x = process(x)
        batch.append(x)
        if i % 8 == 0:
            batch = tf.stack(batch, axis=0)
            z = model(batch, training=False).numpy()
            zs.append(z.copy())
            batch = []
            
    # Process the leftover sample
    batch = tf.stack(batch, axis=0)
    z = model(batch, training=False).numpy()
    zs.append(z.copy())
        
    
zs = np.concatenate(zs, axis=0)
print(zs.shape)

In [None]:
from matplotlib import pyplot as plt
from matplotlib import rcParams

In [None]:
from sklearn.cluster import MiniBatchKMeans
MBKM = MiniBatchKMeans(n_clusters=20)
groups = MBKM.fit_predict(zs)
print(groups.shape)

In [None]:
rcParams['figure.dpi'] = 600
for g in np.unique(groups):
    i = groups == g
    plt.scatter(coords[i,0], -coords[i,1], s=0.25)