Skip to content

Commit

Permalink
Add ViM-UNet (#236)
Browse files Browse the repository at this point in the history
Add ViM-UNet implementation and experiments
  • Loading branch information
anwai98 committed Apr 12, 2024
1 parent f91d20d commit 3323975
Show file tree
Hide file tree
Showing 15 changed files with 1,494 additions and 264 deletions.
110 changes: 110 additions & 0 deletions experiments/misc/get_vimunet_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter


LIVECELL_RESULTS = {
"UNet": {"boundaries": 0.372, "distances": 0.429},
r"UNETR$_{Base}$": {"boundaries": 0.11, "distances": 0.145},
r"UNETR$_{Large}$": {"boundaries": 0.171, "distances": 0.157},
r"UNETR$_{Huge}$": {"boundaries": 0.216, "distances": 0.136},
r"nnUNet$_{v2}$": {"boundaries": 0.228},
r"UMamba$_{Bot}$": {"boundaries": 0.234},
r"UMamba$_{Enc}$": {"boundaries": 0.23},
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.269, "distances": 0.381},
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.274, "distances": 0.397},
}

CREMI_RESULTS = {
"UNet": {"boundaries": 0.354},
r"UNETR$_{Base}$": {"boundaries": 0.285},
r"UNETR$_{Large}$": {"boundaries": 0.325},
r"UNETR$_{Huge}$": {"boundaries": 0.324},
r"nnUNet$_{v2}$": {"boundaries": 0.452},
r"UMamba$_{Bot}$": {"boundaries": 0.471},
r"UMamba$_{Enc}$": {"boundaries": 0.467},
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.518},
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.53},
}

DATASET_MAPPING = {
"livecell": "LIVECell",
"cremi": "CREMI"
}

plt.rcParams["font.size"] = 24


def plot_per_dataset(dataset_name):
if dataset_name == "livecell":
results = LIVECELL_RESULTS
else:
results = CREMI_RESULTS

models = list(results.keys())
metrics = list(results[models[0]].keys())

markers = ['^', '*']

fig, ax = plt.subplots(figsize=(15, 12))

x_pos = np.arange(len(models))

bar_width = 0.05

for i, metric in enumerate(metrics):
scores_list = []
for model in models:
try:
score = results[model][metric]
except KeyError:
score = None

scores_list.append(score)

ax.scatter(x_pos + i * bar_width - bar_width, scores_list, s=250, label=metric, marker=markers[i])

ax.set_xticks(x_pos)
ax.set_xticklabels(models, va='top', ha='center', rotation=45)

ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
if dataset_name == "cremi":
ax.set_yticks(np.linspace(0, 0.5, 11)[1:])
else:
ax.set_yticks(np.linspace(0, 0.4, 9)[1:])

ax.set_ylabel('Segmentation Accuracy', labelpad=15)
ax.set_xlabel(None)
ax.set_title(DATASET_MAPPING[dataset_name], fontsize=32, y=1.025)
ax.set_ylim(0)
ax.legend(loc='lower center', fancybox=True, shadow=True, ncol=2)

best_models = sorted(models, key=lambda x: max(results[x].values()), reverse=True)[:3]
sizes = [100, 70, 40]
for size, best_model in zip(sizes, best_models):
best_scores = [results[best_model].get(metric, 0) for metric in metrics]
best_index = models.index(best_model)

# HACK
offset = 0 if dataset_name == "livecell" else 0.05

ax.plot(
best_index - offset, max(best_scores), marker='o', markersize=size, linestyle='dotted',
markerfacecolor='gray', markeredgecolor='black', markeredgewidth=2, alpha=0.2
)

plt.tight_layout()
plt.show()
plt.savefig(f"{dataset_name}.png")
plt.savefig(f"{dataset_name}.svg", transparent=True)
plt.savefig(f"{dataset_name}.pdf")


def main():
plot_per_dataset("livecell")
plot_per_dataset("cremi")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions experiments/vision-mamba/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.out
*.sh
*.png
31 changes: 31 additions & 0 deletions experiments/vision-mamba/vimunet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ViM-UNet: Vision Mamba in Biomedical Segmentation

We introduce **ViM-UNet**. a novel segmentation architecture based on Vision Mamba for instance segmentation in microscopy.

To get started, make sure to take a look at the [documentation](https://github.com/constantinpape/torch-em/blob/main/vimunet.md).

Here are the experiments for instance segmentation on:
1. LIVECell for cell segmentation in phase-contrast microscopy.
- You can run the boundary-based /distance-based experiments. See `run_livecell.py -h` for details.
```python
python run_livecell.py -i <PATH_TO_DATA>
-s <PATH_TO_SAVE_CHECKPOINTS>
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b'
--train # for training
--predict # for inference on trained models
--result_path <PATH_TO_SAVE_RESULTS>
# below is how you can provide the choice for training for either methods
--boundaries / --distances
```

2. CREMI for neurites segmentation in electron microscopy.
- You can run the boundary-based experiment. See `run_livecell.py -h` for details. Below is an example script:
```python
python run_cremi.py -i <PATH_TO_DATA>
-s <PATH_TO_SAVE_CHECKPOINTS>
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b'
--train # for training
--predict # for inference on trained models
--result_path <PATH_TO_SAVE_RESULTS>
```

176 changes: 176 additions & 0 deletions experiments/vision-mamba/vimunet/run_cremi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import os
import argparse
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import imageio.v3 as imageio

import torch

import torch_em
from torch_em.loss import DiceLoss
from torch_em.util import segmentation
from torch_em.data import MinInstanceSampler
from torch_em.model import get_vimunet_model
from torch_em.data.datasets import get_cremi_loader
from torch_em.util.prediction import predict_with_halo

from elf.evaluation import mean_segmentation_accuracy


ROOT = "/scratch/usr/nimanwai"

# the splits have been customed made
# to reproduce the results:
# extract slices ranging from "100 to 125" for all three volumes
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original"


def get_loaders(args, patch_shape=(1, 512, 512)):
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

sampler = MinInstanceSampler()

train_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=2,
rois=train_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
val_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=1,
rois=val_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
return train_loader, val_loader


def run_cremi_training(args):
# the dataloaders for cremi dataset
train_loader, val_loader = get_loaders(args)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True
)

save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)

# loss function
loss = DiceLoss()

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
name="cremi-vimunet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
)
trainer.fit(iterations=int(1e5))


def run_cremi_inference(args, device):
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)
checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt")

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint
)

all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif"))
all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif"))

msa_list, sa50_list, sa75_list = [], [], []
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)):
image = imageio.imread(image_path)
labels = imageio.imread(label_path)

predictions = predict_with_halo(
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True,
)

bd = predictions.squeeze()
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True)
msa_list.append(msa)
sa50_list.append(sa_acc[0])
sa75_list.append(sa_acc[5])

res = {
"CREMI": "Metrics",
"mSA": np.mean(msa_list),
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}
res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
print(df)
print(f"The result is saved at {res_path}")


def main(args):
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")
device = "cuda" if torch.cuda.is_available() else "cpu"

if args.train:
run_cremi_training(args)

if args.predict:
run_cremi_inference(args, device)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"), help="Path to CREMI dataset."
)
parser.add_argument(
"-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved."
)
parser.add_argument(
"-m", "--model_type", type=str, default="vim_t", help="Choice of ViM backbone"
)
parser.add_argument(
"--train", action="store_true", help="Whether to train the model."
)
parser.add_argument(
"--predict", action="store_true", help="Whether to run inference on the trained model."
)
parser.add_argument(
"--result_path", type=str, default="./", help="Path to save quantitative results."
)
args = parser.parse_args()
main(args)
Loading

0 comments on commit 3323975

Please sign in to comment.