Skip to content

Commit

Permalink
Merge pull request #188 from constantinpape/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
constantinpape committed Dec 22, 2023
2 parents 7d91a1b + 1409c56 commit 020f3fb
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 27 deletions.
3 changes: 2 additions & 1 deletion experiments/vision-transformer/unetr/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.sh
*.out
*.out
*.csv
9 changes: 7 additions & 2 deletions experiments/vision-transformer/unetr/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_unetr_model(
model = torch_em_models.UNETR(
backbone=backbone, encoder=model_name, out_channels=output_channels,
use_sam_stats=sam_initialization, final_activation="Sigmoid",
encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None,
encoder_checkpoint=MODELS[model_name] if sam_initialization else None,
)

elif source_choice == "monai":
Expand Down Expand Up @@ -241,7 +241,12 @@ def predict_for_unetr(
elif with_distances: # inference using foreground and hv distance maps
outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16))
fg, cdist, bdist = outputs.squeeze()
dm_seg = segmentation.watershed_from_center_and_boundary_distances(cdist, bdist, fg, min_size=50)
dm_seg = segmentation.watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
center_distance_threshold=0.5,
boundary_distance_threshold=0.6,
distance_smoothing=1.0
)

else: # inference using foreground-boundary inputs - for the unetr training
outputs = predict_with_halo(
Expand Down
134 changes: 134 additions & 0 deletions experiments/vision-transformer/unetr/livecell/train_by_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
from collections import OrderedDict

import torch
from torch_em import model as torch_em_models

import common


def prune_prefix(checkpoint_path):
state = torch.load(checkpoint_path, map_location="cpu")
model_state = state["model_state"]

# let's prune the `.sam` prefix for the finetuned models
sam_prefix = "sam.image_encoder."
updated_model_state = []
for k, v in model_state.items():
if k.startswith(sam_prefix):
updated_model_state.append((k[len(sam_prefix):], v))
updated_model_state = OrderedDict(updated_model_state)

return updated_model_state


def get_custom_unetr_model(
device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training
):
if checkpoint_path is not None:
if checkpoint_path.endswith("pt"): # for finetuned models
model_state = prune_prefix(checkpoint_path)
else: # for vanilla sam models
model_state = checkpoint_path
else: # while checkpoint path is None, hence we train from scratch
model_state = checkpoint_path

model = torch_em_models.UNETR(
backbone="sam",
encoder=model_name,
out_channels=output_channels,
use_sam_stats=sam_initialization,
final_activation="Sigmoid",
encoder_checkpoint=model_state,
use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default
)

model.to(device)

# if expected, let's freeze the image encoder
if freeze_encoder:
for name, param in model.named_parameters():
if name.startswith("encoder"):
param.requires_grad = False

return model


def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups
patch_shape = (512, 512) # patch size used for training on livecell

# directory folder to save different parts of the scheme
dir_structure = os.path.join(
args.model_name, f"freeze_encoder_{args.freeze_encoder}", "distances", "dicebaseddistloss",
f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch"
)

# get the desired loss function for training
loss = common.get_loss_function(with_distances=True, combine_dist_with_dice=True)

# get the custom model for the training and inference on livecell dataset
model = get_custom_unetr_model(
device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3,
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training
)

# determining where to save the checkpoints and tensorboard logs
save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root

# determines the directory where the predictions will be saved
root_save_dir = os.path.join(args.save_dir, dir_structure)

if args.train:
print("2d (custom) UNETR training (with distances) on LiveCELL...")

# get the desired livecell loaders for training
train_loader, val_loader = common.get_my_livecell_loaders(
args.input, patch_shape, args.cell_type, with_distances=True,
input_norm=not args.do_sam_ini
)

common.do_unetr_training(
train_loader=train_loader, val_loader=val_loader, model=model, loss=loss,
device=device, save_root=save_root, iterations=args.iterations
)

if args.predict:
print("2d (custom) UNETR inference (with distances) on LiveCELL...")
common.do_unetr_inference(
input_path=args.input, device=device, model=model, save_root=save_root,
root_save_dir=root_save_dir, with_distances=True,
# the logic written for `input_norm` is complicated, but the idea is simple:
# - should standardize the inputs when we "DONOT" use SAM initialization
# - should not standardize the inputs when we use SAM initialization
input_norm=not args.do_sam_ini
)
print("Predictions are saved in", root_save_dir)

if args.evaluate:
print("2d (custom) UNETR evaluation (with distances) on LiveCELL...")
csv_save_dir = os.path.join("results", dir_structure)
os.makedirs(csv_save_dir, exist_ok=True)

common.do_unetr_evaluation(
input_path=args.input, root_save_dir=root_save_dir, csv_save_dir=csv_save_dir, with_distances=True
)


# we train three setups:
# - training from scratch, seeing the performance using instance segmentation
# - training from vanilla SAM, seeing the performance using instance segmentation
# - training from finetuned SAM, seeing the performance using instance segmentation
if __name__ == "__main__":
parser = common.get_parser()
parser.add_argument(
"--checkpoint", type=str, default=None, help="The checkpoint to the specific pretrained models."
)
parser.add_argument(
"--freeze_encoder", action="store_true", help="Experiments to freeze the encoder."
)
parser.add_argument(
"--joint_training", action="store_true", help="Uses VNETR for training"
)
args = parser.parse_args()
main(args)
35 changes: 35 additions & 0 deletions scripts/vision_transformer/load_sam_encoder_in_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from torch_em.model import UNETR

from micro_sam.util import get_sam_model


def main():
checkpoint = "/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictor = get_sam_model(
model_type=model_type,
checkpoint_path=checkpoint
)

model = UNETR(
backbone="sam",
encoder=predictor.model.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False
)
model.to(device)

x = torch.ones((1, 1, 512, 512)).to(device)
y = model(x)

print("UNETR Model successfully created and encoder initialized from", checkpoint)


if __name__ == "__main__":
main()
91 changes: 67 additions & 24 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer
from .vit import get_vision_transformer, ViT_MAE, ViT_Sam

try:
from micro_sam.util import get_sam_model
Expand All @@ -24,7 +24,7 @@ class UNETR(nn.Module):
def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

if isinstance(checkpoint, str):
if backbone == "sam":
if backbone == "sam" and isinstance(encoder, str):
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
Expand Down Expand Up @@ -63,23 +63,47 @@ def __init__(
self,
img_size: int = 1024,
backbone: str = "sam",
encoder: str = "vit_b",
encoder: Optional[Union[nn.Module, str]] = "vit_b",
decoder: Optional[nn.Module] = None,
out_channels: int = 1,
use_sam_stats: bool = False,
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

in_chans = self.encoder.in_chans
if embed_dim is None:
embed_dim = self.encoder.embed_dim

else: # `nn.Module` ViT backbone
self.encoder = encoder

have_neck = False
for name, _ in self.encoder.named_parameters():
if name.startswith("neck"):
have_neck = True

if embed_dim is None:
if have_neck:
embed_dim = self.encoder.neck[2].out_channels # the value is 256
else:
embed_dim = self.encoder.patch_embed.proj.out_channels

in_chans = self.encoder.patch_embed.proj.in_channels

# parameters for the decoder network
depth = 3
Expand All @@ -99,18 +123,21 @@ def __init__(
else:
self.decoder = decoder

self.z_inputs = ConvBlock2d(self.encoder.in_chans, features_decoder[-1])
self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.base = ConvBlock2d(self.encoder.embed_dim, features_decoder[0])
self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0])
self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])

self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1])
self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down Expand Up @@ -167,26 +194,42 @@ def forward(self, x):
# backbone used for reshaping inputs to the desired "encoder" shape
x = torch.stack([self.preprocess(e) for e in x], dim=0)

z0 = self.z_inputs(x)
use_skip_connection = getattr(self, "use_skip_connection", True)

z12, from_encoder = self.encoder(x)
x = self.base(z12)
encoder_outputs = self.encoder(x)

from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])
if isinstance(self.encoder, ViT_Sam) or isinstance(self.encoder, ViT_MAE):
z12, from_encoder = encoder_outputs
else:
z12 = encoder_outputs

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)
if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)
z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)

z0 = self.z_inputs(x)

else:
z9 = self.deconv1(z12)
z6 = self.deconv2(z9)
z3 = self.deconv3(z6)
z0 = self.deconv4(z3)

updated_from_encoder = [z9, z6, z3]

x = self.base(z12)
x = self.decoder(x, encoder_inputs=updated_from_encoder)
x = self.deconv4(x)
x = torch.cat([x, z0], dim=1)
x = self.deconv_out(x)

x = torch.cat([x, z0], dim=1)
x = self.decoder_head(x)

x = self.out_conv(x)
Expand Down
5 changes: 5 additions & 0 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
boundary_distances=True,
directed_distances=False,
foreground=True,
instances=False,
apply_label=True,
correct_centers=True,
min_size=0,
Expand All @@ -313,6 +314,7 @@ def __init__(
self.boundary_distances = boundary_distances
self.directed_distances = directed_distances
self.foreground = foreground
self.instances = instances

self.apply_label = apply_label
self.correct_centers = correct_centers
Expand Down Expand Up @@ -441,4 +443,7 @@ def __call__(self, labels):
binary_labels = (labels > 0).astype("float32")
distances = np.concatenate([binary_labels[None], distances], axis=0)

if self.instances:
distances = np.concatenate([labels[None], distances], axis=0)

return distances

0 comments on commit 020f3fb

Please sign in to comment.