Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - Instance Segmentation Experiments #187

Merged
merged 10 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion experiments/vision-transformer/unetr/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.sh
*.out
*.out
*.csv
9 changes: 7 additions & 2 deletions experiments/vision-transformer/unetr/livecell/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def get_unetr_model(
model = torch_em_models.UNETR(
backbone=backbone, encoder=model_name, out_channels=output_channels,
use_sam_stats=sam_initialization, final_activation="Sigmoid",
encoder_checkpoint_path=MODELS[model_name] if sam_initialization else None,
encoder_checkpoint=MODELS[model_name] if sam_initialization else None,
)

elif source_choice == "monai":
Expand Down Expand Up @@ -241,7 +241,12 @@ def predict_for_unetr(
elif with_distances: # inference using foreground and hv distance maps
outputs = predict_with_padding(model, input_, device=device, min_divisible=(16, 16))
fg, cdist, bdist = outputs.squeeze()
dm_seg = segmentation.watershed_from_center_and_boundary_distances(cdist, bdist, fg, min_size=50)
dm_seg = segmentation.watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
center_distance_threshold=0.5,
boundary_distance_threshold=0.6,
distance_smoothing=1.0
)

else: # inference using foreground-boundary inputs - for the unetr training
outputs = predict_with_halo(
Expand Down
134 changes: 134 additions & 0 deletions experiments/vision-transformer/unetr/livecell/train_by_parts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
from collections import OrderedDict

import torch
from torch_em import model as torch_em_models

import common


def prune_prefix(checkpoint_path):
state = torch.load(checkpoint_path, map_location="cpu")
model_state = state["model_state"]

# let's prune the `.sam` prefix for the finetuned models
sam_prefix = "sam.image_encoder."
updated_model_state = []
for k, v in model_state.items():
if k.startswith(sam_prefix):
updated_model_state.append((k[len(sam_prefix):], v))
updated_model_state = OrderedDict(updated_model_state)

return updated_model_state


def get_custom_unetr_model(
device, model_name, sam_initialization, output_channels, checkpoint_path, freeze_encoder, joint_training
):
if checkpoint_path is not None:
if checkpoint_path.endswith("pt"): # for finetuned models
model_state = prune_prefix(checkpoint_path)
else: # for vanilla sam models
model_state = checkpoint_path
else: # while checkpoint path is None, hence we train from scratch
model_state = checkpoint_path

model = torch_em_models.UNETR(
backbone="sam",
encoder=model_name,
out_channels=output_channels,
use_sam_stats=sam_initialization,
final_activation="Sigmoid",
encoder_checkpoint=model_state,
use_skip_connection=not joint_training # if joint_training, no skip con. else, use skip con. by default
)

model.to(device)

# if expected, let's freeze the image encoder
if freeze_encoder:
for name, param in model.named_parameters():
if name.startswith("encoder"):
param.requires_grad = False

return model


def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # overwrite to use complex device setups
patch_shape = (512, 512) # patch size used for training on livecell

# directory folder to save different parts of the scheme
dir_structure = os.path.join(
args.model_name, f"freeze_encoder_{args.freeze_encoder}", "distances", "dicebaseddistloss",
f"{args.source_choice}-sam" if args.do_sam_ini else f"{args.source_choice}-scratch"
)

# get the desired loss function for training
loss = common.get_loss_function(with_distances=True, combine_dist_with_dice=True)

# get the custom model for the training and inference on livecell dataset
model = get_custom_unetr_model(
device, args.model_name, sam_initialization=args.do_sam_ini, output_channels=3,
checkpoint_path=args.checkpoint, freeze_encoder=args.freeze_encoder, joint_training=args.joint_training
)

# determining where to save the checkpoints and tensorboard logs
save_root = os.path.join(args.save_root, dir_structure) if args.save_root is not None else args.save_root

# determines the directory where the predictions will be saved
root_save_dir = os.path.join(args.save_dir, dir_structure)

if args.train:
print("2d (custom) UNETR training (with distances) on LiveCELL...")

# get the desired livecell loaders for training
train_loader, val_loader = common.get_my_livecell_loaders(
args.input, patch_shape, args.cell_type, with_distances=True,
input_norm=not args.do_sam_ini
)

common.do_unetr_training(
train_loader=train_loader, val_loader=val_loader, model=model, loss=loss,
device=device, save_root=save_root, iterations=args.iterations
)

if args.predict:
print("2d (custom) UNETR inference (with distances) on LiveCELL...")
common.do_unetr_inference(
input_path=args.input, device=device, model=model, save_root=save_root,
root_save_dir=root_save_dir, with_distances=True,
# the logic written for `input_norm` is complicated, but the idea is simple:
# - should standardize the inputs when we "DONOT" use SAM initialization
# - should not standardize the inputs when we use SAM initialization
input_norm=not args.do_sam_ini
)
print("Predictions are saved in", root_save_dir)

if args.evaluate:
print("2d (custom) UNETR evaluation (with distances) on LiveCELL...")
csv_save_dir = os.path.join("results", dir_structure)
os.makedirs(csv_save_dir, exist_ok=True)

common.do_unetr_evaluation(
input_path=args.input, root_save_dir=root_save_dir, csv_save_dir=csv_save_dir, with_distances=True
)


# we train three setups:
# - training from scratch, seeing the performance using instance segmentation
# - training from vanilla SAM, seeing the performance using instance segmentation
# - training from finetuned SAM, seeing the performance using instance segmentation
if __name__ == "__main__":
parser = common.get_parser()
parser.add_argument(
"--checkpoint", type=str, default=None, help="The checkpoint to the specific pretrained models."
)
parser.add_argument(
"--freeze_encoder", action="store_true", help="Experiments to freeze the encoder."
)
parser.add_argument(
"--joint_training", action="store_true", help="Uses VNETR for training"
)
args = parser.parse_args()
main(args)
35 changes: 35 additions & 0 deletions scripts/vision_transformer/load_sam_encoder_in_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from torch_em.model import UNETR

from micro_sam.util import get_sam_model


def main():
checkpoint = "/scratch/usr/nimanwai/models/segment-anything/checkpoints/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

predictor = get_sam_model(
model_type=model_type,
checkpoint_path=checkpoint
)

model = UNETR(
backbone="sam",
encoder=predictor.model.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False
)
model.to(device)

x = torch.ones((1, 1, 512, 512)).to(device)
y = model(x)

print("UNETR Model successfully created and encoder initialized from", checkpoint)


if __name__ == "__main__":
main()
91 changes: 67 additions & 24 deletions torch_em/model/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.functional as F

from .unet import Decoder, ConvBlock2d, Upsampler2d
from .vit import get_vision_transformer
from .vit import get_vision_transformer, ViT_MAE, ViT_Sam

try:
from micro_sam.util import get_sam_model
Expand All @@ -24,7 +24,7 @@ class UNETR(nn.Module):
def _load_encoder_from_checkpoint(self, backbone, encoder, checkpoint):

if isinstance(checkpoint, str):
if backbone == "sam":
if backbone == "sam" and isinstance(encoder, str):
# If we have a SAM encoder, then we first try to load the full SAM Model
# (using micro_sam) and otherwise fall back on directly loading the encoder state
# from the checkpoint
Expand Down Expand Up @@ -63,23 +63,47 @@ def __init__(
self,
img_size: int = 1024,
backbone: str = "sam",
encoder: str = "vit_b",
encoder: Optional[Union[nn.Module, str]] = "vit_b",
decoder: Optional[nn.Module] = None,
out_channels: int = 1,
use_sam_stats: bool = False,
use_mae_stats: bool = False,
encoder_checkpoint: Optional[Union[str, OrderedDict]] = None,
final_activation: Optional[Union[str, nn.Module]] = None,
use_skip_connection: bool = True,
embed_dim: Optional[int] = None
) -> None:
super().__init__()

self.use_sam_stats = use_sam_stats
self.use_mae_stats = use_mae_stats
self.use_skip_connection = use_skip_connection

print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)
if isinstance(encoder, str): # "vit_b" / "vit_l" / "vit_h"
print(f"Using {encoder} from {backbone.upper()}")
self.encoder = get_vision_transformer(img_size=img_size, backbone=backbone, model=encoder)
if encoder_checkpoint is not None:
self._load_encoder_from_checkpoint(backbone, encoder, encoder_checkpoint)

in_chans = self.encoder.in_chans
if embed_dim is None:
embed_dim = self.encoder.embed_dim

else: # `nn.Module` ViT backbone
self.encoder = encoder

have_neck = False
for name, _ in self.encoder.named_parameters():
if name.startswith("neck"):
have_neck = True

if embed_dim is None:
if have_neck:
embed_dim = self.encoder.neck[2].out_channels # the value is 256
else:
embed_dim = self.encoder.patch_embed.proj.out_channels

in_chans = self.encoder.patch_embed.proj.in_channels

# parameters for the decoder network
depth = 3
Expand All @@ -99,18 +123,21 @@ def __init__(
else:
self.decoder = decoder

self.z_inputs = ConvBlock2d(self.encoder.in_chans, features_decoder[-1])
self.z_inputs = ConvBlock2d(in_chans, features_decoder[-1])

self.base = ConvBlock2d(embed_dim, features_decoder[0])

self.base = ConvBlock2d(self.encoder.embed_dim, features_decoder[0])
self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1)

self.deconv1 = Deconv2DBlock(self.encoder.embed_dim, features_decoder[0])
self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0])
self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1])
self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2])
self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3])

self.deconv_out = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])

self.deconv4 = SingleDeconv2DBlock(features_decoder[-1], features_decoder[-1])
self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1])

self.decoder_head = ConvBlock2d(2*features_decoder[-1], features_decoder[-1])
self.final_activation = self._get_activation(final_activation)

def _get_activation(self, activation):
Expand Down Expand Up @@ -167,26 +194,42 @@ def forward(self, x):
# backbone used for reshaping inputs to the desired "encoder" shape
x = torch.stack([self.preprocess(e) for e in x], dim=0)

z0 = self.z_inputs(x)
use_skip_connection = getattr(self, "use_skip_connection", True)

z12, from_encoder = self.encoder(x)
x = self.base(z12)
encoder_outputs = self.encoder(x)

from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])
if isinstance(self.encoder, ViT_Sam) or isinstance(self.encoder, ViT_MAE):
z12, from_encoder = encoder_outputs
else:
z12 = encoder_outputs

z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)
if use_skip_connection:
# TODO: we share the weights in the deconv(s), and should preferably avoid doing that
from_encoder = from_encoder[::-1]
z9 = self.deconv1(from_encoder[0])

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)
z6 = self.deconv1(from_encoder[1])
z6 = self.deconv2(z6)

z3 = self.deconv1(from_encoder[2])
z3 = self.deconv2(z3)
z3 = self.deconv3(z3)

z0 = self.z_inputs(x)

else:
z9 = self.deconv1(z12)
z6 = self.deconv2(z9)
z3 = self.deconv3(z6)
z0 = self.deconv4(z3)

updated_from_encoder = [z9, z6, z3]

x = self.base(z12)
x = self.decoder(x, encoder_inputs=updated_from_encoder)
x = self.deconv4(x)
x = torch.cat([x, z0], dim=1)
x = self.deconv_out(x)

x = torch.cat([x, z0], dim=1)
x = self.decoder_head(x)

x = self.out_conv(x)
Expand Down
5 changes: 5 additions & 0 deletions torch_em/transform/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def __init__(
boundary_distances=True,
directed_distances=False,
foreground=True,
instances=False,
apply_label=True,
correct_centers=True,
min_size=0,
Expand All @@ -313,6 +314,7 @@ def __init__(
self.boundary_distances = boundary_distances
self.directed_distances = directed_distances
self.foreground = foreground
self.instances = instances

self.apply_label = apply_label
self.correct_centers = correct_centers
Expand Down Expand Up @@ -441,4 +443,7 @@ def __call__(self, labels):
binary_labels = (labels > 0).astype("float32")
distances = np.concatenate([binary_labels[None], distances], axis=0)

if self.instances:
distances = np.concatenate([labels[None], distances], axis=0)

return distances