Skip to content

Commit

Permalink
Merge pull request #146 from constantinpape/unetr
Browse files Browse the repository at this point in the history
Implement UNETR
  • Loading branch information
constantinpape committed Aug 1, 2023
2 parents aba351c + e08092f commit e43fbed
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 0 deletions.
5 changes: 5 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ Experiments for the re-implementation of [Sparse Object-Level Supervision for In
## probabilistic_domain_adaptation

Experiments for the re-implementation of [Probabilistic Domain Adaptation for Biomedical Image Segmentation](https://arxiv.org/abs/2303.11790). Work in progress.


## vision-transformer

WIP
9 changes: 9 additions & 0 deletions experiments/vision-transformer/unetr/initialize_with_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
from torch_em.model.unetr import build_unetr_with_sam_intialization

# FIXME this doesn't work yet
model = build_unetr_with_sam_intialization()
x = torch.randn(1, 3, 1024, 1024)

y = model(x)
print(y.shape)
72 changes: 72 additions & 0 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import argparse

import torch
import torch_em
from torch_em.model import UNETR
from torch_em.data.datasets import get_livecell_loader


def do_unetr_training(data_path: str, save_root: str, cell_type: list, iterations: int, device, patch_shape=(256, 256)):
os.makedirs(data_path, exist_ok=True)
train_loader = get_livecell_loader(
path=data_path,
split="train",
patch_shape=patch_shape,
batch_size=2,
cell_types=cell_type,
download=True,
binary=True
)

val_loader = get_livecell_loader(
path=data_path,
split="val",
patch_shape=patch_shape,
batch_size=1,
cell_types=cell_type,
download=True,
binary=True
)

model = UNETR(out_channels=1,
initialize_from_sam=True)
model.to(device)

trainer = torch_em.default_segmentation_trainer(
name=f"unet-source-livecell-{cell_type[0]}",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1.0e-4,
log_image_interval=10,
save_root=save_root,
compile_model=False
)

trainer.fit(iterations)


def main(args):
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.train:
print("Training a 2D UNETR on LiveCELL dataset")
do_unetr_training(data_path=args.inputs,
save_root=args.save_root,
cell_type=args.cell_type,
iterations=args.iterations,
device=device)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train", action='store_true', help="Enables UNETR training on LiveCELL dataset")
parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"], help="Choice of cell-type for doing the training")
parser.add_argument("-i", "--inputs", type=str, default="./livecell/", help="Path where the dataset already exists/will be downloaded by the dataloader")
parser.add_argument("-s", "--save_root", type=str, default=None, help="Path where checkpoints and logs will be saved")
parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for")
args = parser.parse_args()
main(args)
40 changes: 40 additions & 0 deletions test/model/test_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
import torch

try:
import segment_anything
except ImportError:
segment_anything = None

try:
import micro_sam
except ImportError:
micro_sam = None


@unittest.skipIf(segment_anything is None, "Needs segment_anything")
class TestUnetr(unittest.TestCase):
def _test_net(self, net, shape):
x = torch.rand(*shape, requires_grad=True)
y = net(x)
expected_shape = shape[:1] + (net.out_channels,) + shape[2:]
self.assertEqual(y.shape, expected_shape)
loss = y.sum()
loss.backward()

def test_unetr(self):
from torch_em.model import UNETR

model = UNETR()
self._test_net(model, (1, 3, 256, 256))

@unittest.skipIf(micro_sam is None, "Needs micro_sam")
def test_unetr_from_sam(self):
from torch_em.model import build_unetr_with_sam_intialization

model = build_unetr_with_sam_intialization()
self._test_net(model, (1, 3, 256, 256))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torch_em/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .unet import AnisotropicUNet, UNet2d, UNet3d
from .probabilistic_unet import ProbabilisticUNet
from .unetr import UNETR, build_unetr_with_sam_intialization
Loading

0 comments on commit e43fbed

Please sign in to comment.