# LACSS Supervised Training Demo

This notebook shows the general workflow of supervised training an LACSS model from scratch. 

This data uses a small dataset from the [Cell Image Library](http://www.cellimagelibrary.org/home) collection.

We will go through these steps:

- Setup the data pipeline
- Initialize a model trainer
- Perform model training
- Visualize the results

## Setting up the environment

In [None]:
!pip install git+https://github.com/jiyuuchc/lacss@lacss1

from os.path import join
import imageio.v2 as imageio
import json
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from skimage.color import label2rgb
from tqdm import tqdm
from pathlib import Path
from dataclasses import asdict

import lacss.data
from lacss.train import LacssTrainer, VMapped, TFDatasetAdapter
from lacss.ops import patches_to_label
from lacss.utils import show_images

## Data pipeline

In [None]:
# First download the dataset

!wget -c https://data.mendeley.com/public-files/datasets/894mmsd9nj/files/568e524f-9a95-45a6-9f80-3619969c2a37/file_downloaded -O images.zip

import zipfile

data_path = Path('image_data')
with zipfile.ZipFile('images.zip', "r") as f:
    f.extractall(data_path)

img = imageio.imread(data_path / 'train' / '000_img.png')
gt = imageio.imread(data_path / 'train'/ '000_masks.png')

show_images([
    img,
    gt,
])

Lacss expect training data from a python generator that produces the following data:

```
x_data, y_data = (
  {
    "image": ndarray[B, W, H, C],
    "gt_locations": ndarray[B, N, 2]
  },
  {
    "gt_labels": ndarray[B, W, H]
  }
)
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
from os.path import join

def parser(data):
    image = data['image']
    label = data['label']
    locations = data['centroids']

    height = tf.shape(image)[0]
    width = tf.shape(image)[1]

    # simple augmentations
    if tf.random.uniform(()) >= 0.5:
        image = tf.image.flip_left_right(image)
        label = label[:, ::-1]
        locations = locations * [1, -1] + [0, width]

    if tf.random.uniform(()) >= 0.5:
        image = tf.image.flip_up_down(image)
        label = label[::-1, :]
        locations = locations * [-1, 1] + [height, 0]

    # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
    n_pad = 512 - len(locations)
    locations = tf.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

    return (
        dict(
            image = image,
            gt_locations = locations, 
        ),
        dict(
            gt_labels = label,
        ),
    )

batch_size = 1
imgfiles = [join(data_path, 'train', f'{k:03d}_img.png') for k in range(89)]
maskfiles = [join(data_path, 'train', f'{k:03d}_masks.png') for k in range(89)]

# create a tensowflow dataset from the files on disk
ds = (
    lacss.data.dataset_from_img_mask_pairs(imgfiles, maskfiles)
    .map(parser)
    .repeat()
    .batch(batch_size)
)

# make sure the dataset has the correct element structure
ds.element_spec

## Initialize a trainer

The ```lacss.train.LacssTrainer``` class is the main interface we use for training. It needs a few things to start:

- A configuration dictionary to override the default model hyperparameters.
- An optional random seed value to control the process of stochastic grandient descent
- An optional strategy specify the training backend to use. Here we used VMapped which is suitable for single GPU training on batched data.

In [None]:
# Model configuration. We override a few default hyperparameters
cfg = {
  "backbone": {
    "drop_path_rate": 0.4
  },
  "segmentor": {
    "instance_crop_size": 128
  }
}

# LacssTrainer is the main class for model training
trainer = LacssTrainer(
    cfg,
    seed=1234, # RNG seed
    strategy=VMapped,
)

#current model hyper-parameters
from pprint import pprint

print("---Current model configuration---")
pprint(
    asdict(trainer.model.principal), 
)

## Training

In [None]:
n_steps = 15000
validation_interval = 3000

trainer.do_training(
  TFDatasetAdapter(ds),
  n_steps = n_steps,
  validation_interval = validation_interval,
)

## Visualize the model prediction

In [None]:
image = imageio.imread(data_path/'test'/'000_img.png')
gt = imageio.imread(data_path/'test'/'000_masks.png')

# prediction
model_output = trainer.model.apply(
    dict(params = trainer.parameters),
    image = image / 255,
)
pred = patches_to_label(
    model_output, 
    input_size=image.shape[:2]
)
pred = np.asarray(pred)

show_images([
    image,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

## What's more?

- You can train for more steps
- You can perform quantitative evaluation
- You can incorporate validation and checkpointing into the training loop
- You can export the trained model

Check the [documentation](https://jiyuuchc.github.io/lacss/api/deploy/) for details.