# LACSS Weakly-supervised Training Demo

The demo will train a model to segment microscopy images of cells, using point label + mask label.

 * The point label was produced automatically from DAPI images

 * The image-level mask label was produced manually.

We will go through these steps:

- Setup the data pipeline

- Initialize a model trainer

- Perform model training

- Visualize the results

## Setting up the environment

In [None]:
!pip install git+https://github.com/jiyuuchc/lacss@lacss1

import imageio.v2 as imageio
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from skimage.color import label2rgb
from tqdm import tqdm
from pathlib import Path
from flax.core.frozen_dict import freeze, unfreeze

import lacss.data
from lacss.train import LacssTrainer, VMapped, TFDatasetAdapter
from lacss.ops import patches_to_label
from lacss.utils import show_images

## Data pipeline

Lacss expect training data from a python generator that produces the following data:

```
x_data, y_data = (
  {
    "image": ndarray[B, W, H, C],
    "gt_locations": ndarray[B, N, 2]
  },
  {
    "gt_image_mask": ndarray[B, W, H]
  }
)
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
# Download the dataset
!wget -c https://data.mendeley.com/public-files/datasets/89s3ymz5wn/files/f976856c-08c5-4bba-85a7-3881e0593115/file_downloaded -O A431.zip

import zipfile

data_path = Path('image_data')
with zipfile.ZipFile('A431.zip', "r") as f:
    f.extractall(data_path)

In [None]:
batch_size = 1

def parser(data):
    # build-in data augmentation function
    data = lacss.data.random_resize(data, scaling=[0.8, 1.2])
    data = lacss.data.random_crop_or_pad(data, target_size=[512,512])

    # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
    locations = data['centroids']
    n_pad = 768 - len(locations)
    locations = tf.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

    return (
        dict(
            image = tf.ensure_shape(data['image'], [512,512,1]),
            gt_locations = tf.ensure_shape(locations, [768,2]) 
        ),
        dict(
            gt_image_mask = data['image_mask'],
        ),
    )

# create a tensowflow dataset from the files on disk
ds = (
    lacss.data.dataset_from_simple_annotations(
        data_path/"train.json",
        data_path/"train",
        image_shape=[512, 512, 1]
    )
    .map(parser)
    .repeat()
    .batch(batch_size)
    .prefetch(10)
)

# make sure the dataset has the correct element structure
ds.element_spec

In [None]:
# show an example of the training data

from matplotlib.patches import Circle

x_data, y_data = next(ds.as_numpy_iterator())
img = x_data['image'][0]
locations = x_data['gt_locations'][0]
mask = y_data['gt_image_mask'][0]

show_images([
    img,
    np.stack([mask]*3, axis=-1) * 0.5,
])

ax = plt.gcf().get_axes()
ax[0].set_title("Image")
for pos in locations:
    c = Circle((pos[1], pos[0]), radius=2, edgecolor='white')
    ax[1].add_patch(c)
ax[1].set_title("Label")

## Initialize a trainer

The idea is to co-train two models: a principal model and a collaborator model

In [None]:
# configuration for the principal model
principal_cfg = {
    "backbone": {
        "drop_path_rate": 0.4
    },
    "lpn": {
        "conv_spec": ((256,256,256,256),()),
    },
    "segmentor": {
        "conv_spec": ((256,256,256),(64,)),
    },    
}
# use default setting for collaborator model
collaborator_cfg = {} 

trainer = LacssTrainer(
    principal_cfg, 
    collaborator_cfg, 
    seed=1234,
    strategy=VMapped,
)

In [None]:
from pprint import pprint

print("---Current model configuration---")
pprint(
    trainer.model.principal.get_config(),
)

## Training

Trainer.train() function returns an iterator, stepping through which will drive the training of the model.

In [None]:
n_steps = 9000
validation_interval = 3000

trainer.do_training(
  TFDatasetAdapter(ds),
  n_steps = n_steps,
  validation_interval = validation_interval,
)

## Visualize  the model prediction

In [None]:
# get data
img = imageio.imread(data_path/'test'/'img_0001.tif')
gt = imageio.imread(data_path/'test'/'masks_0001.tif')
img = img[..., None]

# prediction
model_output = trainer.model.apply(
    dict(params=trainer.parameters),
    image = img,
)
pred = patches_to_label(model_output, input_size=img.shape[:2])
pred = np.asarray(pred)

lacss.utils.show_images([
    img,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]