# LACSS Point-supervised Training Demo

The demo will train a model to segment microscopy images of cells, using only point label.

 * The point label was produced automatically from DAPI images

We will go through these steps:

- Setup the data pipeline

- Initialize a model trainer

- Perform model training

- Visualize the results

## Setting up the environment

In [None]:
!pip install git+https://github.com/jiyuuchc/lacss@lacss1

import imageio.v2 as imageio
import json
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from skimage.color import label2rgb
from tqdm import tqdm
from pathlib import Path
from flax.core.frozen_dict import freeze, unfreeze

import lacss.data
from lacss.train import LacssTrainer, VMapped, TFDatasetAdapter
from lacss.ops import patches_to_label
from lacss.utils import show_images, load_from_pretrained
from lacss.deploy import model_urls

## Data pipeline

Lacss expect training data from a python generator that produces the following data:

```
{
  "image": ndarray[B, W, H, C],
  "gt_locations": ndarray[B, N, 2]
}
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
# Download the dataset
!wget -c https://data.mendeley.com/public-files/datasets/89s3ymz5wn/files/f976856c-08c5-4bba-85a7-3881e0593115/file_downloaded -O A431.zip

import zipfile
from matplotlib.patches import Circle

data_path = Path('image_data')
with zipfile.ZipFile('A431.zip', "r") as f:
    f.extractall(data_path)

In [None]:
batch_size = 1

def parser(data):
    # build-in data augmentation function
    data["image"] = tf.image.per_image_standardization(data["image"])
    data = lacss.data.random_resize(data, scaling=[.8, 1.2])
    data = lacss.data.random_crop_or_pad(data, target_size=[512,512])

    return dict(
      image = data['image'],
      gt_locations = data["centroids"],
    )

# create a tensowflow dataset from the files on disk
ds = (
    lacss.data.dataset_from_simple_annotations(
        data_path/"train.json",
        data_path/"train",
        image_shape=[None, None, 1]
    )
    .map(parser)
    .repeat()
    .padded_batch(
        batch_size,
        padded_shapes=dict(
            image=[512,512,1],
            gt_locations=[768,2],
        ),
        padding_values=-1.0,
    )
    .prefetch(1)
)

# make sure the dataset has the correct element structure
ds.element_spec

In [None]:
# show an example of the training data

data = next(iter(ds))
img = data['image'][0]
locations = data['gt_locations'][0]

show_images([
    img,
    np.zeros_like(img),
])
ax = plt.gcf().get_axes()
ax[0].set_title("Image")
for pos in locations:
    c = Circle((pos[1], pos[0]), radius=2, edgecolor='white')
    ax[1].add_patch(c)
ax[1].set_title("Label")

## Initialize a trainer

We will use transfer learning by starting from a pre-trained model. Transfer learning is generally beneficial even if the orginal model was trained on data that looks very different from the current images.

The main training interface here is ```LacssTrainer```


In [None]:
# Load a pretrained model
# This model was trained on bright field microscopy images (LIVECell dataset)
# This will be serve as the principal model and retrained with the new data
!wget -c {model_urls["cnsp4-bf"]} -O cnsp4_bf
pretrained_module, pretrained_params = load_from_pretrained("cnsp4_bf")
principal_cfg = pretrained_module.get_config()

collaborator_cfg = {} # use default config for collaborator model

trainer = LacssTrainer(
    principal_cfg,
    collaborator_cfg,
    strategy=VMapped,
)

# create random weights
params = trainer.get_init_params(TFDatasetAdapter(ds))

# Now merge with the pre-trained weights
params['principal'] = pretrained_params

In [None]:
from pprint import pprint

print("---Current model configuration---")
pprint(
    trainer.model.principal.get_config(),
    sort_dicts=False
)

## Training

In [None]:
n_steps = 12000
validation_interval = 3000

trainer.do_training(
    TFDatasetAdapter(ds),
    n_steps = n_steps,
    validation_interval = validation_interval,
    init_vars = dict(params=params),
)

## Visualize  the model prediction

In [None]:
# get data
image = imageio.imread(data_path/'test'/'img_0001.tif')
gt = imageio.imread(data_path/'test'/'masks_0001.tif')
img = image - image.mean()
img /= img.std()
img = img[..., None]

# prediction
model_output = trainer.model.apply(
    dict(params=trainer.parameters),
    image = img,
)
pred = patches_to_label(
    model_output, 
    input_size=img.shape[:2]
)
pred = np.asarray(pred)

# display
show_images([
    img,
    pred,
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

## What's more?

- You can train for more steps
- You can perform quantitative evaluation
- You can incorporate validation and checkpointing into the training loop
- You can export the trained model

Check the [documentation](https://jiyuuchc.github.io/lacss/api/deploy/) for details.