# LACSS Supervised Training Demo

This notebook shows the general workflow of supervised training an LACSS model from scratch. 

This data uses a small dataset from the [Cell Image Library](http://www.cellimagelibrary.org/home) collection.

We will go through these steps:

- Setup the data pipeline
- Initialize a model trainer
- Perform model training
- Visualize the results

## Setting up the environment

In [None]:
!pip install lacss

import imageio
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf

from skimage.color import label2rgb
from tqdm import tqdm
from os.path import join

import lacss

## Data pipeline

In [None]:
# First download the dataset

!wget -c https://data.mendeley.com/public-files/datasets/894mmsd9nj/files/568e524f-9a95-45a6-9f80-3619969c2a37/file_downloaded -O images.zip

import zipfile

data_path = 'image_data'
with zipfile.ZipFile('images.zip', "r") as f:
    f.extractall(data_path)

img = imageio.imread(join(data_path, 'train', '000_img.png'))
gt = imageio.imread(join(data_path, 'train', '000_masks.png'))

lacss.utils.show_images([
    img,
    gt,
])

Lacss expect training data from a python generator that produces the following data:

```
x_data, y_data = (
  {
    "image": ndarray[B, W, H, C],
    "gt_locations": ndarray[B, N, 2]
  },
  {
    "gt_labels": ndarray[B, W, H]
  }
)
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
batch_size = 1

imgfiles = [join(data_path, 'train', f'{k:03d}_img.png') for k in range(89)]
maskfiles = [join(data_path, 'train', f'{k:03d}_masks.png') for k in range(89)]

# create a tensowflow dataset from the files on disk
ds = lacss.data.dataset_from_img_mask_pairs(imgfiles, maskfiles)

def parser(data):

    image = data['image']
    label = data['label']
    locations = data['locations']

    height = tf.shape(image)[0]
    width = tf.shape(image)[1]

    # simple augmentations
    if tf.random.uniform(()) >= 0.5:
        image = tf.image.flip_left_right(image)
        label = label[:, ::-1]
        locations = locations * [1, -1] + [0, width-1]

    if tf.random.uniform(()) >= 0.5:
        image = tf.image.flip_up_down(image)
        label = label[::-1, :]
        locations = locations * [-1, 1] + [height-1, 0]

    # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
    n_pad = 256 - len(locations)
    locations = tf.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

    return (
        dict(
            image = image,
            gt_locations = locations, 
        ),
        dict(
            gt_labels = label,
        ),
    )

ds = ds.map(parser).repeat().batch(batch_size)

# make sure the dataset has the correct element structure
ds.element_spec

In [None]:
# Convert the td.dataset to generator
train_gen = lacss.train.TFDatasetAdapter(ds, steps=-1).get_dataset()

## Initialize a trainer

- LPNLoss() is the loss function to train the detection head
- SupervisedInstanceLoss() trains the segmentation head
- The VMapped strategy will compute on batched input data on a single GPU. Use other stratagy if your setup is different (eg. TPU training)



In [None]:
# Model Hyperparameters
cfg_json = '''
{
  "backbone": "ConvNeXt",
  "backbone_cfg": {
  	"depths": [3,3,27,3],
  	"drop_path_rate": 0.4
  },
  "detector": {
    "test_max_output": 256,
    "train_max_output": 256
  },
  "segmentor": {
    "conv_spec": [[384,384,384],[64]],
    "feature_level": 2,
    "instance_crop_size": 128,
    "learned_encoding": true
  }
}
'''

import json
lacss_cfg = json.loads(cfg_json)

trainer = lacss.train.Trainer(
    model=lacss.modules.Lacss(**lacss_cfg),
    losses=[
        lacss.losses.LPNLoss(),
        lacss.losses.SupervisedInstanceLoss(),
    ],
    optimizer=optax.adam(0.001),
    strategy=lacss.train.strategy.VMapped,
    seed=1234, # RNG seed
)
trainer.initialize(train_gen)

## Training

Trainer.train() function returns an iterator, stepping through which will drive the training of the model.

In [None]:
n_epoch = 10
steps_per_epoch = 3000

train_iter = trainer.train(train_gen, rng_cols=["droppath"], training=True)

for epoch in range(n_epoch):
  
  print(f"Epoch {epoch+1}")

  for steps in tqdm(range(steps_per_epoch)):

      logs = next(train_iter)

  print(", ".join([f"{k}:{v:.4f}" for k, v in logs.items()]))

  # reset logs
  trainer.reset()

  # perform validation here

  # maybe save a training checkpoint
  # trainer.checkpoint(f"cp-{epoch}")

# save the current model
trainer.save_model("model.pkl")

## Visualize  the model prediction

In [None]:
image = imageio.imread(join(data_path, 'test', '000_img.png'))
gt = imageio.imread(join(data_path, 'test', '000_masks.png'))

# normalize
img = image - image.mean()
img /= img.std()

# prediction
model_output = trainer.model.apply(
    dict(params=trainer.params),
    image = img,
)
pred = lacss.ops.patches_to_label(model_output, input_size=img.shape[:2])
pred = np.asarray(pred)

lacss.utils.show_images([
    img,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]