-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ViM-UNet implementation and experiments
- Loading branch information
Showing
15 changed files
with
1,494 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import numpy as np | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.ticker import FormatStrFormatter | ||
|
||
|
||
LIVECELL_RESULTS = { | ||
"UNet": {"boundaries": 0.372, "distances": 0.429}, | ||
r"UNETR$_{Base}$": {"boundaries": 0.11, "distances": 0.145}, | ||
r"UNETR$_{Large}$": {"boundaries": 0.171, "distances": 0.157}, | ||
r"UNETR$_{Huge}$": {"boundaries": 0.216, "distances": 0.136}, | ||
r"nnUNet$_{v2}$": {"boundaries": 0.228}, | ||
r"UMamba$_{Bot}$": {"boundaries": 0.234}, | ||
r"UMamba$_{Enc}$": {"boundaries": 0.23}, | ||
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.269, "distances": 0.381}, | ||
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.274, "distances": 0.397}, | ||
} | ||
|
||
CREMI_RESULTS = { | ||
"UNet": {"boundaries": 0.354}, | ||
r"UNETR$_{Base}$": {"boundaries": 0.285}, | ||
r"UNETR$_{Large}$": {"boundaries": 0.325}, | ||
r"UNETR$_{Huge}$": {"boundaries": 0.324}, | ||
r"nnUNet$_{v2}$": {"boundaries": 0.452}, | ||
r"UMamba$_{Bot}$": {"boundaries": 0.471}, | ||
r"UMamba$_{Enc}$": {"boundaries": 0.467}, | ||
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.518}, | ||
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.53}, | ||
} | ||
|
||
DATASET_MAPPING = { | ||
"livecell": "LIVECell", | ||
"cremi": "CREMI" | ||
} | ||
|
||
plt.rcParams["font.size"] = 24 | ||
|
||
|
||
def plot_per_dataset(dataset_name): | ||
if dataset_name == "livecell": | ||
results = LIVECELL_RESULTS | ||
else: | ||
results = CREMI_RESULTS | ||
|
||
models = list(results.keys()) | ||
metrics = list(results[models[0]].keys()) | ||
|
||
markers = ['^', '*'] | ||
|
||
fig, ax = plt.subplots(figsize=(15, 12)) | ||
|
||
x_pos = np.arange(len(models)) | ||
|
||
bar_width = 0.05 | ||
|
||
for i, metric in enumerate(metrics): | ||
scores_list = [] | ||
for model in models: | ||
try: | ||
score = results[model][metric] | ||
except KeyError: | ||
score = None | ||
|
||
scores_list.append(score) | ||
|
||
ax.scatter(x_pos + i * bar_width - bar_width, scores_list, s=250, label=metric, marker=markers[i]) | ||
|
||
ax.set_xticks(x_pos) | ||
ax.set_xticklabels(models, va='top', ha='center', rotation=45) | ||
|
||
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) | ||
if dataset_name == "cremi": | ||
ax.set_yticks(np.linspace(0, 0.5, 11)[1:]) | ||
else: | ||
ax.set_yticks(np.linspace(0, 0.4, 9)[1:]) | ||
|
||
ax.set_ylabel('Segmentation Accuracy', labelpad=15) | ||
ax.set_xlabel(None) | ||
ax.set_title(DATASET_MAPPING[dataset_name], fontsize=32, y=1.025) | ||
ax.set_ylim(0) | ||
ax.legend(loc='lower center', fancybox=True, shadow=True, ncol=2) | ||
|
||
best_models = sorted(models, key=lambda x: max(results[x].values()), reverse=True)[:3] | ||
sizes = [100, 70, 40] | ||
for size, best_model in zip(sizes, best_models): | ||
best_scores = [results[best_model].get(metric, 0) for metric in metrics] | ||
best_index = models.index(best_model) | ||
|
||
# HACK | ||
offset = 0 if dataset_name == "livecell" else 0.05 | ||
|
||
ax.plot( | ||
best_index - offset, max(best_scores), marker='o', markersize=size, linestyle='dotted', | ||
markerfacecolor='gray', markeredgecolor='black', markeredgewidth=2, alpha=0.2 | ||
) | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
plt.savefig(f"{dataset_name}.png") | ||
plt.savefig(f"{dataset_name}.svg", transparent=True) | ||
plt.savefig(f"{dataset_name}.pdf") | ||
|
||
|
||
def main(): | ||
plot_per_dataset("livecell") | ||
plot_per_dataset("cremi") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*.out | ||
*.sh | ||
*.png |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# ViM-UNet: Vision Mamba in Biomedical Segmentation | ||
|
||
We introduce **ViM-UNet**. a novel segmentation architecture based on Vision Mamba for instance segmentation in microscopy. | ||
|
||
To get started, make sure to take a look at the [documentation](https://github.com/constantinpape/torch-em/blob/main/vimunet.md). | ||
|
||
Here are the experiments for instance segmentation on: | ||
1. LIVECell for cell segmentation in phase-contrast microscopy. | ||
- You can run the boundary-based /distance-based experiments. See `run_livecell.py -h` for details. | ||
```python | ||
python run_livecell.py -i <PATH_TO_DATA> | ||
-s <PATH_TO_SAVE_CHECKPOINTS> | ||
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b' | ||
--train # for training | ||
--predict # for inference on trained models | ||
--result_path <PATH_TO_SAVE_RESULTS> | ||
# below is how you can provide the choice for training for either methods | ||
--boundaries / --distances | ||
``` | ||
|
||
2. CREMI for neurites segmentation in electron microscopy. | ||
- You can run the boundary-based experiment. See `run_livecell.py -h` for details. Below is an example script: | ||
```python | ||
python run_cremi.py -i <PATH_TO_DATA> | ||
-s <PATH_TO_SAVE_CHECKPOINTS> | ||
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b' | ||
--train # for training | ||
--predict # for inference on trained models | ||
--result_path <PATH_TO_SAVE_RESULTS> | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
from glob import glob | ||
from tqdm import tqdm | ||
|
||
import imageio.v3 as imageio | ||
|
||
import torch | ||
|
||
import torch_em | ||
from torch_em.loss import DiceLoss | ||
from torch_em.util import segmentation | ||
from torch_em.data import MinInstanceSampler | ||
from torch_em.model import get_vimunet_model | ||
from torch_em.data.datasets import get_cremi_loader | ||
from torch_em.util.prediction import predict_with_halo | ||
|
||
from elf.evaluation import mean_segmentation_accuracy | ||
|
||
|
||
ROOT = "/scratch/usr/nimanwai" | ||
|
||
# the splits have been customed made | ||
# to reproduce the results: | ||
# extract slices ranging from "100 to 125" for all three volumes | ||
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original" | ||
|
||
|
||
def get_loaders(args, patch_shape=(1, 512, 512)): | ||
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]} | ||
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]} | ||
|
||
sampler = MinInstanceSampler() | ||
|
||
train_loader = get_cremi_loader( | ||
path=args.input, | ||
patch_shape=patch_shape, | ||
batch_size=2, | ||
rois=train_rois, | ||
sampler=sampler, | ||
ndim=2, | ||
label_dtype=torch.float32, | ||
defect_augmentation_kwargs=None, | ||
boundaries=True, | ||
num_workers=16, | ||
download=True, | ||
) | ||
val_loader = get_cremi_loader( | ||
path=args.input, | ||
patch_shape=patch_shape, | ||
batch_size=1, | ||
rois=val_rois, | ||
sampler=sampler, | ||
ndim=2, | ||
label_dtype=torch.float32, | ||
defect_augmentation_kwargs=None, | ||
boundaries=True, | ||
num_workers=16, | ||
download=True, | ||
) | ||
return train_loader, val_loader | ||
|
||
|
||
def run_cremi_training(args): | ||
# the dataloaders for cremi dataset | ||
train_loader, val_loader = get_loaders(args) | ||
|
||
# the vision-mamba + decoder (UNet-based) model | ||
model = get_vimunet_model( | ||
out_channels=1, | ||
model_type=args.model_type, | ||
with_cls_token=True | ||
) | ||
|
||
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) | ||
|
||
# loss function | ||
loss = DiceLoss() | ||
|
||
# trainer for the segmentation task | ||
trainer = torch_em.default_segmentation_trainer( | ||
name="cremi-vimunet", | ||
model=model, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
learning_rate=1e-4, | ||
loss=loss, | ||
metric=loss, | ||
log_image_interval=50, | ||
save_root=save_root, | ||
compile_model=False, | ||
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10} | ||
) | ||
trainer.fit(iterations=int(1e5)) | ||
|
||
|
||
def run_cremi_inference(args, device): | ||
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type) | ||
checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt") | ||
|
||
# the vision-mamba + decoder (UNet-based) model | ||
model = get_vimunet_model( | ||
out_channels=1, | ||
model_type=args.model_type, | ||
with_cls_token=True, | ||
checkpoint=checkpoint | ||
) | ||
|
||
all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif")) | ||
all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif")) | ||
|
||
msa_list, sa50_list, sa75_list = [], [], [] | ||
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)): | ||
image = imageio.imread(image_path) | ||
labels = imageio.imread(label_path) | ||
|
||
predictions = predict_with_halo( | ||
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True, | ||
) | ||
|
||
bd = predictions.squeeze() | ||
instances = segmentation.watershed_from_components(bd, np.ones_like(bd)) | ||
|
||
msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) | ||
msa_list.append(msa) | ||
sa50_list.append(sa_acc[0]) | ||
sa75_list.append(sa_acc[5]) | ||
|
||
res = { | ||
"CREMI": "Metrics", | ||
"mSA": np.mean(msa_list), | ||
"SA50": np.mean(sa50_list), | ||
"SA75": np.mean(sa75_list) | ||
} | ||
res_path = os.path.join(args.result_path, "results.csv") | ||
df = pd.DataFrame.from_dict([res]) | ||
df.to_csv(res_path) | ||
print(df) | ||
print(f"The result is saved at {res_path}") | ||
|
||
|
||
def main(args): | ||
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU") | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
if args.train: | ||
run_cremi_training(args) | ||
|
||
if args.predict: | ||
run_cremi_inference(args, device) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"), help="Path to CREMI dataset." | ||
) | ||
parser.add_argument( | ||
"-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved." | ||
) | ||
parser.add_argument( | ||
"-m", "--model_type", type=str, default="vim_t", help="Choice of ViM backbone" | ||
) | ||
parser.add_argument( | ||
"--train", action="store_true", help="Whether to train the model." | ||
) | ||
parser.add_argument( | ||
"--predict", action="store_true", help="Whether to run inference on the trained model." | ||
) | ||
parser.add_argument( | ||
"--result_path", type=str, default="./", help="Path to save quantitative results." | ||
) | ||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.