# LACSS Inference Demo

This is a small notebook demonstrating the workflow of applying an LACSS model to make segmentation prediction.




## Setting up the environment

In [None]:
!pip install git+https://github.com/jiyuuchc/lacss@lacss1

import imageio.v2 as imageio
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import label2rgb

from lacss.deploy import Predictor, model_urls
from lacss.utils import show_images

## Load a pre-trained model

Here we load a model pre-trained on the [tissuenet](https://datasets.deepcell.org/) dataset

In [None]:
predictor = Predictor(model_urls["cnsp4-fl"])

## Also download some image data

We will download some microscopy images from the [Cell Image Library](http://www.cellimagelibrary.org/home) collection.

In [None]:
!wget -c https://data.mendeley.com/public-files/datasets/894mmsd9nj/files/568e524f-9a95-45a6-9f80-3619969c2a37/file_downloaded -O images.zip

import zipfile

data_path = 'image_data'
with zipfile.ZipFile('images.zip', "r") as f:
    f.extractall(data_path)

## Make a prdiction

In [None]:
image = imageio.imread("image_data/test/000_img.png")
gt = imageio.imread("image_data/test/000_masks.png")

pred = predictor.predict(image.astype("float32"))["pred_label"]

# the default model outputs are JAX arrays. It is more convenient 
# to use a numpy array for downstream analysis / visulization
pred = np.asarray(pred)

show_images([
    image,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])

titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]


You may notice that the last part of the code is running quite slowly. This is because LACSS is model based on [JAX](https://jax.readthedocs.io/en/latest/) framework, which performs **just-in-time compilation** of the model the first time we run it. This will take some time, but only happens on the first run.

In addition, the inferencen result is BAD! Why? Well, the model was trained on a tisuenet dataset, but the image we are analyzing is from an unrelated dataset, which has different channel organization and different pixel value normalization. We can improve the results by rearrange the data to match the orginal training data structure:

In [None]:
image_rearranged = image[..., (1,0,2)] / 255.0

pred = predictor.predict(image_rearranged)["pred_label"]
pred = np.asarray(pred) 

show_images([
    image,
    label2rgb(pred, bg_label=0),
    label2rgb(gt, bg_label=0),
])

titles = ['Input', "Prediction", "Ground Truth"]
[ax.set_title(title) for ax, title in zip(plt.gcf().get_axes(), titles)]

Ok, much better.

It is still not good enough though. The remaining inaccuracies reflect the **domain shift** between the training data and the inference data. we can further improve the results by re-training on the new dataset. Check the [training demos](https://www.github.com/jiyuuchc/lacss_jax) to see how to do that.