In [1]:
%load_ext autoreload
%autoreload 2

# Functions & Imports

In [6]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), os.pardir))

import numpy as np
import pandas as pd
import tensorflow as tf
from Tools.leica_tools import RawLoader
from Tools.db_tools import DbManager
from functools import partial
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

import keras
from keras.api.optimizers import Adam
from keras.api.models import Model, load_model
from keras.api.layers import Input, Conv2D, Flatten, Dense, MaxPooling2D, Dropout
from keras.api.losses import CategoricalCrossentropy
from keras.api.metrics import CategoricalAccuracy
from keras.api.utils import plot_model

# Cell count classification

In [3]:
def cell_count(inputs, cls_label):
    conv1 = Conv2D(32, (3, 3), activation='relu', name=cls_label + '_conv1')(inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2), name=cls_label + '_pool1')(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', name=cls_label + '_conv2')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2), name=cls_label + '_pool2')(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', name=cls_label + '_conv3')(pool2)
    pool3 = MaxPooling2D(pool_size=(2, 2), name=cls_label + '_pool3')(conv3)

    flatten = Flatten(name=cls_label + '_flatten')(pool3)

    dense1 = Dense(512, activation='relu', name=cls_label + '_dense1')(flatten)
    dropout1 = Dropout(0.5, name=cls_label + '_dropout1')(dense1)

    dense2 = Dense(256, activation='relu', name=cls_label + '_dense2')(dropout1)
    dropout2 = Dropout(0.5, name=cls_label + '_dropout2')(dense2)

    dense3 = Dense(128, activation='relu', name=cls_label + '_dense3')(dropout2)
    dropout3 = Dropout(0.5, name=cls_label + '_dropout3')(dense3)

    output = Dense(5, activation='softmax', name=cls_label + '_output')(dropout3)

    model = Model(inputs=inputs, outputs=output, name=cls_label + '_model')

    return model

# Depending on how many labels need to be predicted from the droplets, several CNNs will be added to the model in parallel
def get_model(labels):
    input = Input(shape=(128,128,4), name='cell_count_input')
    models = [cell_count(input, label) for label in labels]
    model = Model(inputs=[input], outputs=[m.output for m in models])
    model.compile(optimizer=Adam(),
                  loss={label+'_output': CategoricalCrossentropy() for label in labels},
                  metrics={label+'_output': CategoricalAccuracy() for label in labels})
    return model

Building the tensorflow dataset from .tfrecord files. Importantly, not all frames from a droplet database can be used. 
Only droplets that are fully annotated and were not excluded as outliers should be added to teh training dataset.

In [13]:
#Build dataset from annotated data from multiple experiments
def build_dataset(expIDs, annotation_keys):
    dataset = dbm.get_datasets(expIDs, shuffle=True)
    
    # Scan through the droplets which droplets contain annotations and which droplets need to be excluded
    filter_dfs = []
    ann_dfs = []
    for expID in expIDs:
        drop_register = RawLoader(expID).get_droplet_df()
        ann_df = dbm.get_wps(expID, filter_annotations='full').set_index('GlobalID').filter(annotation_keys)
        ann_df = ann_df[ann_df.apply(lambda row: (row != 10).all(), axis=1)].copy()
        ann_df[ann_df > 4] = 4


        filter_df = pd.Series({ID: (ID in ann_df.index) for ID in drop_register.index}, name='include').to_frame()
        filter_df.set_index(pd.MultiIndex.from_product([[expID], filter_df.index]), inplace=True)
        filter_dfs.append(filter_df)

        ann_df.set_index(pd.MultiIndex.from_product([[expID], ann_df.index]), inplace=True)
        ann_dfs.append(ann_df)
    filter_df = pd.concat(filter_dfs)
    ann_df = pd.concat(ann_dfs)

    filtered_dataset = dataset.filter(partial(filter_dataset, filter_df=filter_df))
    annotated_dataset = filtered_dataset.map(partial(prepare_data, annotations=ann_df))
    return annotated_dataset

def filter_dataset(element, filter_df):
    return tf.py_function(lambda x, i: filter_df.loc[(x.numpy().decode(), i.numpy()), 'include'], [element['expID'], element['GlobalID']], tf.bool)
    
def prepare_data(element, annotations):
    globalID = element['GlobalID']
    image = tf.cast(element['frame'], tf.float32)
    image = tf.math.log(image+1)
    image = (image - tf.reduce_min(image)) /(tf.reduce_max(image) - tf.reduce_min(image))
    element['cell_count_input'] = image
    
    outputs = {}
    for ann_key in annotations.columns:
        label = tf.py_function(lambda x, i: annotations.loc[(x.numpy().decode(), i.numpy()), ann_key], [element['expID'], globalID], tf.int64)
        label.set_shape(())
        label = tf.cast(tf.one_hot(label, 5), tf.int64)
        outputs[ann_key + '_output'] = label
    return element, outputs
    

    
    

In [14]:
dbm = DbManager()

In [15]:
expIDs = ['NKIP_FA_066', ]
annotation_keys = ['Target', 'Effector', 'dead_Target', 'dead_Effector']
dataset = build_dataset(expIDs, annotation_keys)
validation_dataset = build_dataset(['NKIP_FA_063'], annotation_keys)

In [16]:
n_elements = dataset.reduce(tf.constant(0), lambda a,b: a+1).numpy()
n_elements_val = validation_dataset.reduce(tf.constant(0), lambda a,b: a+1).numpy()
print(f'{n_elements} frames in train dataset')
print(f'{n_elements_val} frames in test dataset')

499 frames in train dataset
177 frames in test dataset


In [17]:
train_final = dataset.repeat(12).batch(32)
test_final = validation_dataset.repeat(12).batch(32)

In [22]:
model = keras.models.load_model(os.path.join(os.getenv('MODEL_DIR'), 'cell_count', 'cell_count_v2.h5'))
labels = annotation_keys
model.compile(optimizer=Adam(),
              loss={label+'_output': CategoricalCrossentropy() for label in labels},
              metrics={label+'_output': CategoricalAccuracy() for label in labels})



In [19]:
#model_arch = plot_model(model, to_file='cell_count.png', dpi=100)

In [23]:
model.fit(train_final, validation_data=test_final, batch_size=32, steps_per_epoch=187, epochs=12, validation_steps=66)

Epoch 1/12
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m319s[0m 2s/step - Effector_output_categorical_accuracy: 0.8519 - Target_output_categorical_accuracy: 0.8583 - dead_Effector_output_categorical_accuracy: 0.9531 - dead_Target_output_categorical_accuracy: 0.9442 - loss: 1.0853 - val_Effector_output_categorical_accuracy: 0.6960 - val_Target_output_categorical_accuracy: 0.7794 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.9091 - val_loss: 3.4988
Epoch 2/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 112ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 0.7500 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.5247

2024-09-12 13:40:59.893223: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
  self.gen.throw(value)


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 0.7500 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.5247 - val_Effector_output_categorical_accuracy: 0.5000 - val_Target_output_categorical_accuracy: 0.9167 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 1.0000 - val_loss: 2.6423
Epoch 3/12


2024-09-12 13:41:00.226669: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m322s[0m 2s/step - Effector_output_categorical_accuracy: 0.9526 - Target_output_categorical_accuracy: 0.9519 - dead_Effector_output_categorical_accuracy: 0.9863 - dead_Target_output_categorical_accuracy: 0.9794 - loss: 0.3445 - val_Effector_output_categorical_accuracy: 0.6955 - val_Target_output_categorical_accuracy: 0.8082 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.9091 - val_loss: 5.2684
Epoch 4/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m21s[0m 113ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0577

2024-09-12 13:46:22.218094: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0577 - val_Effector_output_categorical_accuracy: 0.5833 - val_Target_output_categorical_accuracy: 0.7500 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 1.0000 - val_loss: 5.4974
Epoch 5/12


2024-09-12 13:46:22.544702: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m325s[0m 2s/step - Effector_output_categorical_accuracy: 0.9815 - Target_output_categorical_accuracy: 0.9813 - dead_Effector_output_categorical_accuracy: 0.9980 - dead_Target_output_categorical_accuracy: 0.9912 - loss: 0.1395 - val_Effector_output_categorical_accuracy: 0.7576 - val_Target_output_categorical_accuracy: 0.8021 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.9096 - val_loss: 5.4244
Epoch 6/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 113ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0949

2024-09-12 13:51:47.706149: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0949 - val_Effector_output_categorical_accuracy: 0.6667 - val_Target_output_categorical_accuracy: 0.8333 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 0.9167 - val_loss: 7.0302
Epoch 7/12


2024-09-12 13:51:48.028166: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m314s[0m 2s/step - Effector_output_categorical_accuracy: 0.9819 - Target_output_categorical_accuracy: 0.9767 - dead_Effector_output_categorical_accuracy: 0.9906 - dead_Target_output_categorical_accuracy: 0.9837 - loss: 0.2185 - val_Effector_output_categorical_accuracy: 0.7576 - val_Target_output_categorical_accuracy: 0.7623 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.9091 - val_loss: 5.4523
Epoch 8/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m19s[0m 105ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0027

2024-09-12 13:57:02.095415: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0027 - val_Effector_output_categorical_accuracy: 0.6667 - val_Target_output_categorical_accuracy: 0.8333 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 0.9167 - val_loss: 6.3984
Epoch 9/12


2024-09-12 13:57:02.399645: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m313s[0m 2s/step - Effector_output_categorical_accuracy: 0.9760 - Target_output_categorical_accuracy: 0.9880 - dead_Effector_output_categorical_accuracy: 0.9948 - dead_Target_output_categorical_accuracy: 0.9916 - loss: 0.1573 - val_Effector_output_categorical_accuracy: 0.5824 - val_Target_output_categorical_accuracy: 0.8419 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.9266 - val_loss: 9.1942
Epoch 10/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 112ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0040

2024-09-12 14:02:15.666205: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0040 - val_Effector_output_categorical_accuracy: 0.5000 - val_Target_output_categorical_accuracy: 0.7500 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 0.9167 - val_loss: 10.0517
Epoch 11/12


2024-09-12 14:02:16.029792: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m325s[0m 2s/step - Effector_output_categorical_accuracy: 0.9911 - Target_output_categorical_accuracy: 0.9825 - dead_Effector_output_categorical_accuracy: 0.9993 - dead_Target_output_categorical_accuracy: 0.9963 - loss: 0.1189 - val_Effector_output_categorical_accuracy: 0.5710 - val_Target_output_categorical_accuracy: 0.8310 - val_dead_Effector_output_categorical_accuracy: 0.9943 - val_dead_Target_output_categorical_accuracy: 0.8977 - val_loss: 9.0282
Epoch 12/12
[1m  1/187[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 110ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0424

2024-09-12 14:07:41.255438: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - Effector_output_categorical_accuracy: 1.0000 - Target_output_categorical_accuracy: 1.0000 - dead_Effector_output_categorical_accuracy: 1.0000 - dead_Target_output_categorical_accuracy: 1.0000 - loss: 0.0424 - val_Effector_output_categorical_accuracy: 0.5000 - val_Target_output_categorical_accuracy: 0.7500 - val_dead_Effector_output_categorical_accuracy: 1.0000 - val_dead_Target_output_categorical_accuracy: 1.0000 - val_loss: 8.5062


2024-09-12 14:07:41.645874: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]


<keras.src.callbacks.history.History at 0x36bc93aa0>

In [None]:
model.save(os.path.join(os.getenv('MODEL_DIR'),'cell_count', 'cell_count_v3.h5'))

# Evaluate performance

In [None]:
dbm = DbManager()
expIDs = [ 'NKIP_FA_052', 'NKIP_FA_055', 'NKIP_FA_056', 'FA_2024_049', 'FA_2024_050', 'AT_2024_007',]
annotation_keys = ['Target', 'Effector', 'dead_Target', 'dead_Effector']
dataset = build_dataset(expIDs, annotation_keys)

In [None]:
model = load_model(os.path.join(os.getenv('MODEL_DIR'), 'cell_count_v2.h5'))
predictions = model.predict(dataset.batch(32))
y_pred = pd.DataFrame(np.argmax(np.array(predictions), axis=-1).transpose(), columns=annotation_keys)

y_true = y_pred.copy()
for i, (element, output) in enumerate(dataset.as_numpy_iterator()):
    y_true.iloc[i, :] = np.array([np.argmax(output[f'{key}_output']) for key in annotation_keys])

In [None]:
ann2title = {'Target': 'Target',
             'Effector': 'Effector',
             'dead_Target': 'Dead Target',
             'dead_Effector': 'Dead Effector'}

fig,axs = plt.subplots(ncols=2, nrows=2, figsize=(4,4), dpi=100, sharey=True, sharex=True)
for ax, ann in zip(axs.flatten(), annotation_keys):
    ConfusionMatrixDisplay.from_predictions(y_true=y_true[ann], y_pred=y_pred[ann], labels=np.arange(5),cmap='Blues',ax=ax,colorbar=False, normalize='true', values_format='.2f', text_kw={'fontsize': 7})
    ax.grid(False)
    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_title(ann2title[ann])

axs.flatten()[0].set_ylabel('True number of cells')
axs.flatten()[2].set_ylabel('True number of cells')
axs.flatten()[2].set_xlabel('Predicted number of cells')
axs.flatten()[3].set_xlabel('Predicted number of cells')