Meant to be run in google collab

In [None]:
!pip install konda
import konda
konda.install()

Collecting konda
  Downloading konda-0.1.0-py3-none-any.whl.metadata (3.7 kB)
Downloading konda-0.1.0-py3-none-any.whl (7.3 kB)
Installing collected packages: konda
Successfully installed konda-0.1.0
Downloading Miniconda installer...
Installing Miniconda to /usr/local...
✅ Miniconda installed successfully!
Run '!conda --version' to check if conda is working.

📋 Usage examples:
  konda create -n my_env python=3.11 -y
  konda activate my_env


In [None]:
!konda install -y -q -c conda-forge micro_sam > /dev/null 2>&1

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp -r /content/drive/MyDrive/full_data /content/full_data

Mounted at /content/drive


In [None]:
import os
from glob import glob
import pathlib
from typing import Tuple

import numpy as np
import imageio.v3 as imageio    # pip install imageio>=0.25

# ---------------------------------------------------------------------
# Configuration — edit if you like
# ---------------------------------------------------------------------
IN_DIR:  str = "full_data/images"          # source folder
OUT_DIR: str = "full_data/images_uint8"    # destination folder
RECURSIVE: bool = False                    # True = search sub-dirs too
VERBOSE: bool = True                       # print before/after ranges

# ---------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------
def rescale_to_uint8(arr: np.ndarray) -> Tuple[np.ndarray, bool]:
    """
    Return (converted_array, was_already_uint8).
    Works with NumPy 1.x and 2.x (ndarray.ptp was removed in 2.0).
    """
    if arr.dtype == np.uint8 and arr.min() >= 0 and arr.max() <= 255:
        return arr, True

    arr_f = arr.astype(np.float32)

    # -------- changed line --------
    denom = arr_f.max() - arr_f.min()          # <-- ptp replacement
    # --------------------------------

    if denom <= 0:                             # flat image
        out = np.zeros_like(arr_f, dtype=np.uint8)
        return out, False

    arr_norm = (arr_f - arr_f.min()) / denom
    out = (arr_norm * 255).round().astype(np.uint8)
    return out, False
# ---------------------------------------------------------------------
# Main loop
# ---------------------------------------------------------------------
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    pattern = "**/*.png" if RECURSIVE else "*.png"
    src_paths = sorted(glob(os.path.join(IN_DIR, pattern), recursive=RECURSIVE))

    if not src_paths:
        raise FileNotFoundError(f"No PNGs found in {pathlib.Path(IN_DIR).resolve()}")

    for src in src_paths:
        img = imageio.imread(src)
        converted, already_ok = rescale_to_uint8(img)

        if VERBOSE:
            print(f"{os.path.basename(src):<30} : "
                  f"before [{img.min():>7.1f}, {img.max():>7.1f}]  "
                  f"→ after [{converted.min():3d}, {converted.max():3d}]"
                  + ("  (skipped)" if already_ok else ""))

        dst = os.path.join(OUT_DIR, os.path.basename(src))
        if already_ok:
            # Quick copy (no recompression) by saving the original array
            imageio.imwrite(dst, img, compression=None)
        else:
            imageio.imwrite(dst, converted, compression=None)

    print(f"\n✔ All done! 8-bit images are in: {pathlib.Path(OUT_DIR).resolve()}")

# ---------------------------------------------------------------------
if __name__ == "__main__":
    main()

1200_01_061025.png             : before [  122.0, 55297.0]  → after [  0, 255]
1200_02_061025.png             : before [  144.0, 49910.0]  → after [  0, 255]
1200_03_061025.png             : before [  127.0, 65535.0]  → after [  0, 255]
1200_04_061025.png             : before [  136.0, 65535.0]  → after [  0, 255]
1200_05_061025.png             : before [  153.0, 61080.0]  → after [  0, 255]
1200_06_061025.png             : before [  175.0, 65535.0]  → after [  0, 255]
1200_07_061025.png             : before [  114.0, 61455.0]  → after [  0, 255]
1200_08_061025.png             : before [  135.0, 65535.0]  → after [  0, 255]
150_01_061025.png              : before [  129.0, 65535.0]  → after [  0, 255]
150_02_061025.png              : before [  188.0, 65535.0]  → after [  0, 255]
150_03_061025.png              : before [  272.0, 65535.0]  → after [  0, 255]
150_04_061025.png              : before [  309.0, 65535.0]  → after [  0, 255]
150_05_061025.png              : before [  178.0, 65

In [None]:
%%bash
python

# --------------------------- imports ----------------------------------
import os
from glob import glob
from IPython.display import FileLink
from typing import Union, Tuple, Optional

import numpy as np
import imageio.v3 as imageio
from matplotlib import pyplot as plt
from skimage.measure import label as connected_components

import torch

from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.util.util import get_random_colors

import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation

# --------------------------- hyper-params -----------------------------
patch_shape = (1, 656, 656)
raw_key, label_key = "*.png", "*.png"
batch_size  = 1
n_epochs    = 5
model_type  = "vit_b_lm"     # or "vit_h_lm"
decoder     = True           # automatic instance segmentation
device      = "cuda" if torch.cuda.is_available() else "cpu"
image_dir = "full_data/images_uint8"
segmentation_dir = "full_data/masks"
checkpoint_name = "sam_hela"
# The 'roi' argument can be used to subselect parts of the data.
# Here, we use it to select the first 390 images (frames) for the train split and the other frames for the val split.
train_roi = np.s_[:390, :, :]
val_roi = np.s_[390:, :, :]
n_objects_per_batch = 5  # the number of objects per batch that will be sampled
# Train an additional convolutional decoder for end-to-end automatic instance segmentation
# NOTE 1: It's important to have densely annotated-labels while training the additional convolutional decoder.
# NOTE 2: In case you do not have labeled images, we recommend using `micro-sam` annotator tools to annotate as many objects as possible per image for best performance.
train_instance_segmentation = True

# NOTE: The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth
# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances
# to the object centers and object boundaries for automatic segmentation.

# There are cases where our inputs are large and the labeled objects are not evenly distributed across the image.
# For this we use samplers, which ensure that valid inputs are chosen subjected to the paired labels.
# The sampler chosen below makes sure that the chosen inputs have atleast one foreground instance, and filters out small objects.
sampler = MinInstanceSampler(min_size=25)  # NOTE: The choice of 'min_size' value is paired with the same value in 'min_size' filter in 'label_transform'.

train_loader = sam_training.default_sam_loader(
    raw_paths=image_dir,
    raw_key=raw_key,
    label_paths=segmentation_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    rois=train_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

val_loader = sam_training.default_sam_loader(
    raw_paths=image_dir,
    raw_key=raw_key,
    label_paths=segmentation_dir,
    label_key=label_key,
    with_segmentation_decoder=train_instance_segmentation,
    patch_shape=patch_shape,
    batch_size=batch_size,
    is_seg_dataset=True,
    rois=val_roi,
    shuffle=True,
    raw_transform=sam_training.identity,
    sampler=sampler,
)

# Run training
sam_training.train_sam(
    name=checkpoint_name,
    save_root="models",
    model_type=model_type,
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

Start fitting for 1990 iterations /  5 epochs
with 398 iterations per epoch
Training with mixed precision
Finished training after 5 epochs / 1990 iterations.
The best epoch is number 3.
Training took 2864.271598339081 seconds (= 00:47:44 hours)


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
Verifying labels in 'train' dataloader:   0%|          | 0/50 [00:00<?, ?it/s]Verifying labels in 'train' dataloader:   4%|▍         | 2/50 [00:00<00:03, 13.21it/s]Verifying labels in 'train' dataloader:   8%|▊         | 4/50 [00:00<00:03, 14.05it/s]Verifying labels in 'train' dataloader:  12%|█▏        | 6/50 [00:00<00:03, 13.68it/s]Verifying labels in 'train' dataloader:  16%|█▌        | 8/50 [00:00<00:03, 13.72it/s]Verifying labels in 'train' dataloader:  20%|██        | 10/50 [00:00<00:02, 13.81it/s]Verifying labels in 'train' dataloader:  24%|██▍       | 12/50 [00:00<00:02, 13.42it/s]Verifying labels in 'train' dataloader:  28%|██▊       | 14/50 [00:01<00:02, 13.22it/s]Verifying labels in 'train' dataloader:  32%|███▏      | 16/50 [00:01<00:02, 13.42it/s]Verifying labels in 'train' dataloader:  36%|███▌      | 18/50 [00:01<00:02, 14.08it/s]Verifying labels in 'train' dataloader:  40%|████      | 20/50 [00:01<00:02,