In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import Autoencoder, train_AE_simCLR
from micron2.data import stream_dataset

import h5py
import tqdm.auto as tqdm

In [None]:
use_channels = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD4', 'CD8', 'PDGFRb', 'CD20', 'CD68', 'IgG']
with h5py.File('/home/ingn/tmp/micron2-data/dataset.hdf5', 'r') as f:
    all_channels = [b.decode('UTF-8') for b in f['meta/channel_names'][:]]
    print(all_channels)

    means = tf.constant([f[f'intensity/{c}'].attrs['mean'] for c in use_channels], 
                        dtype=tf.float32)
    stds = tf.constant([f[f'intensity/{c}'].attrs['std'] for c in use_channels],
                       dtype=tf.float32)
    print(means)
    print(stds)


In [None]:
print(use_channels)

def process(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)
    x = (x - means) / stds
    
    # x = tf.cast(x, tf.float32)/255.
    # x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    return x
    
dataset = stream_dataset('/home/ingn/tmp/micron2-data/dataset.hdf5', use_channels=use_channels)
dataset = (dataset.repeat(10)
           .shuffle(1024 * 4)
           .map(process)
           .batch(8)
           .prefetch(32)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )

for sample_x in dataset:
    break
    
print(sample_x.shape)
for k in range(sample_x.shape[-1]):
    print(use_channels[k], sample_x.numpy()[...,k].mean())

In [None]:
sample_x = tf.image.random_crop(sample_x, size=(sample_x.shape[0], 48, 48, sample_x.shape[-1]))
ae_model = Autoencoder(input_shape=sample_x.shape[1:])

In [None]:
y, z_g = ae_model(sample_x, return_g=True)
print(y.shape)
print(z_g.shape)
z = ae_model.encode(sample_x)
print(z.shape)

ae_model.summary()

In [None]:
outdir = '/home/ingn/tmp/micron2-data/single_simclr'
import os
if os.path.exists(f'{outdir}/weights.h5'):
    ae_model.load_weights(f'{outdir}/weights.h5')

In [None]:
train_AE_simCLR(dataset, ae_model, batch_reps=1)

In [None]:
ae_model.save_weights(f'{outdir}/weights.h5')

In [None]:
!ls -lha trained_simclr

In [None]:
from matplotlib import pyplot as plt
sample_xout = ae_model(sample_x)
for j in range(sample_x.shape[-1]):
    print(f'channel {j}\t' +\
          f'{use_channels[j]:<10}\t'
          f'pred {tf.reduce_sum(sample_xout[...,j]).numpy():<4.2f}\t' +\
          f'real {tf.reduce_sum(sample_x[...,j]).numpy():<4.2f}')

idx = np.random.choice(sample_xout.shape[0])
jdx = np.random.choice(sample_xout.shape[-1])

# jdx = 0 

print(idx, jdx)
sx = sample_x.numpy()[idx, :,:, jdx]
sxout = sample_xout.numpy()[idx, :,:, jdx] 
print(jdx, sx.sum(), sxout.sum())

plt.figure()
plt.matshow(sx)# / sx.max())
plt.colorbar()

plt.figure()
plt.matshow(sxout)# / sxout.max())
plt.colorbar()