# LACSS Point-supervised Training Demo

The demo will train a model to segment microscopy images of cells, using only point label.

 * The point label was produced automatically from DAPI images

We will go through these steps:

- Setup the data pipeline

- Initialize a model trainer

- Perform model training

- Visualize the results

## Setting up the environment

In [None]:
!pip install "lacss[train] @ git+https://github.com/jiyuuchc/lacss"

In [None]:
from pathlib import Path

import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np

from skimage.color import label2rgb

import lacss.data
from lacss.utils import show_images

## Data pipeline

Lacss expect training data from a python generator that produces the following data:

```
{
  "image": ndarray[B, W, H, C],
  "gt_locations": ndarray[B, N, 2]
}
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
# Download the dataset
!wget -c https://data.mendeley.com/public-files/datasets/89s3ymz5wn/files/f976856c-08c5-4bba-85a7-3881e0593115/file_downloaded -O A431.zip

import zipfile
from matplotlib.patches import Circle

data_path = Path('image_data')
with zipfile.ZipFile('A431.zip', "r") as f:
    f.extractall(data_path)

In [None]:
from lacss.data import simple_generator
from lacss.data.utils import gf_batch, gf_cycle, image_standardization
import lacss.data.augment_ as augment

BATCHSIZE = 1

@gf_batch(batch_size=BATCHSIZE)
@gf_cycle
def dataset_generator():
    for data in simple_generator(data_path/"train.json", data_path/"train"):

        # simple augmentation
        data = augment.flip_left_right(data, p = 0.5)
        data = augment.flip_up_down(data, p = 0.5)

        # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
        locations = data['centroids']
        n_pad = 1024 - len(locations)
        locations = np.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

        yield dict(
            image = image_standardization(data['image']),
            gt_locations = locations,
        )

In [None]:
# show an example of the training data

data = next(dataset_generator())
img = data['image'][0]
img = img - img.min()
img /= img.max()
locations = data['gt_locations'][0]

show_images([
    img,
    np.zeros_like(img),
])
ax = plt.gcf().get_axes()
ax[0].set_title("Image")
for pos in locations:
    c = Circle((pos[1], pos[0]), radius=2, edgecolor='white')
    ax[1].add_patch(c)
ax[1].set_title("Label")

## Initialize a trainer

To speed up the training, we will start from a model pre-trained on LiveCell dataset (bright field microscopy)

In [None]:
import optax
from functools import partial
from ml_collections import ConfigDict
from lacss.train import Trainer, train_fn, CKS
from lacss.modules import Lacss

# Normally we don't segment all cells during training to save time
# But for CKS, we need to ensure all cells are segmented
config = ConfigDict()
config.max_training_instances = 1024

trainer = Trainer(
    model = Lacss.get_small_model(),
    optimizer = optax.adam(1e-4),
    losses = [], # losses is ignored by the CKS module
    strategy = CKS, # The CKS module implements the training logic
)

method = partial(train_fn, config=config)
it = trainer.train(dataset_generator, method=method)

## Training

In [None]:
from tqdm import tqdm

n_steps = 15000
validation_interval = 3000

for step in tqdm(range(n_steps)):

    if (step + 1) % validation_interval == 0:
        print(it.variables['cks'].loss)
        it.reset_loss_logs()

    next(it)


## Visualize  the model prediction

In [None]:
from lacss.deploy.predict import Predictor
from lacss.data.utils import image_standardization
from skimage.color import label2rgb

# get data
image = imageio.imread(data_path/'test'/'img_0001.tif')
gt_label = imageio.imread(data_path/'test'/'masks_0001.tif')

# predict
predictor = Predictor((model, it.parameters))
label = predictor.predict(image, score_threshold=0.4)["pred_label"]

show_images([
    image,
    label2rgb(np.asarray(label), bg_label=0),
    label2rgb(gt_label, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

## What's more?

- You can train for more steps
- You can perform quantitative evaluation
- You can incorporate validation and checkpointing into the training loop
- You can export the trained model

Check the [documentation](https://jiyuuchc.github.io/lacss/api/deploy/) for details.