In [None]:
%load_ext autoreload
%autoreload 2

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

import os
import sys

from micron2.clustering import Encoder, Classifier

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
# Change this cell to load from multiple samples, and concatenate. 
# make sure to place channels correctly, in case they're permuted for some reason between collections.
import glob
srch = '/storage/codex/preprocessed_data/pembro_TLS_panel/210702_PembroRT_Cas19_TLSpanel_reg2/collection*'

image_files = sorted(glob.glob(f'{srch}/training_cells_images.npy'))
annot_files = sorted(glob.glob(f'{srch}/training_cells_annots.npy'))
channel_files = sorted(glob.glob(f'{srch}/training_cells_channels.npy'))

image_files

images = np.load(image_files[0])
annots = np.load(annot_files[0])
channels = np.load(channel_files[0])

channels = {k:i for i,k in enumerate(channels)}

print(images.shape, images.dtype)
print(annots.shape)
print(np.unique(annots, return_counts=True))
print(channels)

perm = np.random.choice(images.shape[0], images.shape[0], replace=False)
images = images[perm]
annots = annots[perm]

u_annots, annots_int = np.unique(annots, return_inverse=True)
labels = np.eye(len(u_annots))[annots_int]
# labels[annots == '', :] = 0 
labels = labels[:,1:]

print(labels.shape)
print('labelled cells:', np.sum(np.sum(labels, axis=1)>0))

In [None]:
weights_path = '/home/ingn/devel/micron2/notebooks/devel_debug/moco-models/weights_cls.h5'

In [None]:
data_shape = [52, 52, len(channels)]
print(data_shape)

x_dummy = tf.zeros([1] + data_shape, dtype=tf.float32)
print(x_dummy.shape)

encoder = Encoder(data_shape=data_shape, z_dim=256, encoder_type='EfficientNetB1')
z = encoder(x_dummy)

classifier = Classifier(encoder=encoder, n_classes=labels.shape[1], mlp_dim=128) 
y = classifier(x_dummy)

print(z.shape)
print(y.shape)

classifier.load_weights(weights_path)

In [None]:
def process(x):
    """
    x is [N, h, w, c]
    """
    x = tf.cast(x, tf.float32)/255.
    x = tf.image.central_crop(x, 0.8125)
    
    return x
    
# Create a dataset with infinite repeating 
batch_size = 32
dataset = (tf.data.Dataset.from_tensor_slices(images)
           .map(process, num_parallel_calls=4)
           .batch(batch_size, drop_remainder=False)
           .prefetch(1)
          )

for batch in dataset:
    print(batch.shape)
    break

In [None]:
import tqdm.auto as tqdm

preds = []
for b in tqdm.tqdm(dataset):
    y = classifier(b, training=False)
    preds.append(y.numpy().copy())
    
preds = np.concatenate(preds, axis=0)
print(preds.shape)

In [None]:
y = np.argmax(preds, axis=1)

In [None]:
label_map = {v:i-1 for i,v in enumerate(u_annots) if i>0}
label_map

In [None]:
np.unique(y, return_counts=True)

In [None]:
has_labels = labels.sum(axis=1) > 0
np.mean(y[has_labels] == np.argmax(labels, axis=1)[has_labels])

In [None]:
np.unique(np.argmax(labels, axis=1)[has_labels], return_counts=True)

In [None]:
chs = list(channels.keys())

fig = plt.figure(figsize=(20,12))
# inds = np.random.choice(images.shape[0], 4, replace=False)

u_cts = np.unique(annots)

m = 1
for i,ct in enumerate(label_map.keys()):
    inds = np.nonzero(y==label_map[ct])[0] 
    if len(inds)==0:
        continue
    ind = np.random.choice(inds)
    
    for j,ch in enumerate(chs):
        ax = fig.add_subplot(len(u_cts),len(chs),m)
        m += 1
        
        k = channels[ch]
        img = images[ind, :, :, k] / 255.
        
        b = ax.matshow(img, vmin=0, vmax=1)
        # plt.colorbar(b, ax=ax)
        ax.axis('off')
    
        if i==0:
            ax.set_title(ch)
        
        if j==0:
            ax.annotate(ct, (0.05, 0.8), xycoords='axes fraction', color='w')