## LACSS Demo
This is a simple demo of semi-supervised LACSS training. We will train a model using the neuroblastoma dataset from the Cell Image Library. 


In [None]:
!git clone https://github.com/jiyuuchc/lacss.git
!wget https://data.mendeley.com/api/datasets/894mmsd9nj/draft/files/568e524f-9a95-45a6-9f80-3619969c2a37

In [None]:
import sys
from os.path import join

import matplotlib.pyplot as plt
import matplotlib.patches
import cv2
from skimage.color import label2rgb
import numpy as np
import tensorflow as tf
layers = tf.keras.layers

sys.path.append('/content/lacss/')
import lacss

### Setting up the data pipeline

In [None]:
import zipfile

data_path = 'data'
with zipfile.ZipFile('cil.zip', "r") as f:
    f.extractall(data_path)

imgfiles = [join(data_path, 'train', f'{k:03d}_img.png') for k in range(89)]
maskfiles = [join(data_path, 'train', f'{k:03d}_masks.png') for k in range(89)]

ds_train =lacss.data.dataset_from_img_mask_pairs(imgfiles, maskfiles)

In [None]:
# The orginal dataset contains the full segmentation annotation, which we will not use
# Therefore we setup a data parser to remove the extra annotation

def train_parser(x):
    del x['mask_indices']
    del x['bboxes']

    # the parser function in the lacss code apply simple data augmentation e.g. flipping and resizing
    x = lacss.data.parse_train_data_func(x, size_jitter=(0.9, 1.1))

    return x

ds_train = ds_train.map(train_parser)

# just in case, we remove samples without any cells in it
ds_train = ds_train.filter(lambda x : tf.shape(x['locations'])[0]>0).repeat()

# We will use ragged batching, and set batch_size = 1 so it will run on any GPU
ds_train = ds_train.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=1))


### Model configuration and training

In [None]:
# use the resnet50 backbone, we also disabled auxnet by setting the last loss_weight to 0
model = lacss.models.LacssModel(
    backbone='resnet_att', 
    train_supervised=False,
    instance_crop_size=128,
    loss_weights=(1.0, 1.0, 1.0, 0.0),
    )

# Use Adam at default setting
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer)

# train a bunch of steps
model.fit(ds_train, epochs=15, steps_per_epoch=1000)

### Display the results on the validation set 

In [None]:
# load the validation dataset
imgfiles = [join(data_path, 'test', f'{k:03d}_img.png') for k in range(11)]
maskfiles = [join(data_path, 'test', f'{k:03d}_masks.png') for k in range(11)]

ds_val =lacss.data.dataset_from_img_mask_pairs(imgfiles, maskfiles)

In [None]:
# Let's check the 3rd image in the validation set
for x in ds_val.skip(2).take(1):
    pass
x = lacss.data.parse_test_data_func(x) # this pad the input image so that ResNet won't complain
y = model(x)

# adjust the contrast of input image for better display
input_img = (x['image'] + 0.5)/6
input_img = (input_img.numpy()* 255).astype('uint8')

# we will use the RGB label format to dislay the ground truth
gt_label = tf.scatter_nd(x['mask_indices'].values, x['mask_indices'].value_rowids() + 1, x['image'].shape[:2])
gt_label_rgb = label2rgb(gt_label.numpy(), bg_label=0)

In [None]:
# add contours of model predictions on both input image and ground truth label
coords = y['instance_coords'][0][:110, ...]
patches = y['instance_output'][0][:110,:,:,0]
n_patches, patch_size, _ = patches.shape
page_n = tf.tile(tf.range(n_patches)[:,None,None,None], [1, patch_size, patch_size, 1])
coords_ext = tf.concat([page_n, coords], axis=-1)
stack_shape = [n_patches,] + x['image'].shape[:2].as_list()
img_stack = tf.scatter_nd(coords_ext, patches, stack_shape)
img_stack = (img_stack.numpy() >= 0.5).astype('uint8')
for page in img_stack:
    contours, _ = cv2.findContours(page, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(input_img, contours, -1, (128,128,128), 1, cv2.LINE_AA)
    cv2.drawContours(gt_label_rgb, contours, -1, (64,64,64), 1, cv2.LINE_AA)

# now display the results
fig,ax=plt.subplots(1, 2, figsize=(15,10))
ax[0].imshow(input_img)
ax[0].axis('off')
ax[1].imshow(gt_label_rgb)
ax[1].axis('off')
fig.tight_layout()