Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement UNETR #146

Merged
merged 11 commits into from
Aug 1, 2023
5 changes: 5 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ Experiments for the re-implementation of [Sparse Object-Level Supervision for In
## probabilistic_domain_adaptation

Experiments for the re-implementation of [Probabilistic Domain Adaptation for Biomedical Image Segmentation](https://arxiv.org/abs/2303.11790). Work in progress.


## vision-transformer

WIP
72 changes: 72 additions & 0 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import argparse

import torch
import torch_em
from torch_em.model import UNETR
from torch_em.data.datasets import get_livecell_loader


def do_unetr_training(data_path: str, save_root: str, cell_type: list, iterations: int, device, patch_shape=(256, 256)):
os.makedirs(data_path, exist_ok=True)
train_loader = get_livecell_loader(
path=data_path,
split='train',
patch_shape=patch_shape,
batch_size=2,
cell_types=cell_type,
download=True,
binary=True
)

val_loader = get_livecell_loader(
path=data_path,
split='val',
patch_shape=patch_shape,
batch_size=1,
cell_types=cell_type,
download=True,
binary=True
)

model = UNETR(out_channels=1,
initialize_from_sam=True)
model.to(device)

trainer = torch_em.default_segmentation_trainer(
name=f"unet-source-livecell-{cell_type[0]}",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1.0e-4,
log_image_interval=10,
save_root=save_root,
compile_model=False
)

trainer.fit(iterations)


def main(args):
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.train:
print("Training a 2D UNETR on LiveCELL dataset")
do_unetr_training(data_path=args.inputs,
save_root=args.save_root,
cell_type=args.cell_type,
iterations=args.iterations,
device=device)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train", action='store_true', help="Enables UNETR training on LiveCELL dataset")
parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"], help="Choice of cell-type for doing the training")
parser.add_argument("-i", "--inputs", type=str, default="./livecell/", help="Path where the dataset already exists/will be downloaded by the dataloader")
parser.add_argument("-s", "--save_root", type=str, default=None, help="Path where checkpoints and logs will be saved")
parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for")
args = parser.parse_args()
main(args)
40 changes: 40 additions & 0 deletions test/model/test_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest
import torch

try:
import segment_anything
except ImportError:
segment_anything = None

try:
import micro_sam
except ImportError:
micro_sam = None


@unittest.skipIf(segment_anything is None, "Needs segment_anything")
class TestUnetr(unittest.TestCase):
def _test_net(self, net, shape):
x = torch.rand(*shape, requires_grad=True)
y = net(x)
expected_shape = shape[:1] + (net.out_channels,) + shape[2:]
self.assertEqual(y.shape, expected_shape)
loss = y.sum()
loss.backward()

def test_unetr(self):
from torch_em.model import UNETR

model = UNETR()
self._test_net(model, (1, 3, 256, 256))

@unittest.skipIf(micro_sam is None, "Needs micro_sam")
def test_unetr_from_sam(self):
from torch_em.model import build_unetr_with_sam_intialization

model = build_unetr_with_sam_intialization()
self._test_net(model, (1, 3, 256, 256))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torch_em/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .unet import AnisotropicUNet, UNet2d, UNet3d
from .probabilistic_unet import ProbabilisticUNet
from .unetr import UNETR, build_unetr_with_sam_intialization
Loading
Loading