# LACSS Supervised Training Demo

This notebook shows the general workflow of supervised training an LACSS model from scratch. 

This data uses a small dataset from the [Cell Image Library](http://www.cellimagelibrary.org/home) collection.

We will go through these steps:

- Setup the data pipeline
- Initialize a model trainer
- Perform model training
- Visualize the results

## Setting up the environment

In [None]:
!pip install "lacss[train] @ git+https://github.com/jiyuuchc/lacss"

In [None]:
from pathlib import Path

import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np
import pprint

from lacss.utils import show_images

## Data pipeline

In [None]:
# First download the dataset

!wget -c https://data.mendeley.com/public-files/datasets/894mmsd9nj/files/568e524f-9a95-45a6-9f80-3619969c2a37/file_downloaded -O images.zip

import zipfile

data_path = Path('image_data')
with zipfile.ZipFile('images.zip', "r") as f:
    f.extractall(data_path)

img = imageio.imread(data_path / 'train' / '000_img.png')
gt = imageio.imread(data_path / 'train'/ '000_masks.png')

show_images([img, gt])

Lacss expect training data from a python generator that produces the following data:

```
x_data, y_data = (
  {
    "image": ndarray[B, W, H, C],
    "gt_locations": ndarray[B, N, 2]
  },
  {
    "gt_labels": ndarray[B, W, H]
  }
)
```

In [None]:
from lacss.data import img_mask_pair_generator
from lacss.data.utils import gf_batch, gf_cycle, image_standardization
import lacss.data.augment_ as augment

BATCHSIZE = 1

@gf_batch(batch_size=BATCHSIZE)
@gf_cycle
def dataset_generator():
    imgfiles = [data_path / 'train' / f'{k:03d}_img.png' for k in range(89)]
    maskfiles = [data_path / 'train'/ f'{k:03d}_masks.png' for k in range(89)]

    for data in img_mask_pair_generator(imgfiles, maskfiles):

        data['image_mask'] = data['label']

        # simple augmentation
        data = augment.flip_left_right(data, p = 0.5)
        data = augment.flip_up_down(data, p = 0.5)

        # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
        locations = data['centroids']
        n_pad = 512 - len(locations)
        locations = np.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

        yield (
            dict(
                image = image_standardization(data['image']),
                gt_locations = locations, 
            ),
            dict(
                gt_labels = data['image_mask'],
            ),
        )

## Initialize a trainer

The ```xtrain.Trainer``` class is the main interface we use for training. It needs a few things to start:

- A model definition.
- A optimizer from optax library
- Definiation of losses to be trained on
- An optional strategy specify the training backend to use. Here we used VMapped which is suitable for single GPU training on batched data.

In [None]:
import optax
from lacss.modules import Lacss
from lacss.losses import supervised_instance_loss
from xtrain import Trainer, VMapped
from lacss.train.train import train_fn

trainer = Trainer(
    model = Lacss.get_small_model(),
    optimizer = optax.adam(1e-4),
    losses = [ # losses are either strings or a function
        "losses/lpn_detection_loss", 
        "losses/lpn_localization_loss",
        supervised_instance_loss,
    ],
    strategy=VMapped,
)

it = trainer.train(dataset_generator, method=train_fn)

## Training

In [None]:
from tqdm import tqdm

n_steps = 15000
validation_interval = 3000

for step in tqdm(range(n_steps)):

    if (step + 1) % validation_interval == 0:
        print(it.loss)
        it.reset_loss_logs()

    next(it)


## Visualize the model prediction

In [None]:
from lacss.data.utils import image_standardization
from lacss.ops import patches_to_label
from skimage.color import label2rgb

image = imageio.imread(data_path/'test'/'000_img.png')
gt_label = imageio.imread(data_path/'test'/'000_masks.png')

# prediction
model_output = trainer.model.apply(
    dict(params = it.parameters),
    image = image_standardization(image),
)

label = patches_to_label(
    model_output, 
    input_size=image.shape[:2]
)

show_images([
    image,
    label2rgb(np.asarray(label), bg_label=0),
    label2rgb(gt_label, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

## What's more?

- You can train for more steps
- You can perform quantitative evaluation
- You can incorporate validation and checkpointing into the training loop
- You can export the trained model

Check the [documentation](https://jiyuuchc.github.io/lacss/api/deploy/) for details.