Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and Updates for UNETR #148

Merged
merged 5 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions experiments/vision-transformer/unetr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Usage of SAM's ViT initialization in UNETR

### Initialize ViT models for UNETR
```
from torch_em.model import UNETR
unetr = UNETR(encoder=model_type, out_channels=out_channels, encoder_checkpoint_path=checkpoint_path)
```

### Vanilla ViT models for UNETR
```
from torch_em.model import UNETR
unetr = UNETR(encoder=model_type, out_channels=out_channels)
```
80 changes: 80 additions & 0 deletions experiments/vision-transformer/unetr/cremi_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import argparse
import numpy as np

import torch
import torch_em
from torch_em.model import UNETR
from torch_em.data.datasets import get_cremi_loader


def do_unetr_training(data_path: str, save_root: str, iterations: int, device, patch_shape=(1, 512, 512)):
os.makedirs(data_path, exist_ok=True)

cremi_train_rois = {"A": np.s_[0:75, :, :], "B": np.s_[0:75, :, :], "C": np.s_[0:75, :, :]}
cremi_val_rois = {"A": np.s_[75:100, :, :], "B": np.s_[75:100, :, :], "C": np.s_[75:100, :, :]}

train_loader = get_cremi_loader(
path=data_path,
patch_shape=patch_shape, download=True,
rois=cremi_train_rois,
ndim=2,
defect_augmentation_kwargs=None,
boundaries=True,
batch_size=2
)

val_loader = get_cremi_loader(
path=data_path,
patch_shape=patch_shape, download=True,
rois=cremi_val_rois,
ndim=2,
defect_augmentation_kwargs=None,
boundaries=True,
batch_size=1
)

model = UNETR(
encoder="vit_b", out_channels=1,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth")
model.to(device)

trainer = torch_em.default_segmentation_trainer(
name="unetr-cremi",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1e-5,
log_image_interval=10,
save_root=save_root,
compile_model=False
)

trainer.fit(iterations)


def main(args):
print(torch.cuda.get_device_name() if torch.cuda.is_available() else "GPU not available, hence running on CPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.train:
print("Training a 2D UNETR on Cremi dataset")
do_unetr_training(
data_path=args.inputs,
save_root=args.save_root,
iterations=args.iterations,
device=device
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train", action='store_true', help="Enables UNETR training on Cremi dataset")
parser.add_argument("-i", "--inputs", type=str, default="./cremi/",
help="Path where the dataset already exists/will be downloaded by the dataloader")
parser.add_argument("-s", "--save_root", type=str, default=None,
help="Path where checkpoints and logs will be saved")
parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for")
args = parser.parse_args()
main(args)
12 changes: 8 additions & 4 deletions experiments/vision-transformer/unetr/initialize_with_sam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import torch
from torch_em.model.unetr import build_unetr_with_sam_intialization
from torch_em.model import UNETR

# FIXME this doesn't work yet
model = build_unetr_with_sam_intialization()
x = torch.randn(1, 3, 1024, 1024)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNETR(encoder="vit_h", out_channels=1,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_h_4b8939.pth")
model.to(device)

x = torch.randn(1, 3, 1024, 1024).to(device=device)

y = model(x)
print(y.shape)
38 changes: 23 additions & 15 deletions experiments/vision-transformer/unetr/livecell_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration
batch_size=2,
cell_types=cell_type,
download=True,
binary=True
boundaries=True
)

val_loader = get_livecell_loader(
Expand All @@ -26,20 +26,23 @@ def do_unetr_training(data_path: str, save_root: str, cell_type: list, iteration
batch_size=1,
cell_types=cell_type,
download=True,
binary=True
boundaries=True
)

model = UNETR(out_channels=1,
initialize_from_sam=True)
n_channels = 2

model = UNETR(
encoder="vit_b", out_channels=n_channels,
encoder_checkpoint_path="/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth")
model.to(device)

trainer = torch_em.default_segmentation_trainer(
name=f"unet-source-livecell-{cell_type[0]}",
name=f"unetr-source-livecell-{cell_type[0]}",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1.0e-4,
learning_rate=1e-5,
log_image_interval=10,
save_root=save_root,
compile_model=False
Expand All @@ -53,20 +56,25 @@ def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.train:
print("Training a 2D UNETR on LiveCELL dataset")
do_unetr_training(data_path=args.inputs,
save_root=args.save_root,
cell_type=args.cell_type,
iterations=args.iterations,
device=device)
print("Training a 2D UNETR on LiveCell dataset")
do_unetr_training(
data_path=args.inputs,
save_root=args.save_root,
cell_type=args.cell_type,
iterations=args.iterations,
device=device
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train", action='store_true', help="Enables UNETR training on LiveCELL dataset")
parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"], help="Choice of cell-type for doing the training")
parser.add_argument("-i", "--inputs", type=str, default="./livecell/", help="Path where the dataset already exists/will be downloaded by the dataloader")
parser.add_argument("-s", "--save_root", type=str, default=None, help="Path where checkpoints and logs will be saved")
parser.add_argument("-c", "--cell_type", nargs='+', default=["A172"],
help="Choice of cell-type for doing the training")
parser.add_argument("-i", "--inputs", type=str, default="./livecell/",
help="Path where the dataset already exists/will be downloaded by the dataloader")
parser.add_argument("-s", "--save_root", type=str, default=None,
help="Path where checkpoints and logs will be saved")
parser.add_argument("--iterations", type=int, default=100000, help="No. of iterations to run the training for")
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion torch_em/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .unet import AnisotropicUNet, UNet2d, UNet3d
from .probabilistic_unet import ProbabilisticUNet
from .unetr import UNETR, build_unetr_with_sam_intialization
from .unetr import UNETR
79 changes: 51 additions & 28 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
get_sam_model = None


class ViTb_Sam(ImageEncoderViT):
class ViT_Sam(ImageEncoderViT): # type: ignore
"""Vision Transformer derived from the Segment Anything Codebase (https://arxiv.org/abs/2304.02643):
https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py
"""
Expand All @@ -39,7 +39,7 @@ def __init__(
"and then rerun your code."
)

super().__init__(**kwargs)
super().__init__(embed_dim=embed_dim, **kwargs)
self.chunks_for_projection = global_attn_indexes
self.in_chans = in_chans
self.embed_dim = embed_dim
Expand Down Expand Up @@ -112,10 +112,11 @@ def window_unpartition(
class UNETR(nn.Module):
def __init__(
self,
encoder=None,
encoder="vit_b",
decoder=None,
out_channels=1,
use_sam_preprocessing=False,
use_sam_preprocessing=True,
encoder_checkpoint_path=None
) -> None:
depth = 3
initial_features = 64
Expand All @@ -127,8 +128,8 @@ def __init__(

super().__init__()

if encoder is None:
self.encoder = ViTb_Sam(
if encoder == "vit_b":
self.encoder = ViT_Sam(
depth=12,
embed_dim=768,
img_size=1024,
Expand All @@ -142,8 +143,50 @@ def __init__(
window_size=14,
out_chans=256,
)

elif encoder == "vit_l":
self.encoder = ViT_Sam(
depth=24,
embed_dim=1024,
img_size=1024,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), # type: ignore
num_heads=16,
patch_size=16,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=[5, 11, 17, 23], # type: ignore
window_size=14,
out_chans=256
)

elif encoder == "vit_h":
self.encoder = ViT_Sam(
depth=32,
embed_dim=1280,
img_size=1024,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), # type: ignore
num_heads=16,
patch_size=16,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=[7, 15, 23, 31], # type: ignore
window_size=14,
out_chans=256
)

else:
self.encoder = encoder
raise ValueError(f"{encoder} is not supported. Currently only vit_b, vit_l, vit_h are supported.")

if encoder_checkpoint_path is not None:
_, model = get_sam_model(
model_type=encoder,
checkpoint_path=encoder_checkpoint_path,
return_sam=True
) # type: ignore
for param1, param2 in zip(model.parameters(), self.encoder.parameters()):
param2.data = param1

if decoder is None:
self.decoder = Decoder(
Expand Down Expand Up @@ -206,7 +249,7 @@ def forward(self, x):

z0 = self.z_inputs(x)

z12, from_encoder = self.encoder(x)
z12, from_encoder = self.encoder(x) # type: ignore
x = self.base(z12)

from_encoder = from_encoder[::-1]
Expand Down Expand Up @@ -277,23 +320,3 @@ def __init__(self, in_planes, out_planes, kernel_size=3):

def forward(self, x):
return self.block(x)


def build_unetr_with_sam_intialization(out_channels=1, model_type="vit_b", checkpoint_path=None):
if get_sam_model is None:
raise RuntimeError(
"micro_sam is required to initialize the UNETR image encoder from segment anything weights."
"Please install it from https://github.com/computational-cell-analytics/micro-sam"
"and then rerun your code."
)
predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
_image_encoder = predictor.model.image_encoder

image_encoder = ViTb_Sam()
# FIXME this doesn't work yet because the parameters don't match
with torch.no_grad():
for param1, param2 in zip(_image_encoder.parameters(), image_encoder.parameters()):
param2.data = param1.data

unetr = UNETR(encoder=image_encoder, out_channels=out_channels)
return unetr
Loading