Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ViM-UNet #236

Merged
merged 55 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
b003fab
Add vision mamba + decoder experiments
anwai98 Jan 23, 2024
b9b8d13
Fix checkpoint loading issue
anwai98 Jan 25, 2024
3d9bff8
Update repo installation
anwai98 Jan 25, 2024
31ba444
Update save_root
anwai98 Jan 25, 2024
ef05d9c
Add inference script
anwai98 Jan 25, 2024
e7bdae0
Update inference + add all instance segmentation setups
anwai98 Jan 26, 2024
9fbba4d
Refactor Vim backbone + Add more vision mamba configs
anwai98 Feb 3, 2024
0a04fd8
Add class token
anwai98 Feb 4, 2024
c28e8f8
Fix imports
anwai98 Feb 4, 2024
0342ac9
Fix more imports
anwai98 Feb 4, 2024
ddeba72
Fix vim imports
anwai98 Feb 4, 2024
ec10d30
Update result saving
anwai98 Feb 5, 2024
b812895
Fix missing updates
anwai98 Feb 5, 2024
57d0e7e
Add plotting script
anwai98 Feb 5, 2024
87d6234
Update installation
anwai98 Feb 5, 2024
8e32994
Add training for longer + use class tokens as default
anwai98 Feb 7, 2024
20f631e
Merge branch 'main' into vision-mamba
anwai98 Feb 8, 2024
9137191
Update vimunet training - with new unetr backbone
anwai98 Feb 8, 2024
0e0411a
Merge branch 'main' into vision-mamba
anwai98 Feb 10, 2024
73d6db0
Update livecell training
anwai98 Feb 10, 2024
d192446
Check result outputs
anwai98 Feb 11, 2024
1339cbc
Merge branch 'main' into vision-mamba
anwai98 Feb 15, 2024
a44a227
Add plotting for unetr
anwai98 Feb 16, 2024
55fba22
Merge branch 'main' into vision-mamba
anwai98 Mar 19, 2024
73f7021
Add CREMI experiments using ViMUNet (#12)
anwai98 Mar 30, 2024
fbb6f6c
Update NeurIPS CellSeg - to allow download from zenodo links (#14)
anwai98 Mar 30, 2024
83d70c1
Add NeurIPS CellSeg experiments for ViMUNet (#15)
anwai98 Mar 31, 2024
b4d3c81
Refactor LIVECell experiments for UNETR setup
anwai98 Mar 31, 2024
7c29db5
Add LIVECell benchmarking experiments (#16)
anwai98 Apr 1, 2024
1fd8b2d
Add scripts for training all benchmarking experiments (#17)
anwai98 Apr 1, 2024
0e009ef
Add LM training for ViMUNet (#18)
anwai98 Apr 4, 2024
4ad364f
Update inference scripts (#19)
anwai98 Apr 5, 2024
1c83e76
Expose learning rate parameter for vimunet experiments
anwai98 Apr 5, 2024
28f420d
Expose lr for vit and unet-based methods
anwai98 Apr 5, 2024
02566b2
Add neurips cellseg inference
anwai98 Apr 5, 2024
6a80cfe
Add resource efficient limited data training (#20)
anwai98 Apr 7, 2024
d198843
Update ViMUNet documentation
anwai98 Apr 7, 2024
9e949af
Update README.md
anwai98 Apr 7, 2024
3fc0996
Update README.md
anwai98 Apr 7, 2024
9a9038c
Minor fix to limited data training
anwai98 Apr 7, 2024
8bc515e
Update multicut for livecell
anwai98 Apr 7, 2024
f062e8b
Cleanup scripts
anwai98 Apr 11, 2024
34f8109
Cleanup code2
anwai98 Apr 11, 2024
926cd1c
Merge branch 'main' into vision-mamba
anwai98 Apr 11, 2024
33f96f6
Cleanup vision mamba directory
anwai98 Apr 11, 2024
63fbb53
Refactor scripts
anwai98 Apr 11, 2024
9d74ecd
Simplify scripts with helper arguments
anwai98 Apr 11, 2024
bdc975d
Final refactor for vimunet scripts
anwai98 Apr 11, 2024
e740893
Refactor unet and unetr benchmarking scripts
anwai98 Apr 11, 2024
d396f44
Update documentation for using unetr
anwai98 Apr 11, 2024
f4a52bb
Update vimunet documentation
anwai98 Apr 11, 2024
257e362
Update documentation
anwai98 Apr 11, 2024
7435cec
Fix pip install scripts
anwai98 Apr 11, 2024
9230d80
Minor doc fix
anwai98 Apr 11, 2024
4b1689b
Minor cuda-related mention
anwai98 Apr 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions experiments/misc/get_vimunet_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter


LIVECELL_RESULTS = {
"UNet": {"boundaries": 0.372, "distances": 0.429},
r"UNETR$_{Base}$": {"boundaries": 0.11, "distances": 0.145},
r"UNETR$_{Large}$": {"boundaries": 0.171, "distances": 0.157},
r"UNETR$_{Huge}$": {"boundaries": 0.216, "distances": 0.136},
r"nnUNet$_{v2}$": {"boundaries": 0.228},
r"UMamba$_{Bot}$": {"boundaries": 0.234},
r"UMamba$_{Enc}$": {"boundaries": 0.23},
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.269, "distances": 0.381},
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.274, "distances": 0.397},
}

CREMI_RESULTS = {
"UNet": {"boundaries": 0.354},
r"UNETR$_{Base}$": {"boundaries": 0.285},
r"UNETR$_{Large}$": {"boundaries": 0.325},
r"UNETR$_{Huge}$": {"boundaries": 0.324},
r"nnUNet$_{v2}$": {"boundaries": 0.452},
r"UMamba$_{Bot}$": {"boundaries": 0.471},
r"UMamba$_{Enc}$": {"boundaries": 0.467},
r"$\bf{ViMUNet}$$_{Tiny}$": {"boundaries": 0.518},
r"$\bf{ViMUNet}$$_{Small}$": {"boundaries": 0.53},
}

DATASET_MAPPING = {
"livecell": "LIVECell",
"cremi": "CREMI"
}

plt.rcParams["font.size"] = 24


def plot_per_dataset(dataset_name):
if dataset_name == "livecell":
results = LIVECELL_RESULTS
else:
results = CREMI_RESULTS

models = list(results.keys())
metrics = list(results[models[0]].keys())

markers = ['^', '*']

fig, ax = plt.subplots(figsize=(15, 12))

x_pos = np.arange(len(models))

bar_width = 0.05

for i, metric in enumerate(metrics):
scores_list = []
for model in models:
try:
score = results[model][metric]
except KeyError:
score = None

scores_list.append(score)

ax.scatter(x_pos + i * bar_width - bar_width, scores_list, s=250, label=metric, marker=markers[i])

ax.set_xticks(x_pos)
ax.set_xticklabels(models, va='top', ha='center', rotation=45)

ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
if dataset_name == "cremi":
ax.set_yticks(np.linspace(0, 0.5, 11)[1:])
else:
ax.set_yticks(np.linspace(0, 0.4, 9)[1:])

ax.set_ylabel('Segmentation Accuracy', labelpad=15)
ax.set_xlabel(None)
ax.set_title(DATASET_MAPPING[dataset_name], fontsize=32, y=1.025)
ax.set_ylim(0)
ax.legend(loc='lower center', fancybox=True, shadow=True, ncol=2)

best_models = sorted(models, key=lambda x: max(results[x].values()), reverse=True)[:3]
sizes = [100, 70, 40]
for size, best_model in zip(sizes, best_models):
best_scores = [results[best_model].get(metric, 0) for metric in metrics]
best_index = models.index(best_model)

# HACK
offset = 0 if dataset_name == "livecell" else 0.05

ax.plot(
best_index - offset, max(best_scores), marker='o', markersize=size, linestyle='dotted',
markerfacecolor='gray', markeredgecolor='black', markeredgewidth=2, alpha=0.2
)

plt.tight_layout()
plt.show()
plt.savefig(f"{dataset_name}.png")
plt.savefig(f"{dataset_name}.svg", transparent=True)
plt.savefig(f"{dataset_name}.pdf")


def main():
plot_per_dataset("livecell")
plot_per_dataset("cremi")


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions experiments/vision-mamba/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.out
*.sh
*.png
31 changes: 31 additions & 0 deletions experiments/vision-mamba/vimunet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ViM-UNet: Vision Mamba in Biomedical Segmentation

We introduce **ViM-UNet**. a novel segmentation architecture based on Vision Mamba for instance segmentation in microscopy.

To get started, make sure to take a look at the [documentation](https://github.com/constantinpape/torch-em/blob/main/vimunet.md).

Here are the experiments for instance segmentation on:
1. LIVECell for cell segmentation in phase-contrast microscopy.
- You can run the boundary-based /distance-based experiments. See `run_livecell.py -h` for details.
```python
python run_livecell.py -i <PATH_TO_DATA>
-s <PATH_TO_SAVE_CHECKPOINTS>
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b'
--train # for training
--predict # for inference on trained models
--result_path <PATH_TO_SAVE_RESULTS>
# below is how you can provide the choice for training for either methods
--boundaries / --distances
```

2. CREMI for neurites segmentation in electron microscopy.
- You can run the boundary-based experiment. See `run_livecell.py -h` for details. Below is an example script:
```python
python run_cremi.py -i <PATH_TO_DATA>
-s <PATH_TO_SAVE_CHECKPOINTS>
-m <MODEL_NAME> # the supported models are 'vim_t', 'vim_s' and 'vim_b'
--train # for training
--predict # for inference on trained models
--result_path <PATH_TO_SAVE_RESULTS>
```

176 changes: 176 additions & 0 deletions experiments/vision-mamba/vimunet/run_cremi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import os
import argparse
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm

import imageio.v3 as imageio

import torch

import torch_em
from torch_em.loss import DiceLoss
from torch_em.util import segmentation
from torch_em.data import MinInstanceSampler
from torch_em.model import get_vimunet_model
from torch_em.data.datasets import get_cremi_loader
from torch_em.util.prediction import predict_with_halo

from elf.evaluation import mean_segmentation_accuracy


ROOT = "/scratch/usr/nimanwai"

# the splits have been customed made
# to reproduce the results:
# extract slices ranging from "100 to 125" for all three volumes
CREMI_TEST_ROOT = "/scratch/projects/nim00007/sam/data/cremi/slices_original"


def get_loaders(args, patch_shape=(1, 512, 512)):
train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

sampler = MinInstanceSampler()

train_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=2,
rois=train_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
val_loader = get_cremi_loader(
path=args.input,
patch_shape=patch_shape,
batch_size=1,
rois=val_rois,
sampler=sampler,
ndim=2,
label_dtype=torch.float32,
defect_augmentation_kwargs=None,
boundaries=True,
num_workers=16,
download=True,
)
return train_loader, val_loader


def run_cremi_training(args):
# the dataloaders for cremi dataset
train_loader, val_loader = get_loaders(args)

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True
)

save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)

# loss function
loss = DiceLoss()

# trainer for the segmentation task
trainer = torch_em.default_segmentation_trainer(
name="cremi-vimunet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 10}
)
trainer.fit(iterations=int(1e5))


def run_cremi_inference(args, device):
save_root = os.path.join(args.save_root, "scratch", "boundaries", args.model_type)
checkpoint = os.path.join(save_root, "checkpoints", "cremi-vimunet", "best.pt")

# the vision-mamba + decoder (UNet-based) model
model = get_vimunet_model(
out_channels=1,
model_type=args.model_type,
with_cls_token=True,
checkpoint=checkpoint
)

all_test_images = glob(os.path.join(CREMI_TEST_ROOT, "raw", "cremi_test_*.tif"))
all_test_labels = glob(os.path.join(CREMI_TEST_ROOT, "labels", "cremi_test_*.tif"))

msa_list, sa50_list, sa75_list = [], [], []
for image_path, label_path in tqdm(zip(all_test_images, all_test_labels), total=len(all_test_images)):
image = imageio.imread(image_path)
labels = imageio.imread(label_path)

predictions = predict_with_halo(
image, model, [device], block_shape=[512, 512], halo=[128, 128], disable_tqdm=True,
)

bd = predictions.squeeze()
instances = segmentation.watershed_from_components(bd, np.ones_like(bd))

msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True)
msa_list.append(msa)
sa50_list.append(sa_acc[0])
sa75_list.append(sa_acc[5])

res = {
"CREMI": "Metrics",
"mSA": np.mean(msa_list),
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}
res_path = os.path.join(args.result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
print(df)
print(f"The result is saved at {res_path}")


def main(args):
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")
device = "cuda" if torch.cuda.is_available() else "cpu"

if args.train:
run_cremi_training(args)

if args.predict:
run_cremi_inference(args, device)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input", type=str, default=os.path.join(ROOT, "data", "cremi"), help="Path to CREMI dataset."
)
parser.add_argument(
"-s", "--save_root", type=str, default="./", help="Path where the model checkpoints will be saved."
)
parser.add_argument(
"-m", "--model_type", type=str, default="vim_t", help="Choice of ViM backbone"
)
parser.add_argument(
"--train", action="store_true", help="Whether to train the model."
)
parser.add_argument(
"--predict", action="store_true", help="Whether to run inference on the trained model."
)
parser.add_argument(
"--result_path", type=str, default="./", help="Path to save quantitative results."
)
args = parser.parse_args()
main(args)
Loading
Loading