# LACSS Supervised Training Demo

This notebook shows the general workflow of supervised training an LACSS model from scratch. 

This data uses a small dataset from the [Cell Image Library](http://www.cellimagelibrary.org/home) collection.

We will go through these steps:

- Setup the data pipeline
- Initialize a model trainer
- Perform model training
- Visualize the results

## Setting up the environment

In [None]:
!pip install "jax[cuda12]==0.4.28"
!pip install git+https://github.com/jiyuuchc/lacss

from pathlib import Path

import imageio.v2 as imageio
import jax
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from pprint import pprint

from lacss.utils import show_images

## Data pipeline

In [None]:
# download the dataset
!wget -c https://data.mendeley.com/public-files/datasets/894mmsd9nj/files/568e524f-9a95-45a6-9f80-3619969c2a37/file_downloaded -O images.zip

import zipfile

data_path = Path('image_data')

with zipfile.ZipFile('images.zip', "r") as f:
    f.extractall(data_path)

show_images([
    imageio.imread(data_path / 'train' / '000_img.png'),
    imageio.imread(data_path / 'train'/ '000_masks.png'),
])

In [None]:
from functools import partial
from os.path import join

import lacss.data

def parser(data):
    image = data['image']
    label = data['label']
    locations = data['centroids']

    # normalize image
    image = tf.image.per_image_standardization(image)

    # pad the locations tensor so that all elements of the dataset are of the same shape
    n_pad = 512 - len(locations)
    locations = tf.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

    return (
        dict(
            image = image,
            gt_locations = locations, 
        ),  # these are inputs to the model, the dict keys match the model's argnames
        dict(
            gt_labels = label,
        ), # these are extra labels for the training
    )

imgfiles = [data_path/'train'/f'{k:03d}_img.png' for k in range(89)]
maskfiles = [data_path/'train'/f'{k:03d}_masks.png' for k in range(89)]

# create a tensowflow dataset from the files on disk
ds = (
    lacss.data.dataset_from_img_mask_pairs(imgfiles, maskfiles)
    .map(partial(lacss.data.flip_left_right, p=0.5))
    .map(partial(lacss.data.flip_up_down, p=0.5))
    .map(parser)
    .repeat()
    .prefetch(1)
)

# make sure the dataset has the correct element structure
ds.element_spec

## Initialize a trainer

The ```lacss.train.LacssTrainer``` class is the main interface we use for training. It needs a few things to start:

- A configuration dictionary to override the default model hyperparameters.
- An optional random seed value to control the process of stochastic grandient descent
- An optional strategy specify the training backend to use. Here we used VMapped which is suitable for single GPU training on batched data.

In [None]:
from dataclasses import asdict

from xtrain import Trainer, VMapped
from lacss.modules import Lacss
from lacss.losses import supervised_instance_loss

model = Lacss.get_small_model()
model.detector.max_output = 256 # reduce max number of cells per image to save a bit time

trainer = Trainer(
    model,
    losses = (
      "losses/lpn_detection_loss", 
      "losses/lpn_localization_loss",
      supervised_instance_loss,
    ),
    optimizer = optax.adamw(1e-4),
)

#current model hyper-parameters

print("---Current model configuration---")
pprint(
    asdict(trainer.model), 
)

## Training

In [None]:
from tqdm import tqdm

n_steps = 15000
validation_interval = 3000

train_it = trainer.train(ds, rng_cols="dropout", training=True)

with tqdm(total=n_steps) as pbar:
    while train_it.step < n_steps:
        pred = next(train_it)
        pbar.update(1)

        if train_it.step % validation_interval == 0:
            pprint(train_it.loss_logs)
            train_it.reset_loss_logs()

## Visualize the model prediction

In [None]:
from skimage.color import label2rgb
from lacss.ops import patches_to_label

image = imageio.imread(data_path/'test'/'000_img.png')
gt = imageio.imread(data_path/'test'/'000_masks.png')

#noramlize image
img = tf.image.per_image_standardization(image).numpy()

# prediction
model_output = trainer.model.apply(
    dict(params = train_it.parameters),
    image = img,
)
pred = patches_to_label(
    model_output["predictions"], 
    input_size=image.shape[:2]
)
pred = np.asarray(pred)

show_images([
    image,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

## What's more?

- You can train for more steps
- You can perform quantitative evaluation
- You can incorporate validation and checkpointing into the training loop
- You can export the trained model

Check the [documentation](https://jiyuuchc.github.io/lacss/api/deploy/) for details.