# Finetuning Segment Anything with `µsam`

This notebook shows how to use Segment Anything for Microscopy to fine-tune a Segment Anything Model (SAM) on an open-source data with multiple channels.

We use confocal microscopy images from the HPA Kaggle Challenge for protein identification (from [Ouyang et al.](https://doi.org/10.1038/s41592-019-0658-6)) in this notebook for the cell segmentation task. The functionalities shown here should work for your (microscopy) images too.

## Running this notebook

If you have an environment with `µsam` on your computer you can run this notebook in there. You can follow the [installation instructions](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation) to install it on your computer.

You can also run this notebook in the cloud on [Kaggle Notebooks](https://www.kaggle.com/code/). This service offers free usage of a GPU to speed up running the code. The next cells will take care of the installation for you if you are using it.

In [None]:
# Check if we are running this notebook on kaggle, google colab or local compute resources.

import os
current_spot = os.getcwd()

if current_spot.startswith("/kaggle/working"):
    print("Kaggle says hi!")
    root_dir = "/kaggle/working"

elif current_spot.startswith("/content"):
    print("Google Colab says hi!")
    print(" NOTE: The scripts have not been tested on Google Colab, you might need to adapt the installations a bit.")
    root_dir = "/content"

    # You might need to install condacolab on Google Colab to be able to install packages using conda / mamba
    # !pip install -q condacolab
    # import condacolab
    # condacolab.install()

else:
    msg = "You are using a behind-the-scenes resource. Follow our installation instructions here:"
    msg += " https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation"
    print(msg)
    root_dir = ""  # overwrite to set the root directory, where the data, checkpoints, and all relevant stuff will be stored

### Installation

The next cells will install the `micro_sam` library on Kaggle Notebooks. **Please skip these cells and go to `Importing the libraries` if you are running the notebook on your own computer.**

In [None]:
!git clone --quiet https://github.com/computational-cell-analytics/micro-sam.git
tmp_dir = os.path.join(root_dir, "micro-sam")
!pip install --quiet $tmp_dir

In [None]:
!git clone --quiet https://github.com/constantinpape/torch-em.git
tmp_dir = os.path.join(root_dir, "torch-em")
!pip install --quiet $tmp_dir

In [None]:
!git clone --quiet https://github.com/constantinpape/elf.git
tmp_dir = os.path.join(root_dir, "elf")
!pip install --quiet $tmp_dir

Known Issues on **Kaggle Notebooks**:

1. `warning  libmamba Cache file "/opt/conda/pkgs/cache/2ce54b42.json" was modified by another program` (multiples lines of such warnings)
    - We have received this warning while testing this notebook on Kaggle. It does not lead to any issues while making use of the installed packages. You can proceed and ignore the warnings.

In [None]:
!mamba install -q -y -c conda-forge nifty affogato zarr z5py
!pip uninstall -y --quiet qtpy  # qtpy is not supported in Kaggle / Google Colab, let's remove it to avoid errors.

### Importing the libraries

In [None]:
import os
from glob import glob
from pathlib import Path
from natsort import natsorted
from IPython.display import FileLink

import h5py
import numpy as np
import imageio.v3 as imageio
from matplotlib import pyplot as plt
from skimage.measure import label as connected_components

import torch

import torch_em
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_hpa_segmentation_paths
from torch_em.util.util import get_random_colors
from torch_em.transform.label import PerObjectDistanceTransform

from micro_sam import util
import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.instance_segmentation import (
    InstanceSegmentationWithDecoder,
    get_predictor_and_decoder,
    mask_data_to_segmentation
)

### Let's download the dataset

First, we download the volumes, assort the input data (thanks to `torch-em`) and store the images and corresponding labels in `tif` files.

In [None]:
# Download the data into a directory
DATA_FOLDER = os.path.join(root_dir, "data")
volume_paths = get_hpa_segmentation_paths(path=os.path.join(DATA_FOLDER, "hpa"))

# Store inputs as tif files
image_dir = os.path.join(DATA_FOLDER, "hpa", "preprocessed", "images")
label_dir = os.path.join(DATA_FOLDER, "hpa", "preprocessed", "labels")
os.makedirs(image_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)

for volume_path in volume_paths:
    fname = Path(volume_path).stem

    with h5py.File(volume_path, "r") as f:
        # Get the channel-wise inputs
        image = np.stack(
            [f["raw/microtubule"], f["raw/protein"], f["raw/nuclei"], f["raw/er"]], axis=-1
        )
        labels = f["labels"]

    image_path = os.path.join(image_dir, f"{fname}.tif")
    label_path = os.path.join(label_dir, f"{fname}.tif")

    imageio.imwrite(image_path, image)
    imageio.imwrite(label_path, labels)

print(f"The inputs have been preprocessed and stored at: '{os.path.join(DATA_FOLDER, 'hpa', 'preprocessed')}'")

### Let's understand our inputs' data structure.

In [None]:
image_paths = natsorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = natsorted(glob(os.path.join(label_dir, "*.tif")))

for image_path, label_path in zip(image_paths, label_paths):
    image = imageio.imread(image_path)
    labels = imageio.imread(label_path)

    print(f"Shape of inputs: '{image.shape}'")  # The images should be of shape: H, W, 4 -> where, 4 is the number of channels.
    print(f"Shape of corresponding labels: '{labels.shape}'")  # The labels should be of shape: H, W

Segment Anything accepts inputs of either 1 channel or 3 channels. To fine-tune Segment Anything on our data, we must select either 1 channel or 3 channels out of the 4 channels available.

Let's make the choice to choose the `microtubule` (first channel), `protein` (second channel) and `nuclei` (third channel) for finetuning Segment Anything.

In [None]:
# We remove the 'er' channel, i.e. the last channel.
for image_path in zip(image_paths):
    image = imageio.imread(image_path)
    image = image[..., :-1]
    imageio.imwrite(image_path, image)

### Let's create the dataloaders

Our task is to segment cells in confocal microscopy images. The dataset comes from https://zenodo.org/records/4665863, and the dataloader has been implemented in [torch-em](https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/light_microscopy/hpa.py)

#### First, let's visualize how our samples look.

In [None]:
for image_path, label_path in zip(image_paths, label_paths):
    image = imageio.imread(image_path)
    labels = imageio.imread(label_path)

    fig, ax = plt.subplots(1, 2, figsize=(10, 10))
    ax[0].imshow(image, cmap="gray")
    ax[0].set_title("Input Image")
    ax[0].axis("off")
    
    labels = connected_components(labels)
    ax[1].imshow(labels, cmap=get_random_colors(labels), interpolation="nearest")
    ax[1].set_title("Ground Truth Instances")
    ax[1].axis("off")
    
    plt.show()
    plt.close()
    
    break  # comment this out in case you want to visualize all the images

#### Next, let's create the dataloaders.