From 33239758f675c3ec93713ba9b2e95e3ba911c910 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 12 Apr 2024 09:19:04 +0200 Subject: [PATCH] Add ViM-UNet (#236) Add ViM-UNet implementation and experiments --- experiments/misc/get_vimunet_plots.py | 110 +++++++ experiments/vision-mamba/.gitignore | 3 + experiments/vision-mamba/vimunet/README.md | 31 ++ experiments/vision-mamba/vimunet/run_cremi.py | 176 ++++++++++ .../vision-mamba/vimunet/run_livecell.py | 234 ++++++++++++++ .../vision-transformer/unetr/README.md | 42 ++- .../unetr/cremi/cremi_unetr.py | 80 ----- .../for_vimunet_benchmarking/run_cremi.py | 197 ++++++++++++ .../for_vimunet_benchmarking/run_livecell.py | 253 +++++++++++++++ .../unetr/livecell/common.py | 12 +- .../unetr/livecell/train_livecell.py | 31 +- torch_em/data/datasets/neurips_cell_seg.py | 302 +++++++++--------- torch_em/model/__init__.py | 1 + torch_em/model/vim.py | 205 ++++++++++++ vimunet.md | 81 ++++- 15 files changed, 1494 insertions(+), 264 deletions(-) create mode 100644 experiments/misc/get_vimunet_plots.py create mode 100644 experiments/vision-mamba/.gitignore create mode 100644 experiments/vision-mamba/vimunet/README.md create mode 100644 experiments/vision-mamba/vimunet/run_cremi.py create mode 100644 experiments/vision-mamba/vimunet/run_livecell.py delete mode 100644 experiments/vision-transformer/unetr/cremi/cremi_unetr.py create mode 100644 experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_cremi.py create mode 100644 experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_livecell.py create mode 100644 torch_em/model/vim.py diff --git a/experiments/misc/get_vimunet_plots.py b/experiments/misc/get_vimunet_plots.py new file mode 100644 index 00000000..7f87a1ee --- /dev/null +++ b/experiments/misc/get_vimunet_plots.py @@ -0,0 +1,110 @@ +import numpy as np + +import matplotlib.pyplot as plt +from matplotlib.ticker import FormatStrFormatter + + +LIVECELL_RESULTS = { + "UNet": {"boundaries": 0.372, "distances": 0.429}, + r"UNETR$_{Base}$": {"boundaries": 0.11, "distances": 0.145}, + r"UNETR$_{Large}$": {"boundaries": 0.171, "distances": 0.157}, + r"UNETR$_{Huge}$": {"boundaries": 0.216, "distances": 0.136}, + r"nnUNet$_{v2}$": {"boundaries": 0.228}, + r"UMamba$_{Bot}$": {"boundaries": 0.234}, + r"UMamba$_{Enc}$": {"boundaries": 0.23}, + r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.269, "distances": 0.381}, + r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.274, "distances": 0.397}, +} + +CREMI_RESULTS = { + "UNet": {"boundaries": 0.354}, + r"UNETR$_{Base}$": {"boundaries": 0.285}, + r"UNETR$_{Large}$": {"boundaries": 0.325}, + r"UNETR$_{Huge}$": {"boundaries": 0.324}, + r"nnUNet$_{v2}$": {"boundaries": 0.452}, + r"UMamba$_{Bot}$": {"boundaries": 0.471}, + r"UMamba$_{Enc}$": {"boundaries": 0.467}, + r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.518}, + r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.53}, +} + +DATASET_MAPPING = { + "livecell": "LIVECell", + "cremi": "CREMI" +} + +plt.rcParams["font.size"] = 24 + + +def plot_per_dataset(dataset_name): + if dataset_name == "livecell": + results = LIVECELL_RESULTS + else: + results = CREMI_RESULTS + + models = list(results.keys()) + metrics = list(results[models[0]].keys()) + + markers = ['^', '*'] + + fig, ax = plt.subplots(figsize=(15, 12)) + + x_pos = np.arange(len(models)) + + bar_width = 0.05 + + for i, metric in enumerate(metrics): + scores_list = [] + for model in models: + try: + score = results[model][metric] + except KeyError: + score = None + + scores_list.append(score) + + ax.scatter(x_pos + i * bar_width - bar_width, scores_list, s=250, label=metric, marker=markers[i]) + + ax.set_xticks(x_pos) + ax.set_xticklabels(models, va='top', ha='center', rotation=45) + + ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) + if dataset_name == "cremi": + ax.set_yticks(np.linspace(0, 0.5, 11)[1:]) + else: + ax.set_yticks(np.linspace(0, 0.4, 9)[1:]) + + ax.set_ylabel('Segmentation Accuracy', labelpad=15) + ax.set_xlabel(None) + ax.set_title(DATASET_MAPPING[dataset_name], fontsize=32, y=1.025) + ax.set_ylim(0) + ax.legend(loc='lower center', fancybox=True, shadow=True, ncol=2) + + best_models = sorted(models, key=lambda x: max(results[x].values()), reverse=True)[:3] + sizes = [100, 70, 40] + for size, best_model in zip(sizes, best_models): + best_scores = [results[best_model].get(metric, 0) for metric in metrics] + best_index = models.index(best_model) + + # HACK + offset = 0 if dataset_name == "livecell" else 0.05 + + ax.plot( + best_index - offset, max(best_scores), marker='o', markersize=size, linestyle='dotted', + markerfacecolor='gray', markeredgecolor='black', markeredgewidth=2, alpha=0.2 + ) + + plt.tight_layout() + plt.show() + plt.savefig(f"{dataset_name}.png") + plt.savefig(f"{dataset_name}.svg", transparent=True) + plt.savefig(f"{dataset_name}.pdf") + + +def main(): + plot_per_dataset("livecell") + plot_per_dataset("cremi") + + +if __name__ == "__main__": + main() diff --git a/experiments/vision-mamba/.gitignore b/experiments/vision-mamba/.gitignore new file mode 100644 index 00000000..27d28c01 --- /dev/null +++ b/experiments/vision-mamba/.gitignore @@ -0,0 +1,3 @@ +*.out +*.sh +*.png \ No newline at end of file diff --git a/experiments/vision-mamba/vimunet/README.md b/experiments/vision-mamba/vimunet/README.md new file mode 100644 index 00000000..2fba3f4a --- /dev/null +++ b/experiments/vision-mamba/vimunet/README.md @@ -0,0 +1,31 @@ +# ViM-UNet: Vision Mamba in Biomedical Segmentation + +We introduce **ViM-UNet**. a novel segmentation architecture based on Vision Mamba for instance segmentation in microscopy. + +To get started, make sure to take a look at the [documentation](https://github.com/constantinpape/torch-em/blob/main/vimunet.md). + +Here are the experiments for instance segmentation on: +1. LIVECell for cell segmentation in phase-contrast microscopy. + - You can run the boundary-based /distance-based experiments. See `run_livecell.py -h` for details. + ```python + python run_livecell.py -i + -s + -m # the supported models are 'vim_t', 'vim_s' and 'vim_b' + --train # for training + --predict # for inference on trained models + --result_path + # below is how you can provide the choice for training for either methods + --boundaries / --distances + ``` + +2. CREMI for neurites segmentation in electron microscopy. + - You can run the boundary-based experiment. See `run_livecell.py -h` for details. Below is an example script: + ```python + python run_cremi.py -i + -s + -m # the supported models are 'vim_t', 'vim_s' and 'vim_b' + --train # for training + --predict # for inference on trained models + --result_path + ``` + diff --git a/experiments/vision-mamba/vimunet/run_cremi.py b/experiments/vision-mamba/vimunet/run_cremi.py new file mode 100644 index 00000000..880e5d3f --- /dev/null +++ b/experiments/vision-mamba/vimunet/run_cremi.py @@ -0,0 +1,176 @@ +import os +import argparse +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm + +import imageio.v3 as imageio + +import torch + +import torch_em +from torch_em.loss import DiceLoss +from torch_em.util import segmentation +from torch_em.data import MinInstanceSampler +from torch_em.model import get_vimunet_model +from torch_em.data.datasets import get_cremi_loader +from torch_em.util.prediction import predict_with_halo + +from elf.evaluation import mean_segmentation_accuracy + + +ROOT = "/scratch/usr/nimanwai" + +# the splits have been customed made +# to reproduce the results: +# extract slices ranging from "100 to 125" for all three volumes +CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original" + + +def get_loaders(args, patch_shape=(1, 512, 512)): + train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} + val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} + + sampler = MinInstanceSampler() + + train_loader = get_cremi_loader( + path=args.input, + patch_shape=patch_shape, + batch_size=2, + rois=train_rois, + sampler=sampler, + ndim=2, + label_dtype=torch.float32, + defect_augmentation_kwargs=None, + boundaries=True, + num_workers=16, + download=True, + ) + val_loader = get_cremi_loader( + path=args.input, + patch_shape=patch_shape, + batch_size=1, + rois=val_rois, + sampler=sampler, + ndim=2, + label_dtype=torch.float32, + defect_augmentation_kwargs=None, + boundaries=True, + num_workers=16, + download=True, + ) + return train_loader, val_loader + + +def run_cremi_training(args): + # the dataloaders for cremi dataset + train_loader, val_loader = get_loaders(args) + + # the vision-mamba + decoder (UNet-based) model + model = get_vimunet_model( + out_channels=1, + model_type=args.model_type, + with_cls_token=True + ) + + save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) + + # loss function + loss = DiceLoss() + + # trainer for the segmentation task + trainer = torch_em.default_segmentation_trainer( + name="cremi-vimunet", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10} + ) + trainer.fit(iterations=int(1e5)) + + +def run_cremi_inference(args, device): + save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) + checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt") + + # the vision-mamba + decoder (UNet-based) model + model = get_vimunet_model( + out_channels=1, + model_type=args.model_type, + with_cls_token=True, + checkpoint=checkpoint + ) + + all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif")) + all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif")) + + msa_list, sa50_list, sa75_list = [], [], [] + for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + predictions = predict_with_halo( + image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True, + ) + + bd = predictions.squeeze() + instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) + + msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + sa75_list.append(sa_acc[5]) + + res = { + "CREMI": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list), + "SA75": np.mean(sa75_list) + } + res_path = os.path.join(args.result_path, "results.csv") + df = pd.DataFrame.from_dict([res]) + df.to_csv(res_path) + print(df) + print(f"The result is saved at {res_path}") + + +def main(args): + print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") + device = "cuda" if torch.cuda.is_available() else "cpu" + + if args.train: + run_cremi_training(args) + + if args.predict: + run_cremi_inference(args, device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"), help="Path to CREMI dataset." + ) + parser.add_argument( + "-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved." + ) + parser.add_argument( + "-m", "--model_type", type=str, default="vim_t", help="Choice of ViM backbone" + ) + parser.add_argument( + "--train", action="store_true", help="Whether to train the model." + ) + parser.add_argument( + "--predict", action="store_true", help="Whether to run inference on the trained model." + ) + parser.add_argument( + "--result_path", type=str, default="./", help="Path to save quantitative results." + ) + args = parser.parse_args() + main(args) diff --git a/experiments/vision-mamba/vimunet/run_livecell.py b/experiments/vision-mamba/vimunet/run_livecell.py new file mode 100644 index 00000000..4c922a57 --- /dev/null +++ b/experiments/vision-mamba/vimunet/run_livecell.py @@ -0,0 +1,234 @@ +import os +import argparse +import numpy as np +import pandas as pd +from glob import glob +from tqdm import tqdm + +import imageio.v3 as imageio + +import torch + +import torch_em +from torch_em.util import segmentation +from torch_em.model import get_vimunet_model +from torch_em.transform.raw import standardize +from torch_em.data.datasets import get_livecell_loader +from torch_em.loss import DiceLoss, DiceBasedDistanceLoss + +from elf.evaluation import mean_segmentation_accuracy + + +ROOT = "/scratch/usr/nimanwai" + + +def get_loaders(args, patch_shape=(512, 512)): + if args.distances: + label_trafo = torch_em.transform.label.PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=25 + ) + else: + label_trafo = None + + train_loader = get_livecell_loader( + path=args.input, + split="train", + patch_shape=patch_shape, + batch_size=2, + label_dtype=torch.float32, + boundaries=args.boundaries, + label_transform=label_trafo, + num_workers=16, + download=True, + ) + val_loader = get_livecell_loader( + path=args.input, + split="val", + patch_shape=patch_shape, + batch_size=1, + label_dtype=torch.float32, + boundaries=args.boundaries, + label_transform=label_trafo, + num_workers=16, + download=True, + ) + return train_loader, val_loader + + +def get_output_channels(args): + if args.boundaries: + output_channels = 2 + else: + output_channels = 3 + + return output_channels + + +def get_loss_function(args): + if args.distances: + loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + + else: + loss = DiceLoss() + + return loss + + +def get_save_root(args): + # experiment_type + if args.boundaries: + experiment_type = "boundaries" + else: + experiment_type = "distances" + + model_name = args.model_type + + # saving the model checkpoints + save_root = os.path.join(args.save_root, "scratch", experiment_type, model_name) + return save_root + + +def run_livecell_training(args): + # the dataloaders for livecell dataset + train_loader, val_loader = get_loaders(args) + + output_channels = get_output_channels(args) + + # the vision-mamba + decoder (UNet-based) model + model = get_vimunet_model( + out_channels=output_channels, + model_type=args.model_type, + with_cls_token=True, + ) + + save_root = get_save_root(args) + + # loss function + loss = get_loss_function(args) + + # trainer for the segmentation task + trainer = torch_em.default_segmentation_trainer( + name="livecell-vimunet", + model=model, + train_loader=train_loader, + val_loader=val_loader, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10} + ) + trainer.fit(iterations=int(1e5)) + + +def run_livecell_inference(args, device): + output_channels = get_output_channels(args) + + save_root = get_save_root(args) + + checkpoint = os.path.join(save_root, "checkpoints", "livecell-vimunet", "best.pt") + + # the vision-mamba + decoder (UNet-based) model + model = get_vimunet_model( + out_channels=output_channels, + model_type=args.model_type, + with_cls_token=True, + checkpoint=checkpoint, + ) + + # the splits are provided with the livecell dataset + # to reproduce the results: + # run the inference on the entire test datasets as it is. + test_image_dir = os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images") + all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*")) + + msa_list, sa50_list, sa75_list = [], [], [] + for label_path in tqdm(all_test_labels): + labels = imageio.imread(label_path) + image_id = os.path.split(label_path)[-1] + + image = imageio.imread(os.path.join(test_image_dir, image_id)) + image = standardize(image) + + tensor_image = torch.from_numpy(image)[None, None].to(device) + + predictions = model(tensor_image) + predictions = predictions.squeeze().detach().cpu().numpy() + + if args.boundaries: + fg, bd = predictions + instances = segmentation.watershed_from_components(bd, fg) + else: + fg, cdist, bdist = predictions + instances = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) + + msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + sa75_list.append(sa_acc[5]) + + res = { + "LIVECell": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list), + "SA75": np.mean(sa75_list) + } + res_path = os.path.join(args.result_path, "results.csv") + df = pd.DataFrame.from_dict([res]) + df.to_csv(res_path) + print(df) + print(f"The result is saved at {res_path}") + + +def main(args): + assert (args.boundaries + args.distances) == 1, "Choose only one of boundaries / distances to run." + + print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.train: + run_livecell_training(args) + + if args.predict: + run_livecell_inference(args, device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input", type=str, default=os.path.join(ROOT, "data", "livecell"), help="Path to LIVECell dataset." + ) + parser.add_argument( + "-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved." + ) + parser.add_argument( + "-m", "--model_type", type=str, default="vim_t", help="Choice of ViM backbone." + ) + parser.add_argument( + "--train", action="store_true", help="Whether to train the model." + ) + parser.add_argument( + "--predict", action="store_true", help="Whether to run inference on the trained model." + ) + parser.add_argument( + "--result_path", type=str, default="./", help="Path to save quantitative results." + ) + parser.add_argument( + "--boundaries", action="store_true", help="Runs the boundary-based methods." + ) + parser.add_argument( + "--distances", action="store_true", help="Runs the distance-based methods." + ) + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/README.md b/experiments/vision-transformer/unetr/README.md index 43a6d7ca..70577b3a 100644 --- a/experiments/vision-transformer/unetr/README.md +++ b/experiments/vision-transformer/unetr/README.md @@ -1,19 +1,39 @@ -## SAM's ViT Initialization in UNETR +## UNETR +## Integrating SegmentAnything's Vision Transformer -Note: -- `model_type` - [`vit_b`/`vit_l`/`vit_h`] -- `out_channels` - Number of output channels -- `encoder_checkpoint_path` - Pass the checkpoints from the pretrained [Segment Anything](https://github.com/facebookresearch/segment-anything) models to initialize the SAM weights to the (ViT) encoder backbone (Click on the model names to download them - [ViT-b](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) / [ViT-l](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) / [ViT-h](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)) +The UNETR is implemented by adapting the vision transformer from Segment Anything for biomedical image segmentation. +Key Mentions: +- It's expected to install [SegmentAnything](https://github.com/facebookresearch/segment-anything) for this. +- The supported models are ViT Base, ViT Large and ViT Huge. They are often abbreviated as: [`vit_b`/`vit_l`/`vit_h`] +- The advantage of using SegmentAnything's vision transformer is to enable loading the pretrained weights without any hassle. It's exposed in the `UNETR` class configuration under the argument name: `encoder_checkpoint_path` - You need to pass the checkpoints from the pretrained [SegmentAnything models](https://github.com/facebookresearch/segment-anything#model-checkpoints) to initialize the SAM weights to the (ViT) encoder backbone (click on the model names to download them - [vit_b](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) / [vit_l](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) / [vit_h](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)) -### How to initialize ViT models for UNETR? -``` + +### How to train UNETR from scratch? +```python from torch_em.model import UNETR -unetr = UNETR(encoder=model_type, out_channels=out_channels, encoder_checkpoint_path=checkpoint_path) +model = UNETR( + encoder=, # name of the vit backbone (see supported abbreviations above) + out_channels=, # number of output channels matching the segmentation targets +) ``` -### Vanilla ViT models for UNETR -``` +### How to train UNETR, initialized with pretrained SegmentAnything weights? +```python from torch_em.model import UNETR -unetr = UNETR(encoder=model_type, out_channels=out_channels) +unetr = UNETR( + encoder=, # name of the vit backbone (see supported abbreviations above) + out_channels=, # number of output channels matching the segmentation targets + encoder_checkpoint_path=, # path to the pretrained model weights + use_sam_stats=True # uses the image statistics from SA-1B dataset +) ``` + +## Description: +- `for_vimunet_benchmarking/`: (see [ViM-UNet description](https://github.com/constantinpape/torch-em/blob/main/vimunet.md) for details) + - `run_livecell.py`: Benchmarking UNet and UNETR for cell segmentation in phase contrast microscopy. + - `run_cremi.py`: Benchmarking UNet and UNETR for neurites segmentation in electron microscopy. + +### Additional Experiments: +- `dsb/`: Experiments on DSB data for segmentation of nuclei in light microscopy. +- `livecell/`: Experiments on LIVECell data for segmentation of cells in phase contrast microscopy. diff --git a/experiments/vision-transformer/unetr/cremi/cremi_unetr.py b/experiments/vision-transformer/unetr/cremi/cremi_unetr.py deleted file mode 100644 index c4696ee1..00000000 --- a/experiments/vision-transformer/unetr/cremi/cremi_unetr.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import argparse -import numpy as np - -import torch -import torch_em -from torch_em.model import UNETR -from torch_em.data.datasets import get_cremi_loader - - -def do_unetr_training(data_path: str, save_root: str, iterations: int, device, patch_shape=(1, 512, 512)): - os.makedirs(data_path, exist_ok=True) - - cremi_train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} - cremi_val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} - - train_loader = get_cremi_loader( - path=data_path, - patch_shape=patch_shape, download=True, - rois=cremi_train_rois, - ndim=2, - defect_augmentation_kwargs=None, - boundaries=True, - batch_size=2 - ) - - val_loader = get_cremi_loader( - path=data_path, - patch_shape=patch_shape, download=True, - rois=cremi_val_rois, - ndim=2, - defect_augmentation_kwargs=None, - boundaries=True, - batch_size=1 - ) - - model = UNETR( - encoder="vit_b", out_channels=1, - encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth") - model.to(device) - - trainer = torch_em.default_segmentation_trainer( - name="unetr-cremi", - model=model, - train_loader=train_loader, - val_loader=val_loader, - device=device, - learning_rate=1e-5, - log_image_interval=10, - save_root=save_root, - compile_model=False - ) - - trainer.fit(iterations) - - -def main(args): - print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - if args.train: - print("Training a 2D UNETR on Cremi dataset") - do_unetr_training( - data_path=args.inputs, - save_root=args.save_root, - iterations=args.iterations, - device=device - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--train", action='store_true', help="Enables UNETR training on Cremi dataset") - parser.add_argument("-i", "--inputs", type=str, default="./cremi/", - help="Path where the dataset already exists/will be downloaded by the dataloader") - parser.add_argument("-s", "--save_root", type=str, default=None, - help="Path where checkpoints and logs will be saved") - parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for") - args = parser.parse_args() - main(args) diff --git a/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_cremi.py b/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_cremi.py new file mode 100644 index 00000000..ceea2531 --- /dev/null +++ b/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_cremi.py @@ -0,0 +1,197 @@ +import os +import argparse +from glob import glob +from tqdm import tqdm + +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +import torch + +import torch_em +from torch_em.loss import DiceLoss +from torch_em.util import segmentation +from torch_em.model import UNETR, UNet2d +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets import get_cremi_loader +from torch_em.model.unetr import SingleDeconv2DBlock +from torch_em.util.prediction import predict_with_halo + +from elf.evaluation import mean_segmentation_accuracy + + +ROOT = "/scratch/usr/nimanwai" + +# the splits have been customed made +# to reproduce the results: +# extract slices ranging from "100 to 125" for all three volumes +CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original" + + +def get_loaders(args, patch_shape=(1, 512, 512)): + train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} + val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} + + sampler = MinInstanceSampler() + + train_loader = get_cremi_loader( + path=args.input, + patch_shape=patch_shape, + batch_size=2, + rois=train_rois, + sampler=sampler, + ndim=2, + label_dtype=torch.float32, + defect_augmentation_kwargs=None, + boundaries=True, + num_workers=16, + download=True, + ) + val_loader = get_cremi_loader( + path=args.input, + patch_shape=patch_shape, + batch_size=1, + rois=val_rois, + sampler=sampler, + ndim=2, + label_dtype=torch.float32, + defect_augmentation_kwargs=None, + boundaries=True, + num_workers=16, + download=True, + ) + return train_loader, val_loader + + +def get_model(args, device): + if args.model_type == "unet": + # the UNet model + model = UNet2d( + in_channels=1, + out_channels=1, + initial_features=64, + final_activation="Sigmoid", + sampler_impl=SingleDeconv2DBlock, + ) + else: + # the UNETR model + model = UNETR( + encoder=args.model_type, + out_channels=1, + final_activation="Sigmoid", + use_skip_connection=False, + ) + model.to(device) + + return model + + +def run_cremi_training(args, device): + # the dataloaders for cremi dataset + train_loader, val_loader = get_loaders(args) + + model = get_model(args, device) + + save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) + + # loss function + loss = DiceLoss() + + trainer = torch_em.default_segmentation_trainer( + name="cremi-unet" if args.model_type == "unet" else "cremi-unetr", + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10} + ) + trainer.fit(int(1e5)) + + +def run_cremi_inference(args, device): + save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) + checkpoint = os.path.join( + save_root, "checkpoints", "cremi-unet" if args.model_type == "unet" else "cremi-unetr", "best.pt" + ) + + model = get_model(args, device) + + assert os.path.exists(checkpoint), checkpoint + model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif")) + all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif")) + + msa_list, sa50_list, sa75_list = [], [], [] + for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + predictions = predict_with_halo( + image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True, + ) + + bd = predictions.squeeze() + instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) + + msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + sa75_list.append(sa_acc[5]) + + res = { + "CREMI": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list), + "SA75": np.mean(sa75_list) + } + res_path = os.path.join(args.result_path, "results.csv") + df = pd.DataFrame.from_dict([res]) + df.to_csv(res_path) + print(df) + print(f"The result is saved at {res_path}") + + +def main(args): + print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.train: + run_cremi_training(args, device) + + if args.predict: + run_cremi_inference(args, device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"), help="Path to CREMI dataset." + ) + parser.add_argument( + "-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved." + ) + parser.add_argument( + "-m", "--model_type", type=str, required=True, + help="Choice of encoder. Supported models are 'unet', 'vit_b', 'vit_l' and 'vit_h'." + ) + parser.add_argument( + "--train", action="store_true", help="Whether to train the model." + ) + parser.add_argument( + "--predict", action="store_true", help="Whether to run inference for the trained model." + ) + parser.add_argument( + "--result_path", type=str, default="./", help="Path to save quantitative results." + ) + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_livecell.py b/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_livecell.py new file mode 100644 index 00000000..e7f33dec --- /dev/null +++ b/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_livecell.py @@ -0,0 +1,253 @@ +import os +import argparse +from glob import glob +from tqdm import tqdm + +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +import torch + +import torch_em +from torch_em.util import segmentation +from torch_em.model import UNETR, UNet2d +from torch_em.transform.raw import standardize +from torch_em.model.unetr import SingleDeconv2DBlock +from torch_em.data.datasets import get_livecell_loader +from torch_em.util.prediction import predict_with_padding +from torch_em.loss import DiceLoss, DiceBasedDistanceLoss + +from elf.evaluation import mean_segmentation_accuracy + + +ROOT = "/scratch/usr/nimanwai" + + +def get_loaders(args, patch_shape=(512, 512)): + if args.distances: + label_trafo = torch_em.transform.label.PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=25 + ) + else: + label_trafo = None + + train_loader = get_livecell_loader( + path=args.input, + split="train", + patch_shape=patch_shape, + batch_size=2, + label_dtype=torch.float32, + boundaries=args.boundaries, + label_transform=label_trafo, + num_workers=16, + download=True, + ) + val_loader = get_livecell_loader( + path=args.input, + split="val", + patch_shape=patch_shape, + batch_size=1, + label_dtype=torch.float32, + boundaries=args.boundaries, + label_transform=label_trafo, + num_workers=16, + download=True, + ) + return train_loader, val_loader + + +def get_output_channels(args): + if args.boundaries: + output_channels = 2 + else: + output_channels = 3 + + return output_channels + + +def get_loss_function(args): + if args.distances: + loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + else: + loss = DiceLoss() + + return loss + + +def get_save_root(args): + # experiment_type + if args.boundaries: + experiment_type = "boundaries" + else: + experiment_type = "distances" + + # saving the model checkpoints + save_root = os.path.join(args.save_root, "scratch", experiment_type, args.model_type) + return save_root + + +def get_model(args, device): + output_channels = get_output_channels(args) + + if args.model_type == "unet": + # the UNet model + model = UNet2d( + in_channels=1, + out_channels=output_channels, + initial_features=64, + final_activation="Sigmoid", + sampler_impl=SingleDeconv2DBlock, + ) + else: + # the UNETR model + model = UNETR( + encoder=args.model_type, + out_channels=output_channels, + final_activation="Sigmoid", + use_skip_connection=False, + ) + model.to(device) + + return model + + +def run_livecell_unetr_training(args, device): + # the dataloaders for livecell dataset + train_loader, val_loader = get_loaders(args) + + model = get_model(args, device) + + save_root = get_save_root(args) + + # loss function + loss = get_loss_function(args) + + trainer = torch_em.default_segmentation_trainer( + name="livecell-unet" if args.model_type == "unet" else "livecell-unetr", + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10} + ) + + trainer.fit(int(1e5)) + + +def run_livecell_unetr_inference(args, device): + save_root = get_save_root(args) + + checkpoint = os.path.join( + save_root, + "checkpoints", + "livecell-unet" if args.model_type == "unet" else "livecell-unetr", + "best.pt" + ) + + model = get_model(args, device) + + assert os.path.exists(checkpoint), checkpoint + model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'))["model_state"]) + model.to(device) + model.eval() + + # the splits are provided with the livecell dataset + # to reproduce the results: + # run the inference on the entire dataset as it is. + test_image_dir = os.path.join(ROOT, "data", "livecell", "images", "livecell_test_images") + all_test_labels = glob(os.path.join(ROOT, "data", "livecell", "annotations", "livecell_test_images", "*", "*")) + + msa_list, sa50_list, sa75_list = [], [], [] + for label_path in tqdm(all_test_labels): + labels = imageio.imread(label_path) + image_id = os.path.split(label_path)[-1] + + image = imageio.imread(os.path.join(test_image_dir, image_id)) + image = standardize(image) + + predictions = predict_with_padding(model, image, min_divisible=(16, 16), device=device) + predictions = predictions.squeeze() + + if args.boundaries: + fg, bd = predictions + instances = segmentation.watershed_from_components(bd, fg) + else: + fg, cdist, bdist = predictions + instances = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) + + msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + sa75_list.append(sa_acc[5]) + + res = { + "LIVECell": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list), + "SA75": np.mean(sa75_list) + } + res_path = os.path.join(args.result_path, "results.csv") + df = pd.DataFrame.from_dict([res]) + df.to_csv(res_path) + print(df) + print(f"The result is saved at {res_path}") + + +def main(args): + assert (args.boundaries + args.distances) == 1 + + print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.train: + run_livecell_unetr_training(args, device) + + if args.predict: + run_livecell_unetr_inference(args, device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", "--input", type=str, default=os.path.join(ROOT, "data", "livecell"), help="Path to LIVECell dataset." + ) + parser.add_argument( + "-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved." + ) + parser.add_argument( + "-m", "--model_type", type=str, required=True, + help="Choice of encoder. Supported models are 'unet', 'vit_b', 'vit_l' and 'vit_h'." + ) + parser.add_argument( + "--train", action="store_true", help="Whether to train the model." + ) + parser.add_argument( + "--predict", action="store_true", help="WWhether to rn inference on the trained model." + ) + parser.add_argument( + "--result_path", type=str, default="./", help="Path to save quantitative results." + ) + parser.add_argument( + "--boundaries", action="store_true", help="Runs the boundary-based methods." + ) + parser.add_argument( + "--distances", action="store_true", help="Runs the distance-based methods." + ) + args = parser.parse_args() + main(args) diff --git a/experiments/vision-transformer/unetr/livecell/common.py b/experiments/vision-transformer/unetr/livecell/common.py index 567d474c..fca8b85e 100644 --- a/experiments/vision-transformer/unetr/livecell/common.py +++ b/experiments/vision-transformer/unetr/livecell/common.py @@ -147,15 +147,14 @@ def get_my_livecell_loaders( } -def get_unet_model(output_channels, use_conv_transpose): - from torch_em.model.unet import UNet2d, Upsampler2d - from torch_em.model.unetr import SingleDeconv2DBlock +def get_unet_model(output_channels): + from torch_em.model.unet import UNet2d + model = UNet2d( in_channels=1, out_channels=output_channels, initial_features=64, final_activation="Sigmoid", - sampler_impl=SingleDeconv2DBlock if use_conv_transpose else Upsampler2d ) return model @@ -166,7 +165,6 @@ def get_unetr_model( patch_shape: Tuple[int, int], sam_initialization: bool, output_channels: int, - use_conv_transpose: bool, backbone: str = "sam", ): """Returns the expected UNETR model @@ -181,7 +179,6 @@ def get_unetr_model( use_sam_stats=sam_initialization, final_activation="Sigmoid", encoder_checkpoint=MODELS[model_name] if sam_initialization else None, - use_conv_transpose=use_conv_transpose ) elif source_choice == "monai": @@ -468,9 +465,6 @@ def get_parser(): # this argument takes care of which ViT encoder to use for the UNETR (as ViTs from SAM and MAE are different) parser.add_argument("--pretrained_choice", type=str, default="sam") - parser.add_argument( - "--use_bilinear", action="store_true", help="Use bilinear interpolation for upsampling." - ) return parser diff --git a/experiments/vision-transformer/unetr/livecell/train_livecell.py b/experiments/vision-transformer/unetr/livecell/train_livecell.py index 9c6c9262..25795853 100644 --- a/experiments/vision-transformer/unetr/livecell/train_livecell.py +++ b/experiments/vision-transformer/unetr/livecell/train_livecell.py @@ -12,10 +12,6 @@ def main(args): patch_shape = tuple(args.patch_shape) # patch size used for training on livecell _name = args.model_name if not args.use_unet else "unet" - if args.use_bilinear: - _name += "-bilinear" - else: - _name += "-conv-transpose" # directory folder to save different parts of the scheme dir_structure = os.path.join( @@ -27,10 +23,7 @@ def main(args): loss = common.get_loss_function(args.experiment_name) if args.use_unet: - model = common.get_unet_model( - output_channels=common.get_output_channels(args.experiment_name), - use_conv_transpose=not args.use_bilinear - ) + model = common.get_unet_model(output_channels=common.get_output_channels(args.experiment_name)) _store_model_name = "unet" else: # get the unetr model for the training and inference on livecell dataset @@ -40,7 +33,6 @@ def main(args): patch_shape=patch_shape, sam_initialization=args.do_sam_ini, output_channels=common.get_output_channels(args.experiment_name), - use_conv_transpose=not args.use_bilinear ) _store_model_name = "unetr" model.to(device) @@ -64,16 +56,27 @@ def main(args): ) common.do_unetr_training( - train_loader=train_loader, val_loader=val_loader, model=model, loss=loss, - device=device, save_root=save_root, iterations=args.iterations, name=f"livecell-{_store_model_name}" + train_loader=train_loader, + val_loader=val_loader, + model=model, + loss=loss, + device=device, + save_root=save_root, + iterations=args.iterations, + name=f"livecell-{_store_model_name}", ) if args.predict: print(f"2d {_store_model_name.upper()} inference (with {args.experiment_name}) on LiveCELL...") common.do_unetr_inference( - input_path=args.input, device=device, model=model, save_root=save_root, - root_save_dir=root_save_dir, experiment_name=args.experiment_name, - input_norm=not args.do_sam_ini, name_extension=f"livecell-{_store_model_name}" + input_path=args.input, + device=device, + model=model, + save_root=save_root, + root_save_dir=root_save_dir, + experiment_name=args.experiment_name, + input_norm=not args.do_sam_ini, + name_extension=f"livecell-{_store_model_name}" ) print("Predictions are saved in", root_save_dir) diff --git a/torch_em/data/datasets/neurips_cell_seg.py b/torch_em/data/datasets/neurips_cell_seg.py index b8acb256..e3ae7694 100644 --- a/torch_em/data/datasets/neurips_cell_seg.py +++ b/torch_em/data/datasets/neurips_cell_seg.py @@ -1,31 +1,41 @@ -import json import os +import numpy as np from glob import glob +from typing import Union, Tuple, Any, Optional -import numpy as np import torch + import torch_em +from . import util +from .. import ImageCollectionDataset, RawImageCollectionDataset, ConcatDataset + + +URL = { + "train": "https://zenodo.org/records/10719375/files/Training-labeled.zip", + "val": "https://zenodo.org/records/10719375/files/Tuning.zip", + "test": "https://zenodo.org/records/10719375/files/Testing.zip", + "unlabeled": "https://zenodo.org/records/10719375/files/train-unlabeled-part1.zip", + "unlabeled_wsi": "https://zenodo.org/records/10719375/files/train-unlabeled-part2.zip" +} + +CHECKSUM = { + "train": "b2383929eb8e99b2716fa0d4e2f6e03983e626a57cf00fe85175869c54aa3592", + "val": "849423d36bb8fcc2d91a5b189a3b6d93c3d4071c9701eaaa44ba393a510459c4", + "test": "3379730221f43830d30fddf131750e967c9c9bdf04f98811e852a050eb659ccc", + "unlabeled": "390b38b398b05e9e5306a024a3bd48ab22e49592cfab3c1a119eab3636b38e0d", + "unlabeled_wsi": "d1e68eba2918305eab8b846e7578ac14683de970e3fa6a7c2a4a55753be56204" +} -"""TODO: refactor the loader based on the updated data structure -- Training - - images (multi-modal training inputs) - - labels - - unlabeled (WSI) -- Tuning - - images (multi-modal tuning inputs) - - labels -- Testing - - Public - - images (multi-modal testing inputs) - - labels - - WSI (whole-slide testing inputs) - - WSI-labels - - * (results from `osilab` - ranked 1st in the challenge) - - Hidden - - images (multi-modal hidden testing inputs - unlabeled) - - * (results from `osilab` - ranked 1st in the challenge) -""" -URL = "https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV" + +DIR_NAMES = { + "train": "Training-labeled", "val": "Tuning", "test": "Testing/Public", + "unlabeled": "release-part1", "unlabeled_wsi": "train-unlabeled-part2" +} + +ZIP_PATH = { + "train": "Training-labeled.zip", "val": "Tuning.zip", "test": "Testing.zip", + "unlabeled": "train-unlabeled-part1.zip", "unlabeled_wsi": "train-unlabeled-part2.zip" +} def to_rgb(image): @@ -40,12 +50,21 @@ def to_rgb(image): return image -# would be better to make balanced splits for the different data modalities -# (but we would need to know mapping of images to modality) -def _get_image_and_label_paths(root, split, val_fraction): - path = os.path.join(root, "TrainLabeled") - assert os.path.exists(root), "Please download the dataset and assort the data as expected here.\ - See `get_neurips_cellseg_supervised_dataset`" +def _download_dataset(root, split, download): + os.makedirs(root, exist_ok=True) + + target_dir = os.path.join(root, DIR_NAMES[split]) + zip_path = os.path.join(root, ZIP_PATH[split]) + + if not os.path.exists(target_dir): + util.download_source(path=zip_path, url=URL[split], download=download, checksum=CHECKSUM[split]) + util.unzip(zip_path=zip_path, dst=root) + + return target_dir + + +def _get_image_and_label_paths(root, split, download): + path = _download_dataset(root, split, download) image_folder = os.path.join(path, "images") assert os.path.exists(image_folder) @@ -58,60 +77,29 @@ def _get_image_and_label_paths(root, split, val_fraction): all_label_paths.sort() assert len(all_image_paths) == len(all_label_paths) - if split is None: - return all_image_paths, all_label_paths - - split_file = os.path.join( - os.path.split(__file__)[0], f"split_{val_fraction}.json" - ) - - if os.path.exists(split_file): - with open(split_file) as f: - split_ids = json.load(f)[split] - else: - # split into training and val images - n_images = len(all_image_paths) - n_train = int((1.0 - val_fraction) * n_images) - image_ids = list(range(n_images)) - np.random.shuffle(image_ids) - train_ids, val_ids = image_ids[:n_train], image_ids[n_train:] - assert len(train_ids) + len(val_ids) == n_images - - with open(split_file, "w") as f: - json.dump({"train": train_ids, "val": val_ids}, f) - - split_ids = val_ids if split == "val" else train_ids - - image_paths = [all_image_paths[idx] for idx in split_ids] - label_paths = [all_label_paths[idx] for idx in split_ids] - assert len(image_paths) == len(label_paths) - return image_paths, label_paths + return all_image_paths, all_label_paths def get_neurips_cellseg_supervised_dataset( - root, split, patch_shape, - make_rgb=True, - label_transform=None, - label_transform2=None, - raw_transform=None, - transform=None, - label_dtype=torch.float32, - n_samples=None, - sampler=None, - val_fraction=0.1, + root: Union[str, os.PathLike], + split: str, + patch_shape: Tuple[int, int], + make_rgb: bool = True, + label_transform: Optional[Any] = None, + label_transform2: Optional[Any] = None, + raw_transform: Optional[Any] = None, + transform: Optional[Any] = None, + label_dtype: torch.dtype = torch.float32, + n_samples: Optional[int] = None, + sampler: Optional[Any] = None, + download: bool = False, ): """Dataset for the segmentation of cells in light microscopy. - This dataset is part of the NeuRIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. - - NOTE: - - The dataset isn't available to download using an in-built functionality - - Please download the dataset from here:\ - https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV - - REMEMBER: to convert the available data in the expected directory format + This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. """ - assert split in ("train", "val", None), split - image_paths, label_paths = _get_image_and_label_paths(root, split, val_fraction) + assert split in ("train", "val", "test"), split + image_paths, label_paths = _get_image_and_label_paths(root, split, download) if raw_transform is None: trafo = to_rgb if make_rgb else None @@ -119,54 +107,64 @@ def get_neurips_cellseg_supervised_dataset( if transform is None: transform = torch_em.transform.get_augmentations(ndim=2) - ds = torch_em.data.ImageCollectionDataset(image_paths, label_paths, - patch_shape=patch_shape, - raw_transform=raw_transform, - label_transform=label_transform, - label_transform2=label_transform2, - label_dtype=label_dtype, - transform=transform, - n_samples=n_samples, - sampler=sampler) + ds = ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=label_paths, + patch_shape=patch_shape, + raw_transform=raw_transform, + label_transform=label_transform, + label_transform2=label_transform2, + label_dtype=label_dtype, + transform=transform, + n_samples=n_samples, + sampler=sampler + ) return ds def get_neurips_cellseg_supervised_loader( - root, split, - patch_shape, batch_size, - make_rgb=True, - label_transform=None, - label_transform2=None, - raw_transform=None, - transform=None, - label_dtype=torch.float32, - n_samples=None, - sampler=None, - val_fraction=0.1, + root: Union[str, os.PathLike], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + make_rgb: bool = True, + label_transform: Optional[Any] = None, + label_transform2: Optional[Any] = None, + raw_transform: Optional[Any] = None, + transform: Optional[Any] = None, + label_dtype: torch.dtype = torch.float32, + n_samples: Optional[Any] = None, + sampler: Optional[Any] = None, + download: bool = False, **loader_kwargs ): """Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_supervised_dataset`.""" ds = get_neurips_cellseg_supervised_dataset( - root, split, patch_shape, make_rgb=make_rgb, label_transform=label_transform, - label_transform2=label_transform2, raw_transform=raw_transform, transform=transform, - label_dtype=label_dtype, n_samples=n_samples, sampler=sampler, val_fraction=val_fraction, + root=root, + split=split, + patch_shape=patch_shape, + make_rgb=make_rgb, + label_transform=label_transform, + label_transform2=label_transform2, + raw_transform=raw_transform, + transform=transform, + label_dtype=label_dtype, + n_samples=n_samples, + sampler=sampler, + download=download ) return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs) -def _get_image_paths(root): - path = os.path.join(root, "TrainUnlabeled") - assert os.path.exists(path), "Please download the dataset and assort the data as expected here.\ - See `get_neurips_cellseg_unsupervised_dataset`" +def _get_image_paths(root, download): + path = _download_dataset(root, "unlabeled", download) image_paths = glob(os.path.join(path, "*")) image_paths.sort() return image_paths -def _get_wholeslide_paths(root, patch_shape): - path = os.path.join(root, "TrainUnlabeled_WholeSlide") - assert os.path.exists(path), "Please download the dataset and assort the data as expected here.\ - See `get_neurips_cellseg_unsupervised_dataset`" +def _get_wholeslide_paths(root, patch_shape, download): + path = _download_dataset(root, "unlabeled_wsi", download) image_paths = glob(os.path.join(path, "*")) image_paths.sort() @@ -185,24 +183,20 @@ def _get_wholeslide_paths(root, patch_shape): def get_neurips_cellseg_unsupervised_dataset( - root, patch_shape, - make_rgb=True, - raw_transform=None, - transform=None, - dtype=torch.float32, - sampler=None, - use_images=True, - use_wholeslide=True, + root: Union[str, os.PathLike], + patch_shape: Tuple[int, int], + make_rgb: bool = True, + raw_transform: Optional[Any] = None, + transform: Optional[Any] = None, + dtype: torch.dtype = torch.float32, + sampler: Optional[Any] = None, + use_images: bool = True, + use_wholeslide: bool = True, + download: bool = False, ): """Dataset for the segmentation of cells in light microscopy. - This dataset is part of the NeuRIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. - - NOTE: - - The dataset isn't available to download using an in-built functionality - - Please download the dataset from here:\ - https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV - - REMEMBER: to convert the available data in the expected directory format + This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. """ if raw_transform is None: trafo = to_rgb if make_rgb else None @@ -212,40 +206,52 @@ def get_neurips_cellseg_unsupervised_dataset( datasets = [] if use_images: - image_paths = _get_image_paths(root) - datasets.append(torch_em.data.RawImageCollectionDataset(image_paths, - patch_shape=patch_shape, - raw_transform=raw_transform, - transform=transform, - dtype=dtype, - sampler=sampler)) + image_paths = _get_image_paths(root, download) + datasets.append( + RawImageCollectionDataset( + raw_image_paths=image_paths, + patch_shape=patch_shape, + raw_transform=raw_transform, + transform=transform, + dtype=dtype, + sampler=sampler + ) + ) if use_wholeslide: - image_paths, n_samples = _get_wholeslide_paths(root, patch_shape) - datasets.append(torch_em.data.RawImageCollectionDataset(image_paths, - patch_shape=patch_shape, - raw_transform=raw_transform, - transform=transform, - dtype=dtype, - n_samples=n_samples, - sampler=sampler)) + image_paths, n_samples = _get_wholeslide_paths(root, patch_shape, download) + datasets.append( + RawImageCollectionDataset( + raw_image_paths=image_paths, + patch_shape=patch_shape, + raw_transform=raw_transform, + transform=transform, + dtype=dtype, + n_samples=n_samples, + sampler=sampler + ) + ) assert len(datasets) > 0 - return torch.utils.data.ConcatDataset(datasets) + return ConcatDataset(*datasets) def get_neurips_cellseg_unsupervised_loader( - root, patch_shape, batch_size, - make_rgb=True, - raw_transform=None, - transform=None, - dtype=torch.float32, - sampler=None, - use_images=True, - use_wholeslide=True, + root: Union[str, os.PathLike], + patch_shape: Tuple[int, int], + batch_size: int, + make_rgb: bool = True, + raw_transform: Optional[Any] = None, + transform: Optional[Any] = None, + dtype: torch.dtype = torch.float32, + sampler: Optional[Any] = None, + use_images: bool = True, + use_wholeslide: bool = True, + download: bool = False, **loader_kwargs, ): - """Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_unsupervised_dataset`.""" + """Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_unsupervised_dataset`. + """ ds = get_neurips_cellseg_unsupervised_dataset( - root, patch_shape, make_rgb=make_rgb, raw_transform=raw_transform, transform=transform, - dtype=dtype, sampler=sampler, use_images=use_images, use_wholeslide=use_wholeslide + root=root, patch_shape=patch_shape, make_rgb=make_rgb, raw_transform=raw_transform, transform=transform, + dtype=dtype, sampler=sampler, use_images=use_images, use_wholeslide=use_wholeslide, download=download ) return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs) diff --git a/torch_em/model/__init__.py b/torch_em/model/__init__.py index 672abd95..d6dd8755 100644 --- a/torch_em/model/__init__.py +++ b/torch_em/model/__init__.py @@ -2,3 +2,4 @@ from .probabilistic_unet import ProbabilisticUNet from .unetr import UNETR from .vit import get_vision_transformer +from .vim import get_vimunet_model diff --git a/torch_em/model/vim.py b/torch_em/model/vim.py new file mode 100644 index 00000000..2cf97991 --- /dev/null +++ b/torch_em/model/vim.py @@ -0,0 +1,205 @@ +# installation from https://github.com/hustvl/Vim +# encoder from https://github.com/hustvl/Vim +# decoder from https://github.com/constantinpape/torch-em + +# pretrained model weights: vim_t - https://huggingface.co/hustvl/Vim-tiny/blob/main/vim_tiny_73p1.pth + +import torch + +from .unetr import UNETR + +try: + from vim.models_mamba import VisionMamba, rms_norm_fn, RMSNorm, layer_norm_fn + _have_vim_installed = True +except ImportError: + VisionMamba = object + rms_norm_fn = RMSNorm = layer_norm_fn = None + _have_vim_installed = False + +try: + from timm.models.vision_transformer import _cfg +except ImportError: + _cfg = None + + +class ViM(VisionMamba): + def __init__( + self, + **kwargs + ): + assert _have_vim_installed, "Please install Vim." + super().__init__(**kwargs) + + def convert_to_expected_dim(self, inputs_): + # reshape the outputs to desired shape (N x H*W X C -> N x H x W x C) + rdim = inputs_.shape[1] + dshape = int(rdim ** 0.5) # finding the square root of the outputs for obtaining the patch shape + inputs_ = torch.unflatten(inputs_, 1, (dshape, dshape)) + inputs_ = inputs_.permute(0, 3, 1, 2) + return inputs_ + + def forward_features(self, x, inference_params=None): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add the dist_token + x = self.patch_embed(x) + if self.if_cls_token: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + + if self.if_abs_pos_embed: + x = x + self.pos_embed + x = self.pos_drop(x) + + # mamba implementation + residual = None + hidden_states = x + for layer in self.layers: + # rope about + if self.if_rope: + hidden_states = self.rope(hidden_states) + if residual is not None and self.if_rope_residual: + residual = self.rope(residual) + + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + + if not self.fused_add_norm: + if residual is None: + residual = hidden_states + else: + residual = residual + self.drop_path(hidden_states) + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm = False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + self.drop_path(hidden_states), + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + + if self.final_pool_type == 'none': + return hidden_states[:, -1, :] + elif self.final_pool_type == 'mean': + return hidden_states.mean(dim=1) + elif self.final_pool_type == 'max': + return hidden_states.max(dim=1) + elif self.final_pool_type == 'all': + return hidden_states + else: + raise NotImplementedError + + def forward(self, x, inference_params=None): + x = self.forward_features(x, inference_params) + + if self.if_cls_token: # remove the class token + x = x[:, 1:, :] + + # let's get the patches back from the 1d tokens + x = self.convert_to_expected_dim(x) + + return x # from here, the tokens can be upsampled easily (N x H x W x C) + + +def get_vim_encoder(model_type="vim_t", with_cls_token=True): + if model_type == "vim_t": + # `vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual_with_cls_token` + # *has an imagenet pretrained model + encoder = ViM( + img_size=1024, + patch_size=16, + embed_dim=192, + depth=24, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + final_pool_type='all', + if_abs_pos_embed=True, + if_rope=True, + if_rope_residual=True, + bimamba_type="v2", + if_cls_token=with_cls_token, + ) + elif model_type == "vim_s": + # `vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` + # AA: added a class token to the default models + encoder = ViM( + img_size=1024, + patch_size=16, + embed_dim=384, + depth=24, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + final_pool_type='all', + if_abs_pos_embed=True, + if_rope=True, + if_rope_residual=True, + bimamba_type="v2", + if_cls_token=with_cls_token, + ) + elif model_type == "vim_b": + # `vim_base_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_rope_also_residual` + # AA: added a class token to the default models + encoder = ViM( + img_size=1024, + patch_size=16, + embed_dim=768, + depth=24, + rms_norm=True, + residual_in_fp32=True, + fused_add_norm=True, + final_pool_type='all', + if_abs_pos_embed=True, + if_rope=True, + if_rope_residual=True, + bimamba_type="v2", + if_cls_token=with_cls_token, + ) + else: + raise ValueError("Choose from 'vim_t' / 'vim_s' / 'vim_b'") + + encoder.default_cfg = _cfg() + return encoder + + +def get_vimunet_model( + out_channels, model_type="vim_t", with_cls_token=True, device=None, checkpoint=None +): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + encoder = get_vim_encoder(model_type, with_cls_token) + + model_state = None + if checkpoint is not None: + state = torch.load(checkpoint, map_location="cpu") + + if checkpoint.endswith(".pth"): # from Vim + encoder_state = state["model"] + encoder.load_state_dict(encoder_state) + + else: # from torch_em + model_state = state["model_state"] + + encoder.img_size = encoder.patch_embed.img_size[0] + + model = UNETR( + encoder=encoder, + out_channels=out_channels, + resize_input=False, + use_skip_connection=False, + final_activation="Sigmoid", + ) + + if model_state is not None: + model.load_state_dict(model_state) + + model.to(device) + + return model diff --git a/vimunet.md b/vimunet.md index 2e1b7354..86d03030 100644 --- a/vimunet.md +++ b/vimunet.md @@ -1,3 +1,80 @@ -# ViM-UNet: Vision Mamba-based UNet for Biomedical Image Segmentation +# ViM-UNet: Vision Mamba in Biomedical Segmentation* -TODO +We introduce **ViM-UNet**. a novel segmentation architecture based on Vision Mamba for instance segmentation in microscopy. + +This is the documentation for the installation instructions, known issues and linked suggestions, benchmarking scripts, and link to the tutorial notebook. + +## TLDR +1. Please install [`torch-em`](https://github.com/constantinpape/torch-em) and `ViM` (based on our fork: https://github.com/anwai98/Vim) +2. Supports `ViM Tiny` and `ViM Small` for 2d segmentation using ViM-UNet. +3. *More details on the preprint coming soon. + - Our observations: "ViM-UNet performs similarly or better that UNet (depending on the task), and outperforms UNETR while being more efficient." + +## Benchmarking Methods + +### Re-implemented methods in `torch-em`: +1. [ViM-UNet]() +2. [UNet]() +3. [UNETR]() + +### External methods: + +> [Here](https://github.com/anwai98/vimunet-benchmarking) are the scripts to run the benchmarking for the aforementioned external methods. + +1. nnU-Net (see [here](https://github.com/MIC-DKFZ/nnUNet) for installation instructions) +2. U-Mamba (see [here](https://github.com/bowang-lab/U-Mamba#installation) for installation instructions, and [issues]() encountered with our suggestions to take care of them) + +## Installation + +### For ViM-UNet: +1. Create a new environment and activate it: +```bash +$ mamba create -n vimunet python=3.10.13 +$ mamba activate vimunet +``` +2. Install `torch-em` [from source](https://github.com/constantinpape/torch-em#from-source). + +3. Install `PyTorch`: +```bash +$ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 +``` +> Q1. Why use `pip`? - for installation consistency + +> Q2. Why choose CUDA 11.8? - Vim seems to prefer $\le$ 11.8 ([hint](https://github.com/hustvl/Vim/issues/51)) + +4. Install `ViM` and related dependencies (`causal-conv1d`\**, `mamba`, `Vim`\***): +```bash +$ git clone https://github.com/anwai98/Vim.git +$ cd Vim +$ pip install -r vim/vim_requirements.txt +$ pip install -e causal_conv1d/ +$ pip install -e mamba/ +$ pip install -e . +``` + +> NOTE: The installation is sometimes a bit tricky, but following the steps and keeping the footnotes in mind should do the trick. + +### For UNet and UNETR + +1. Install `torch-em` [from source](https://github.com/constantinpape/torch-em#from-source). +2. Install `segment-anything` [from source](https://github.com/facebookresearch/segment-anything#installation). + + +## Known Issues and Suggestions +- `GLIBCXX_` related issues: + - Suggestion: Specify your path to the mamba environment to `LD_LIBRARY_PATH`. For example, + ```bash + $ export LD_LIBRARY_PATH=/scratch/usr/nimanwai/mambaforge/lib/ + ``` + +- `FileNotFoundError: [Error 2] No such file or directory: 'ldconfig'`: + - Suggestion: Possible reason is that the path variable isn't set correctly. I found this [here](https://unix.stackexchange.com/questions/160019/dpkg-cannot-find-ldconfig-start-stop-daemon-in-the-path-variable) quite useful. You can provide it as the following example: + ```bash + $ export PATH=$PATH:/usr/sbin # it could also be located at /usr/bin, etc. please check your system configurations for this. + ``` + +- **`NameError: name 'bare_metal_version' is not defined` while installing `causal-conv1d`: + - Suggestion: This one's a bit tricky. From our findings, the possible issue is that the path to `CUDA_HOME` isn't visible to the installed PyTorch. The quickest way to test this is: `python -c "from torch.utils.cpp_extension import CUDA_HOME; print(CUDA_HOME)"`. It's often stored at `/usr/local/cuda`, hence to expose the path, here's the example script: `export CUDA_HOME=/usr/local/cuda`. + > NOTE: If you are using your cluster's cuda installation and not sure where is it located, this should do the trick: `module show cuda/$VERSION` + +- ***Remember to install the suggested `ViM` branch for installation. It's important as we enable a few changes to: a) automatically install the vision mamba as a developer module, and b) setting AMP to false for known issues (see [mention 1](https://github.com/hustvl/Vim/issues/30) and [mention 2](https://github.com/bowang-lab/U-Mamba/issues/8) for hints) \ No newline at end of file