In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import pandas as pd

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

from micron2.clustering import Encoder, train_simCLR
from micron2.data import stream_dataset

import h5py
import tqdm.auto as tqdm

In [None]:
use_channels = ['DAPI', 'CD45', 'PanCytoK', 'CD3e', 'CD4', 'CD8', 'PDGFRb', 'CD20', 'CD68', 'IgG', 'C1q']
with h5py.File('/home/ingn/tmp/micron2-data/dataset.hdf5', 'r') as f:
    all_channels = [b.decode('UTF-8') for b in f['meta/channel_names'][:]]
    print(all_channels)

    means = tf.constant([f[f'intensity/{c}'].attrs['mean'] for c in use_channels], 
                        dtype=tf.float32)
    stds = tf.constant([f[f'intensity/{c}'].attrs['std'] for c in use_channels],
                       dtype=tf.float32)
    print(means)
    print(stds)


In [None]:
print(use_channels)

def process(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)
    x = (x - means) / stds
    
    # x = tf.cast(x, tf.float32)/255.
    # x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    return x
    
dataset = stream_dataset('/home/ingn/tmp/micron2-data/dataset.hdf5', use_channels=use_channels)
dataset = (dataset.repeat(125)
           .shuffle(1024 * 8)
           .map(process)
           .batch(256)
           .prefetch(16)
           #.apply(tf.data.experimental.prefetch_to_device("/gpu:0"))
          )

for sample_x in dataset:
    break
    
print(sample_x.shape)
for k in range(sample_x.shape[-1]):
    print(use_channels[k], sample_x.numpy()[...,k].mean())

In [None]:
sample_x = tf.image.random_crop(sample_x, size=(sample_x.shape[0], 48, 48, sample_x.shape[-1]))
model = Encoder(input_shape=sample_x.shape[1:])

In [None]:
y, z_g = model(sample_x, return_g=True)
print(y.shape)
print(z_g.shape)
z = model.encode(sample_x)
print(z.shape)

model.summary()

In [None]:
outdir = '/home/ingn/tmp/micron2-data/single_simclr'
with open(f'{outdir}/use_channels.txt', 'w+') as f:
    for c in use_channels:
        f.write(f'{c}\n')
# import os
# if os.path.exists(f'{outdir}/weights.h5'):
#     model.load_weights(f'{outdir}/weights.h5')

In [None]:
loss_history = train_simCLR(dataset, model, batch_reps=1)

In [None]:
from matplotlib import pyplot as plt
lh = len(loss_history)
plt.plot(np.arange(lh), loss_history)
# plt.plot(np.log1p(np.arange(lh)), loss_history)

In [None]:
model.save_weights(f'{outdir}/weights.h5')

# Process slide

In [None]:
crop_frac = 48 / 64
def process_crop(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)
    x = (x - means) / stds
    
    # x = tf.cast(x, tf.float32)/255.
    # x = tf.transpose(tf.image.per_image_standardization(tf.transpose(x)))
    x = tf.image.central_crop(x, crop_frac)
    return x

dataset = stream_dataset('/home/ingn/tmp/micron2-data/dataset.hdf5', use_channels=use_channels)
dataset = (dataset.map(process_crop)
           .batch(128)
           .prefetch(8)
          )

z = []
for batch in tqdm.tqdm(dataset):
    z.append(model.encode(batch, training=False).numpy())
    
z = np.concatenate(z, axis=0)
print(z.shape)

In [None]:
np.save(f'{outdir}/z.npy', z)