# LACSS Point-supervised Training Demo

The demo will train a model to segment microscopy images of cells, using only point label.

 * The point label was produced automatically from DAPI images

We will go through these steps:

- Setup the data pipeline

- Initialize a model trainer

- Perform model training

- Visualize the results

## Setting up the environment

In [None]:
!pip install lacss

import imageio
import json
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from skimage.color import label2rgb
from tqdm import tqdm

import lacss

## Data pipeline

Lacss expect training data from a python generator that produces the following data:

```
{
  "image": ndarray[B, W, H, C],
  "gt_locations": ndarray[B, N, 2]
},
```

Here we will set up the data pipeline using tensorflow.dataset library, which has many useful utilities.

In [None]:
# Download the dataset
!wget -c https://data.mendeley.com/public-files/datasets/89s3ymz5wn/files/f976856c-08c5-4bba-85a7-3881e0593115/file_downloaded -O A431.zip

import zipfile
import tensorflow as tf
from os.path import join
from matplotlib.patches import Circle

data_path = 'image_data'
with zipfile.ZipFile('A431.zip', "r") as f:
    f.extractall(data_path)

# show an example of the training data
img = imageio.imread(join(data_path, 'train', 'img_0000.tif'))
with open(join(data_path, "train.json")) as f:
    locations = json.load(f)

lacss.utils.show_images([
    img,
    np.zeros_like(img),
])

ax = plt.gcf().get_axes()
ax[0].set_title("Image")
for pos in locations[0]['locations']:
    c = Circle((pos[1], pos[0]), radius=2, edgecolor='white')
    ax[1].add_patch(c)
ax[1].set_title("Label")

In [None]:
batch_size = 1

# create a tensowflow dataset from the files on disk
ds = lacss.data.dataset_from_simple_annotations(
    join(data_path, "train.json"),
    join(data_path, "train"),
    image_shape=[None, None, 1]
)

def parser(data):

    # build-in data augmentation function
    data = lacss.data.parse_train_data_func(data, size_jitter=[0.8, 1.2])

    # It is important to pad the locations tensor so that all elements of the dataset are of the same shape
    locations = data['locations']
    n_pad = 768 - len(locations)
    locations = tf.pad(locations, [[0, n_pad], [0,0]], constant_values=-1)

    return dict(
      image = data['image'],
      gt_locations = locations, 
    )

ds = ds.map(parser).repeat().batch(batch_size)

# make sure the dataset has the correct element structure
ds.element_spec

In [None]:
# Convert the td.dataset to generator
train_gen = lacss.train.TFDatasetAdapter(ds, steps=-1).get_dataset()

## Initialize a trainer

We will use transfer learning by starting from a pre-trained model. Transfer learning is generally beneficial even if the orginal model was trained on data that looks very different from the current images.


In [None]:
# Load a pretrain model
# This model was trained on bright field microscopy images (LIVECell dataset)
!wget -c https://data.mendeley.com/public-files/datasets/sj3vrvm6w3/files/667af564-0242-4c1d-87b0-c46e6bc3f63d/file_downloaded -O cnsp4_lc.pkl

from dataclasses import asdict
from lacss.deploy import load_from_pretrained
from flax.core.frozen_dict import freeze, unfreeze

pretrained_module, pretrained_params = load_from_pretrained("cnsp4_lc.pkl")

# LacssWithHelper contains both the principal model and the collaborator model
lacss_cfg = asdict(pretrained_module)
model = lacss.modules.lacss.LacssWithHelper(
    cfg=lacss_cfg,
    aux_edge_cfg={},
    aux_fg_cfg={"conv_spec": (16, 32)},
)

losses = [
    lacss.losses.LPNLoss(), # detector head loss
    lacss.losses.SelfSupervisedInstanceLoss(), # segmentation head loss
    lacss.losses.AuxEdgeLoss(), # consistency loss
    lacss.losses.AuxSegLoss( # consistency loss
        offset_sigma=20.0,  # use bigger value for large cells
        offset_scale=2.0,
    ),
    lacss.losses.AuxSizeLoss(), # prevent model collapse
]

trainer = lacss.train.Trainer(
    model=model,
    losses=losses,
    optimizer=optax.adam(0.001),
    strategy=lacss.train.strategy.VMapped,
    seed=1234,
)

trainer.initialize(train_gen)

# Merge with the pre-trained weights
_, aux_edge_params = trainer.params.pop("_aux_edge_module")
_, aux_fg_params = trainer.params.pop("_aux_fg_module")
trainer.state = trainer.state.replace(
    params = freeze(dict(
        _lacss=pretrained_params,
        _aux_edge_module=aux_edge_params,
        _aux_fg_module=aux_fg_params,
    ))
)

## Training

Trainer.train() function returns an iterator, stepping through which will drive the training of the model.

In [None]:
n_epoch = 5
steps_per_epoch = 3000

train_iter = trainer.train(train_gen, rng_cols=["droppath", "augment"], training=True)

for epoch in range(n_epoch):
  
  print(f"Epoch {epoch+1}")

  for steps in tqdm(range(steps_per_epoch)):

      logs = next(train_iter)

  print(", ".join([f"{k}:{v:.4f}" for k, v in logs.items()]))

  # reset logs
  trainer.reset()

  # perform validation here

  # maybe save a training checkpoint
  # trainer.checkpoint(f"cp-{epoch}")

# save the current model. We only need the principal model
trainer.save_model("model.pkl", "_lacss")

## Visualize  the model prediction

In [None]:
image = imageio.imread(join(data_path, 'test', 'img_0001.tif'))
gt = imageio.imread(join(data_path, 'test', 'masks_0001.tif'))

# normalize
img = image - image.mean()
img /= img.std()
img = img[..., None]

# prediction
model_output = trainer.model.apply(
    dict(params=trainer.params),
    image = img,
)
pred = lacss.ops.patches_to_label(model_output, input_size=img.shape[:2])
pred = np.asarray(pred)

lacss.utils.show_images([
    img,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])
titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]