From 9d75668fbd26035dd23918db4583e56e84edbe56 Mon Sep 17 00:00:00 2001 From: Luca Date: Wed, 26 Jun 2024 17:13:29 +0200 Subject: [PATCH 01/24] implemented 3dsam train routine with lucchi data. still shape mismatch --- development/train_3d_model_with_lucchi.py | 83 +++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 development/train_3d_model_with_lucchi.py diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py new file mode 100644 index 00000000..cb56f725 --- /dev/null +++ b/development/train_3d_model_with_lucchi.py @@ -0,0 +1,83 @@ +import os +import argparse + +import torch + +from torch_em.data.datasets import get_lucchi_loader +import torch_em + +from micro_sam.sam_3d_wrapper import get_3d_sam_model +from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D +import micro_sam.training as sam_training + +def get_dataloaders(patch_shape, data_path, batch_size=1, num_workers=4): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + + train_loader = get_lucchi_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=batch_size, num_workers=num_workers, + download=True, shuffle=True, label_transform=label_transform, label_dtype=torch.float32 + ) + val_loader = get_lucchi_loader( + path=data_path, patch_shape=patch_shape, split="test", batch_size=batch_size, num_workers=num_workers, + download=True, shuffle=True, label_transform=label_transform, label_dtype=torch.float32 + ) + + return train_loader, val_loader + + +def train_on_lucchi(input_path, patch_shape, model_type, n_classes, n_iterations): + from micro_sam.training.util import ConvertToSemanticSamInputs + + device = "cuda" if torch.cuda.is_available() else "cpu" + sam_3d = get_3d_sam_model( + device, n_classes=n_classes, image_size=patch_shape[1], + model_type=model_type) + train_loader, val_loader = get_dataloaders(patch_shape, input_path) + optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) + + trainer = SemanticSamTrainer3D( + name="test-3d-sam", + model=sam_3d, + convert_inputs=ConvertToSemanticSamInputs(), + num_classes=n_classes, + train_loader=train_loader, + val_loader=val_loader, + optimizer=optimizer, + device=device, + compile_model=False, + ) + trainer.fit(n_iterations) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument("--patch_shape", type=int, nargs=3, default=(64, 256, 256), help="Patch shape for data loading (3D tuple)") + parser.add_argument("--n_iterations", type=int, default=10000, help="Number of training iterations") + parser.add_argument("--n_classes", type=int, default=1, help="Number of classes to predict") + args = parser.parse_args() + args = parser.parse_args() + train_on_lucchi( + args.input_path, args.patch_shape, args.model_type, + args.n_classes, args.n_iterations) + + +if __name__ == "__main__": + main() From 42f9f36458c1ac30141a9565309aa7221f343720 Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 27 Jun 2024 15:43:39 +0200 Subject: [PATCH 02/24] implemented training routine for 3d sam --- development/train_3d_model_with_lucchi.py | 96 ++++++++++++++++------- 1 file changed, 66 insertions(+), 30 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index cb56f725..82cd71fe 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -1,48 +1,78 @@ import os import argparse +import numpy as np import torch -from torch_em.data.datasets import get_lucchi_loader +from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset +from torch_em.segmentation import SegmentationDataset import torch_em from micro_sam.sam_3d_wrapper import get_3d_sam_model from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D import micro_sam.training as sam_training -def get_dataloaders(patch_shape, data_path, batch_size=1, num_workers=4): - """This returns the livecell data loaders implemented in torch_em: - https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py - It will automatically download the livecell data. - Note: to replace this with another data loader you need to return a torch data loader - that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. - The labels have to be in a label mask instance segmentation format. - I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. - Important: the ID 0 is reseved for background, and the IDs must be consecutive - """ - label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) - - train_loader = get_lucchi_loader( - path=data_path, patch_shape=patch_shape, split="train", batch_size=batch_size, num_workers=num_workers, - download=True, shuffle=True, label_transform=label_transform, label_dtype=torch.float32 - ) - val_loader = get_lucchi_loader( - path=data_path, patch_shape=patch_shape, split="test", batch_size=batch_size, num_workers=num_workers, - download=True, shuffle=True, label_transform=label_transform, label_dtype=torch.float32 - ) +class LucchiSegmentationDataset(SegmentationDataset): + def __init__(self, patch_shape, num_classes, label_transform=None, **kwargs): + super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor + self.num_classes = num_classes + + def __getitem__(self, index): + raw, label = super().__getitem__(index) + # raw shape: (z, color channels, x, y) channels is fixed to 3 + image_shape = (self.patch_shape[0], 1) + self.patch_shape[1:] + raw = raw.unsqueeze(2) + raw = raw.view(image_shape) + raw = raw.squeeze(0) + raw = raw.repeat(1, 3, 1, 1) + # label shape: (classes, z, x, y) + label_shape = (self.num_classes,) + self.patch_shape + label = label.view(label_shape) + return raw, label + - return train_loader, val_loader +def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, num_workers=1): + assert split in ("train", "test") + data_path = os.path.join(path, f"lucchi_{split}.h5") + raw_key, label_key = "raw", "labels" + ds = LucchiSegmentationDataset( + raw_path=data_path, label_path=data_path, raw_key=raw_key, + label_key=label_key, patch_shape=patch_shape, + num_classes=n_classes, label_transform=label_transform) + loader = torch.utils.data.DataLoader( + ds, batch_size=batch_size, shuffle=True, + num_workers=num_workers) + loader.shuffle = True + return loader -def train_on_lucchi(input_path, patch_shape, model_type, n_classes, n_iterations): +def train_on_lucchi(args): from micro_sam.training.util import ConvertToSemanticSamInputs + input_path = args.input_path + patch_shape = args.patch_shape + batch_size = args.batch_size + num_workers = args.num_workers + n_classes = args.n_classes + model_type = args.model_type + n_iterations = args.n_iterations + save_root = args.save_root + + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) device = "cuda" if torch.cuda.is_available() else "cpu" sam_3d = get_3d_sam_model( device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type) - train_loader, val_loader = get_dataloaders(patch_shape, input_path) + #get_dataloaders(patch_shape, input_path) + train_loader = get_loader( + input_path, split="train", patch_shape=patch_shape, + n_classes=n_classes, batch_size=batch_size, num_workers=num_workers, + label_transform=label_transform) + val_loader = get_loader( + input_path, split="test", patch_shape=patch_shape, + n_classes=n_classes, batch_size=batch_size, num_workers=num_workers, + label_transform=label_transform) optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) trainer = SemanticSamTrainer3D( @@ -55,6 +85,8 @@ def train_on_lucchi(input_path, patch_shape, model_type, n_classes, n_iterations optimizer=optimizer, device=device, compile_model=False, + save_root=save_root, + logger=None ) trainer.fit(n_iterations) @@ -70,13 +102,17 @@ def main(): help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) parser.add_argument("--patch_shape", type=int, nargs=3, default=(64, 256, 256), help="Patch shape for data loading (3D tuple)") - parser.add_argument("--n_iterations", type=int, default=10000, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=1, help="Number of classes to predict") - args = parser.parse_args() + parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") + parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--num_workers", type=int, default=4, help="num_workers") + parser.add_argument( + "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", + help="The filepath to where the logs and the checkpoints will be saved." + ) + args = parser.parse_args() - train_on_lucchi( - args.input_path, args.patch_shape, args.model_type, - args.n_classes, args.n_iterations) + train_on_lucchi(args) if __name__ == "__main__": From 9be15d51857238efc8cb460376d721b132f3ee1b Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 27 Jun 2024 15:58:28 +0200 Subject: [PATCH 03/24] tidied up code --- development/train_3d_model_with_lucchi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 82cd71fe..87bd3f4f 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -64,7 +64,6 @@ def train_on_lucchi(args): sam_3d = get_3d_sam_model( device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type) - #get_dataloaders(patch_shape, input_path) train_loader = get_loader( input_path, split="train", patch_shape=patch_shape, n_classes=n_classes, batch_size=batch_size, num_workers=num_workers, From b0fc01abe5dc442736ab8f4122fa80695049fc6b Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 27 Jun 2024 16:48:47 +0200 Subject: [PATCH 04/24] changed dataset esp. label shape not depending on num_classes --- development/train_3d_model_with_lucchi.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 87bd3f4f..ccc517ca 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -14,21 +14,21 @@ class LucchiSegmentationDataset(SegmentationDataset): - def __init__(self, patch_shape, num_classes, label_transform=None, **kwargs): + def __init__(self, patch_shape, label_transform=None, **kwargs): super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor - self.num_classes = num_classes def __getitem__(self, index): raw, label = super().__getitem__(index) - # raw shape: (z, color channels, x, y) channels is fixed to 3 + # raw shape: (z, color channels, y, x) channels is fixed to 3 image_shape = (self.patch_shape[0], 1) + self.patch_shape[1:] raw = raw.unsqueeze(2) raw = raw.view(image_shape) raw = raw.squeeze(0) - raw = raw.repeat(1, 3, 1, 1) - # label shape: (classes, z, x, y) - label_shape = (self.num_classes,) + self.patch_shape - label = label.view(label_shape) + raw = raw.repeat(1, 3, 1, 1) + print("raw shape", raw.shape) + # wanted label shape: (1, z, y, x) + label = (label != 0).to(torch.float) + print("label shape", label.shape) return raw, label @@ -38,8 +38,7 @@ def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, raw_key, label_key = "raw", "labels" ds = LucchiSegmentationDataset( raw_path=data_path, label_path=data_path, raw_key=raw_key, - label_key=label_key, patch_shape=patch_shape, - num_classes=n_classes, label_transform=label_transform) + label_key=label_key, patch_shape=patch_shape, label_transform=label_transform) loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) @@ -59,7 +58,7 @@ def train_on_lucchi(args): save_root = args.save_root label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) - + label_transform = None device = "cuda" if torch.cuda.is_available() else "cpu" sam_3d = get_3d_sam_model( device, n_classes=n_classes, image_size=patch_shape[1], From 8ca13261a67cc90ba581ea26af601ff33a0b6d5f Mon Sep 17 00:00:00 2001 From: Luca Freckmann Date: Fri, 28 Jun 2024 15:33:52 +0200 Subject: [PATCH 05/24] added check_loader --- development/train_3d_model_with_lucchi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index ccc517ca..2463bcbb 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -7,6 +7,7 @@ from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset from torch_em.segmentation import SegmentationDataset import torch_em +from torch_em.util.debug import check_loader from micro_sam.sam_3d_wrapper import get_3d_sam_model from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D @@ -86,6 +87,7 @@ def train_on_lucchi(args): save_root=save_root, logger=None ) + # check_loader(train_loader, n_samples=10) trainer.fit(n_iterations) From ca864ed2bb40ee32aaab56dff359220c41ae5785 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 10 Jun 2024 00:30:26 +0200 Subject: [PATCH 06/24] Add mentions for annotating 3D RGB volumes (#629) * Update faq.md * Add described answer to the issue --- doc/faq.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/faq.md b/doc/faq.md index f76b7430..ebcba331 100644 --- a/doc/faq.md +++ b/doc/faq.md @@ -131,6 +131,10 @@ Editing (drawing / erasing) very large 2d images or 3d volumes is known to be sl This can happen for long running computations. You just need to wait a bit longer and the computation will finish. +### 14. I have 3D RGB microscopy volumes. How does `micro_sam` handle these images? +`micro_sam` performs automatic segmentation in 3D volumes by first segmenting slices individually in 2D and merging the segmentations across 3D based on overlap of objects between slices. The expected shape of your 3D RGB volume should be `(Z * Y * X * 3)` (reason: Segment Anything is devised to consider 3-channel inputs, so while the user provides micro-sam with 1-channel inputs, we handle this by triplicating this to fit the requirement, or with 3-channel inputs, we use them in the expected RGB array structures as it is). + + ## Fine-tuning questions From a66c09fb846ece97e36b896a99707f9ca47067a5 Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 28 Jun 2024 15:48:41 +0200 Subject: [PATCH 07/24] tidied up code --- development/train_3d_model_with_lucchi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 2463bcbb..12f05a26 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -26,10 +26,10 @@ def __getitem__(self, index): raw = raw.view(image_shape) raw = raw.squeeze(0) raw = raw.repeat(1, 3, 1, 1) - print("raw shape", raw.shape) + # print("raw shape", raw.shape) # wanted label shape: (1, z, y, x) label = (label != 0).to(torch.float) - print("label shape", label.shape) + # print("label shape", label.shape) return raw, label @@ -75,7 +75,7 @@ def train_on_lucchi(args): optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) trainer = SemanticSamTrainer3D( - name="test-3d-sam", + name="3d-sam-lucchi", model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, @@ -103,7 +103,7 @@ def main(): ) parser.add_argument("--patch_shape", type=int, nargs=3, default=(64, 256, 256), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") + parser.add_argument("--n_classes", type=int, default=1, help="Number of classes to predict") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument( From a5e937a640e8fcaf7410a1f119b4d7f456400ccb Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:19:16 +0200 Subject: [PATCH 08/24] Add SemanticSam3dLogger (#643) Updates to SAM 3d training --------- Co-authored-by: Constantin Pape --- micro_sam/sam_3d_wrapper.py | 20 ++++- micro_sam/training/semantic_sam_trainer.py | 90 +++++++++++++++++++--- 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index ccb9968e..1676652b 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -9,11 +9,23 @@ from .util import get_sam_model -def get_3d_sam_model(device, n_classes, image_size, model_type="vit_b"): - predictor, sam = get_sam_model( - return_sam=True, model_type=model_type, device=device, num_multimask_outputs=n_classes, - flexible_load_checkpoint=True, image_size=image_size, +def get_3d_sam_model( + device, + n_classes, + image_size, + model_type="vit_b", + checkpoint_path=None, +): + _, sam = get_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_sam=True, + flexible_load_checkpoint=True, + num_multimask_outputs=n_classes, + image_size=image_size, ) + sam_3d = Sam3DWrapper(sam) sam_3d.to(device) return sam_3d diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index b3f1cc0a..93228752 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,10 +1,36 @@ import time +import numpy as np + import torch import torch.nn as nn from torch_em.loss import DiceLoss from torch_em.trainer import DefaultTrainer +from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im + + +class CustomDiceLoss(nn.Module): + def __init__(self, num_classes: int, softmax: bool = True) -> None: + super().__init__() + self.num_classes = num_classes + self.dice_loss = DiceLoss() + self.softmax = softmax + + def _one_hot_encoder(self, input_tensor): + tensor_list = [] + for i in range(self.num_classes): + temp_prob = input_tensor == i # * torch.ones_like(input_tensor) + tensor_list.append(temp_prob) + output_tensor = torch.cat(tensor_list, dim=1) + return output_tensor.float() + + def __call__(self, pred, target): + if self.softmax: + pred = torch.softmax(pred, dim=1) + target = self._one_hot_encoder(target) + loss = self.dice_loss(pred, target) + return loss class SemanticSamTrainer(DefaultTrainer): @@ -13,31 +39,35 @@ class SemanticSamTrainer(DefaultTrainer): def __init__( self, convert_inputs, - num_classes: int = 1, + num_classes: int, **kwargs ): - loss = DiceLoss() - metric = DiceLoss() - super().__init__(loss=loss, metric=metric, **kwargs) + assert num_classes > 1 + + loss = CustomDiceLoss(num_classes=num_classes) + metric = CustomDiceLoss(num_classes=num_classes) + logger = SemanticSamLogger + super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes - self.compute_ce_loss = nn.BCELoss() if num_classes == 1 else nn.CrossEntropyLoss() + self.compute_ce_loss = nn.CrossEntropyLoss() self._kwargs = kwargs def _compute_loss(self, y, masks): + target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions - dice_loss = self.loss(masks, y.to(self.device, non_blocking=True)) + dice_loss = self.loss(masks, target) # Compute cross entropy loss for the predictions - ce_loss = self.compute_ce_loss(masks, y.to(self.device, non_blocking=True)) + ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) net_loss = dice_loss + ce_loss return net_loss def _get_model_outputs(self, batched_inputs): image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) - batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1)) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) return masks @@ -56,11 +86,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(net_loss) + self._iteration += 1 + if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True) + self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) - self._iteration += 1 if self._iteration >= self.max_iteration: break progress.update(1) @@ -86,8 +117,9 @@ def _validate_impl(self, forward_context): loss_val /= len(self.val_loader) metric_val /= len(self.val_loader) + dice_metric = 1 - (metric_val / self.num_classes) print() - print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}") + print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) @@ -107,3 +139,39 @@ def _get_model_outputs(self, batched_inputs): # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) masks = batched_outputs["masks"] return masks + + +class SemanticSamLogger(TensorboardLogger): + def log_images(self, step, x, y, prediction, name, gradients=None): + + selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] + selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] + + image = normalize_im(x[selection_image].cpu()) + self.tb.add_image(tag=f"{name}/input", + img_tensor=image, + global_step=step) + + prediction = torch.softmax(prediction, dim=1) + im, im_name = self.make_image(image, y, prediction, selection, gradients) + im_name = f"{name}/{im_name}" + self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) + + def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + + # the embedding visualisation function currently doesn't support gradients, + # so we can't log them even if log_gradients is true + log_grads = log_gradients + if self.have_embeddings: + log_grads = False + + if step % self.log_image_interval == 0: + gradients = prediction.grad if log_grads else None + self.log_images(step, x, y, prediction, "train", gradients=gradients) + + def log_validation(self, step, metric, loss, x, y, prediction): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + self.log_images(step, x, y, prediction, "validation") From 1592988b0ba03c30ce6462e84f39205bb11492e0 Mon Sep 17 00:00:00 2001 From: Luca Freckmann Date: Thu, 4 Jul 2024 08:57:30 +0200 Subject: [PATCH 09/24] added new training and predict scripts --- development/predict_3d_model_with_lucchi.py | 191 ++++++++++++++++++++ development/train_3d_model_with_lucchi.py | 131 ++++++++++---- 2 files changed, 290 insertions(+), 32 deletions(-) create mode 100644 development/predict_3d_model_with_lucchi.py diff --git a/development/predict_3d_model_with_lucchi.py b/development/predict_3d_model_with_lucchi.py new file mode 100644 index 00000000..9cdcdbc0 --- /dev/null +++ b/development/predict_3d_model_with_lucchi.py @@ -0,0 +1,191 @@ +import os +import argparse +from tqdm import tqdm +import numpy as np +import imageio.v3 as imageio +from elf.io import open_file +from skimage.measure import label as connected_components + +import torch +from glob import glob + +from torch_em.util.segmentation import size_filter +from torch_em.util import load_model +from torch_em.transform.raw import normalize +from torch_em.util.prediction import predict_with_halo + +from micro_sam import util +from micro_sam.evaluation.inference import _run_inference_with_iterative_prompting_for_image + +from segment_anything import SamPredictor + +from micro_sam.models.sam_3d_wrapper import get_sam_3d_model +from typing import List, Union, Dict, Optional, Tuple + + +class RawTrafoFor3dInputs: + def _normalize_inputs(self, raw): + raw = normalize(raw) + raw = raw * 255 + return raw + + def _set_channels_for_inputs(self, raw): + raw = np.stack([raw] * 3, axis=0) + return raw + + def __call__(self, raw): + raw = self._normalize_inputs(raw) + raw = self._set_channels_for_inputs(raw) + return raw + + +def _run_semantic_segmentation_for_image_3d( + model: torch.nn.Module, + image: np.ndarray, + prediction_path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + halo: Tuple[int, int, int], +): + device = next(model.parameters()).device + block_shape = tuple(bs - 2 * ha for bs, ha in zip(patch_shape, halo)) + + def preprocess(x): + x = 255 * normalize(x) + x = np.stack([x] * 3) + return x + + def prediction_function(net, inp): + # Note: we have two singleton axis in front here, I am not quite sure why. + # Both need to be removed to be compatible with the SAM network. + batched_input = [{ + "image": inp[0, 0], "original_size": inp.shape[-2:] + }] + masks = net(batched_input, multimask_output=True)[0]["masks"] + masks = torch.argmax(masks, dim=1) + return masks + + # num_classes = model.sam_model.mask_decoder.num_multimask_outputs + image_size = patch_shape[-1] + output = np.zeros(image.shape, dtype="float32") + predict_with_halo( + image, model, gpu_ids=[device], + block_shape=block_shape, halo=halo, + preprocess=preprocess, output=output, + prediction_function=prediction_function + ) + + # save the segmentations + imageio.imwrite(prediction_path, output, compression="zlib") + + +def run_semantic_segmentation_3d( + model: torch.nn.Module, + image_paths: List[Union[str, os.PathLike]], + prediction_dir: Union[str, os.PathLike], + semantic_class_map: Dict[str, int], + patch_shape: Tuple[int, int, int] = (32, 512, 512), + halo: Tuple[int, int, int] = (6, 64, 64), + image_key: Optional[str] = None, + is_multiclass: bool = False, +): + """ + """ + for image_path in tqdm(image_paths, desc="Run inference for semantic segmentation with all images"): + image_name = os.path.basename(image_path) + + assert os.path.exists(image_path), image_path + + # Perform segmentation only on the semantic class + for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): + if is_multiclass: + semantic_class_name = "all" + if i > 0: # We only perform segmentation for multiclass once. + continue + + # We skip the images that already have been segmented + image_name = os.path.splitext(image_name)[0] + ".tif" + prediction_path = os.path.join(prediction_dir, semantic_class_name, image_name) + if os.path.exists(prediction_path): + continue + + if image_key is None: + image = imageio.imread(image_path) + else: + with open_file(image_path, "r") as f: + image = f[image_key][:] + + # create the prediction folder + os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) + + _run_semantic_segmentation_for_image_3d( + model=model, image=image, prediction_path=prediction_path, + patch_shape=patch_shape, halo=halo, + ) + + +def transform_labels(y): + return (y > 0).astype("float32") + + +def predict(args): + + device = "cuda" if torch.cuda.is_available() else "cpu" + if args.checkpoint_path is not None: + if os.path.exists(args.checkpoint_path): + # model = load_model(checkpoint=args.checkpoint_path, device=device) # does not work + + cp_path = os.path.join(args.checkpoint_path, "", "best.pt") + print(cp_path) + model = get_sam_3d_model(device, n_classes=args.n_classes, image_size=args.patch_shape[1], + lora_rank=4, + model_type=args.model_type, + checkpoint_path=cp_path + ) + + # checkpoint = torch.load(cp_path, map_location=device) + # #print(checkpoint.keys()) + # # # Load the state dictionary from the checkpoint + # model.load_state_dict(checkpoint['model_state']) + model.eval() + + data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) + pred_path = args.save_root + semantic_class_map = {"all": 0} + + run_semantic_segmentation_3d( + model=model, image_paths=data_paths, prediction_dir=pred_path, semantic_class_map=semantic_class_map, + patch_shape=args.patch_shape, image_key="raw", is_multiclass=True + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") + parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") + parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--num_workers", type=int, default=4, help="num_workers") + parser.add_argument( + "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", + help="The filepath to where the logs and the checkpoints will be saved." + ) + parser.add_argument( + "--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-lucchi-train/", + help="The filepath to where the logs and the checkpoints will be saved." + ) + + args = parser.parse_args() + + predict(args) + + +if __name__ == "__main__": + main() diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 12f05a26..d393ece6 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -1,19 +1,66 @@ import os import argparse import numpy as np - +from math import ceil, floor import torch from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset from torch_em.segmentation import SegmentationDataset import torch_em from torch_em.util.debug import check_loader +from torch_em.transform.raw import normalize -from micro_sam.sam_3d_wrapper import get_3d_sam_model -from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D +from micro_sam.models.sam_3d_wrapper import get_sam_3d_model +from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer import micro_sam.training as sam_training +class RawTrafoFor3dInputs: + def _normalize_inputs(self, raw): + raw = normalize(raw) + raw = raw * 255 + return raw + + def _set_channels_for_inputs(self, raw): + raw = np.stack([raw] * 3, axis=0) + return raw + + def __call__(self, raw): + raw = self._normalize_inputs(raw) + raw = self._set_channels_for_inputs(raw) + return raw + + +# for sega +class RawResizeTrafoFor3dInputs(RawTrafoFor3dInputs): + def __init__(self, desired_shape, padding="constant"): + super().__init__() + self.desired_shape = desired_shape + self.padding = padding + + def __call__(self, raw): + raw = self._normalize_inputs(raw) + + # let's pad the inputs + tmp_ddim = ( + self.desired_shape[0] - raw.shape[0], + self.desired_shape[1] - raw.shape[1], + self.desired_shape[2] - raw.shape[2] + ) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2) + raw = np.pad( + raw, + pad_width=( + (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2])) + ), + mode=self.padding + ) + + raw = self._set_channels_for_inputs(raw) + + return raw + + class LucchiSegmentationDataset(SegmentationDataset): def __init__(self, patch_shape, label_transform=None, **kwargs): super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor @@ -33,18 +80,33 @@ def __getitem__(self, index): return raw, label -def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, num_workers=1): - assert split in ("train", "test") - data_path = os.path.join(path, f"lucchi_{split}.h5") - raw_key, label_key = "raw", "labels" - ds = LucchiSegmentationDataset( - raw_path=data_path, label_path=data_path, raw_key=raw_key, - label_key=label_key, patch_shape=patch_shape, label_transform=label_transform) - loader = torch.utils.data.DataLoader( - ds, batch_size=batch_size, shuffle=True, - num_workers=num_workers) - loader.shuffle = True - return loader +def transform_labels(y): + return (y > 0).astype("float32") + + +def get_loaders(input_path, patch_shape): + train_loader = get_lucchi_loader( + input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True, + raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels, + n_samples=100 + ) + val_loader = get_lucchi_loader( + input_path, split="test", patch_shape=patch_shape, batch_size=1, + raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels + ) + return train_loader, val_loader +# def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, num_workers=1): +# assert split in ("train", "test") +# data_path = os.path.join(path, f"lucchi_{split}.h5") +# raw_key, label_key = "raw", "labels" +# ds = LucchiSegmentationDataset( +# raw_path=data_path, label_path=data_path, raw_key=raw_key, +# label_key=label_key, patch_shape=patch_shape, label_transform=label_transform) +# loader = torch.utils.data.DataLoader( +# ds, batch_size=batch_size, shuffle=True, +# num_workers=num_workers) +# loader.shuffle = True +# return loader def train_on_lucchi(args): @@ -58,24 +120,29 @@ def train_on_lucchi(args): n_iterations = args.n_iterations save_root = args.save_root - label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) - label_transform = None + # label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + # label_transform = None + raw_data = np.random.rand(64, 256, 256) # Shape (z, y, x) + raw_data2, label = next(iter(get_lucchi_loader(input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True))) + + # Create an instance of RawTrafoFor3dInputs + transformer = RawTrafoFor3dInputs() + + # Apply transformations + processed_data = transformer(raw_data) + processed_data2 = transformer(raw_data2) + print("input (64,256,256)", processed_data.shape) + print("input", raw_data2.shape, processed_data2.shape) + device = "cuda" if torch.cuda.is_available() else "cpu" - sam_3d = get_3d_sam_model( + sam_3d = get_sam_3d_model( device, n_classes=n_classes, image_size=patch_shape[1], - model_type=model_type) - train_loader = get_loader( - input_path, split="train", patch_shape=patch_shape, - n_classes=n_classes, batch_size=batch_size, num_workers=num_workers, - label_transform=label_transform) - val_loader = get_loader( - input_path, split="test", patch_shape=patch_shape, - n_classes=n_classes, batch_size=batch_size, num_workers=num_workers, - label_transform=label_transform) + model_type=model_type, lora_rank=4) + train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) - trainer = SemanticSamTrainer3D( - name="3d-sam-lucchi", + trainer = SemanticSamTrainer( + name="3d-sam-lucchi-train", model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, @@ -85,7 +152,7 @@ def train_on_lucchi(args): device=device, compile_model=False, save_root=save_root, - logger=None + #logger=None ) # check_loader(train_loader, n_samples=10) trainer.fit(n_iterations) @@ -101,9 +168,9 @@ def main(): "--model_type", "-m", default="vit_b", help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) - parser.add_argument("--patch_shape", type=int, nargs=3, default=(64, 256, 256), help="Patch shape for data loading (3D tuple)") + parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=1, help="Number of classes to predict") + parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument( From b61ee0452358b2ccb679cf3b00333f735b3c760e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 28 Jun 2024 16:16:35 +0200 Subject: [PATCH 10/24] Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapper (#645) Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapper, simplify lora support --- micro_sam/sam_3d_wrapper.py | 22 +++- micro_sam/simple_sam_3d_wrapper.py | 159 +++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 micro_sam/simple_sam_3d_wrapper.py diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index 1676652b..5b40608b 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -13,9 +13,20 @@ def get_3d_sam_model( device, n_classes, image_size, + lora_rank=None, + freeze_encoder=False, model_type="vit_b", checkpoint_path=None, ): + if lora_rank is None: + use_lora = False + rank = None + freeze_encoder_ = freeze_encoder + else: + use_lora = True + rank = lora_rank + freeze_encoder_ = False + _, sam = get_sam_model( model_type=model_type, device=device, @@ -24,15 +35,17 @@ def get_3d_sam_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, + use_lora=use_lora, + rank=rank, ) - sam_3d = Sam3DWrapper(sam) + sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d class Sam3DWrapper(nn.Module): - def __init__(self, sam_model: Sam): + def __init__(self, sam_model: Sam, freeze_encoder: bool): """ Initializes the Sam3DWrapper object. @@ -45,6 +58,11 @@ def __init__(self, sam_model: Sam): ) self.sam_model = sam_model + self.freeze_encoder = freeze_encoder + if self.freeze_encoder: + for param in self.sam_model.image_encoder.parameters(): + param.requires_grad = False + # FIXME # - handling of the image size here is wrong, this only works for square images # - this does not take care of resizing diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/simple_sam_3d_wrapper.py new file mode 100644 index 00000000..ba33391b --- /dev/null +++ b/micro_sam/simple_sam_3d_wrapper.py @@ -0,0 +1,159 @@ +from contextlib import nullcontext + +import torch +import torch.nn as nn + +from .util import get_sam_model + + +def get_simple_3d_sam_model( + device, + n_classes, + image_size, + lora_rank=None, + freeze_encoder=False, + model_type="vit_b", + checkpoint_path=None, +): + if lora_rank is None: + use_lora = False + rank = None + freeze_encoder_ = freeze_encoder + else: + use_lora = True + rank = lora_rank + freeze_encoder_ = False + + _, sam = get_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_sam=True, + image_size=image_size, + flexible_load_checkpoint=True, + use_lora=use_lora, + rank=rank, + ) + + sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) + sam_3d.to(device) + return sam_3d + + +class BasicBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True, + mode="nearest" + ): + super().__init__() + + self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU() + ) + + self.conv2 = nn.Sequential( + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + nn.InstanceNorm3d(out_channels) + ) + + self.downsample = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), + nn.InstanceNorm3d(out_channels) + ) + + self.leakyrelu = nn.LeakyReLU() + + self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) + + def forward(self, x): + residual = self.downsample(x) + + out = self.conv1(x) + out = self.conv2(out) + out += residual + + out = self.leakyrelu(out) + out = self.up(out) + return out + + +class SegmentationHead(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True + ): + super().__init__() + + self.conv_pred = nn.Sequential( + nn.Conv3d( + in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ), + nn.InstanceNorm3d(in_channels // 2), + nn.LeakyReLU() + ) + self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) + + def forward(self, x): + x = self.conv_pred(x) + return self.segmentation_head(x) + + +class SimpleSam3DWrapper(nn.Module): + def __init__(self, sam, num_classes, freeze_encoder): + super().__init__() + + self.sam = sam + self.freeze_encoder = freeze_encoder + if self.freeze_encoder: + for param in self.sam.image_encoder.parameters(): + param.requires_grad = False + self.no_grad = torch.no_grad + + else: + self.no_grad = nullcontext + + self.decoders = nn.ModuleList([ + BasicBlock(in_channels=256, out_channels=128), + BasicBlock(in_channels=128, out_channels=64), + BasicBlock(in_channels=64, out_channels=32), + BasicBlock(in_channels=32, out_channels=16), + ]) + self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) + + def _apply_image_encoder(self, x, D): + encoder_features = [] + for d in range(D): + image = x[:, d] + feature = self.sam.image_encoder(image) + encoder_features.append(feature) + encoder_features = torch.stack(encoder_features, 1) + encoder_features = encoder_features.transpose(1, 2) + return encoder_features + + def forward(self, x, **kwargs): + B, D, C, H, W = x.shape + assert C == 3 + + with self.no_grad(): + features = self._apply_image_encoder(x, D) + + out = features + for decoder in self.decoders: + out = decoder(out) + logits = self.out_conv(out) + + outputs = {"masks": logits} + return outputs From c64944db59a9fbad2c43fe1f1aa2a6ee8423d12e Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:59:13 +0200 Subject: [PATCH 11/24] Minor fix to trainable sam model functionality (#646) Minor fix to trainable sam model functionality --- micro_sam/training/util.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 3e4f01e3..6ad6ce40 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -13,7 +13,6 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) -from .peft_sam import PEFT_Sam from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -87,21 +86,18 @@ def get_trainable_sam_model( # (for e.g. encoder blocks to "image_encoder") if freeze is not None: for name, param in sam.named_parameters(): - if isinstance(freeze, list): - # we would want to "freeze" all the components in the model if passed a list of parts - for l_item in freeze: - if name.startswith(f"{l_item}"): - param.requires_grad = False - else: + if not isinstance(freeze, list): # we "freeze" only for one specific component when passed a "particular" part - if name.startswith(f"{freeze}"): - param.requires_grad = False + freeze = [freeze] + + # we would want to "freeze" all the components in the model if passed a list of parts + for l_item in freeze: + # in case LoRA is switched on, we cannot freeze the image encoder + if use_lora and (l_item == "image_encoder"): + raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") - # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything - if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers - if rank is None: - rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them - sam = PEFT_Sam(sam, rank=rank).sam + if name.startswith(f"{l_item}"): + param.requires_grad = False # convert to trainable sam trainable_sam = TrainableSAM(sam) From 70cf9b7ac42f8e252c1039a932d7521f110bd3ae Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 29 Jun 2024 23:14:00 +0200 Subject: [PATCH 12/24] Fix dimension order in 3d sam wrappers --- micro_sam/sam_3d_wrapper.py | 9 +++- micro_sam/simple_sam_3d_wrapper.py | 5 +-- micro_sam/training/semantic_sam_trainer.py | 50 +++------------------- 3 files changed, 15 insertions(+), 49 deletions(-) diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index 5b40608b..4582cfc4 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -71,9 +71,14 @@ def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: return self._forward_train(batched_input, multimask_output, image_size) def _forward_train(self, batched_input, multimask_output, image_size): - # dimensions: [b, d, 3, h, w] + # dimensions: [b, 3, d, h, w] shape = batched_input.shape - batch_size, d_size, hw_size = shape[0], shape[1], shape[-2] + assert shape[1] == 3 + batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] + # Transpose the axes, so that the depth axis is the first axis and the channel + # axis is the second axis. This is expected by the transformer! + batched_input = batched_input.transpose(1, 2) + assert batched_input.shape[1] == d_size batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) input_images = self.sam_model.preprocess(batched_input) diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/simple_sam_3d_wrapper.py index ba33391b..30c8c20a 100644 --- a/micro_sam/simple_sam_3d_wrapper.py +++ b/micro_sam/simple_sam_3d_wrapper.py @@ -136,11 +136,10 @@ def __init__(self, sam, num_classes, freeze_encoder): def _apply_image_encoder(self, x, D): encoder_features = [] for d in range(D): - image = x[:, d] + image = x[:, :, d] feature = self.sam.image_encoder(image) encoder_features.append(feature) - encoder_features = torch.stack(encoder_features, 1) - encoder_features = encoder_features.transpose(1, 2) + encoder_features = torch.stack(encoder_features, 2) return encoder_features def forward(self, x, **kwargs): diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 93228752..6e3dad7e 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,13 +1,10 @@ import time -import numpy as np - import torch import torch.nn as nn from torch_em.loss import DiceLoss from torch_em.trainer import DefaultTrainer -from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im class CustomDiceLoss(nn.Module): @@ -46,8 +43,7 @@ def __init__( loss = CustomDiceLoss(num_classes=num_classes) metric = CustomDiceLoss(num_classes=num_classes) - logger = SemanticSamLogger - super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) + super().__init__(loss=loss, metric=metric, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes @@ -90,7 +86,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) + self.logger.log_train( + self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False + ) if self._iteration >= self.max_iteration: break @@ -122,7 +120,7 @@ def _validate_impl(self, forward_context): print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: - self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) + self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)) return metric_val @@ -133,45 +131,9 @@ def _get_model_outputs(self, batched_inputs): image_size = batched_inputs[0]["original_size"][-1] batched_outputs = self.model( model_input, - multimask_output=(self.num_classes > 1), + multimask_output=True, image_size=image_size ) # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) masks = batched_outputs["masks"] return masks - - -class SemanticSamLogger(TensorboardLogger): - def log_images(self, step, x, y, prediction, name, gradients=None): - - selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] - selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] - - image = normalize_im(x[selection_image].cpu()) - self.tb.add_image(tag=f"{name}/input", - img_tensor=image, - global_step=step) - - prediction = torch.softmax(prediction, dim=1) - im, im_name = self.make_image(image, y, prediction, selection, gradients) - im_name = f"{name}/{im_name}" - self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) - - def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): - self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) - self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) - - # the embedding visualisation function currently doesn't support gradients, - # so we can't log them even if log_gradients is true - log_grads = log_gradients - if self.have_embeddings: - log_grads = False - - if step % self.log_image_interval == 0: - gradients = prediction.grad if log_grads else None - self.log_images(step, x, y, prediction, "train", gradients=gradients) - - def log_validation(self, step, metric, loss, x, y, prediction): - self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) - self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.log_images(step, x, y, prediction, "validation") From 09af0a77e0ef826f0c495860d6be662346ca14ba Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 2 Jul 2024 10:25:22 +0200 Subject: [PATCH 13/24] Api cleanup (#648) Clean up interfaces related to 3d models and PEFT --- micro_sam/models/__init__.py | 2 + micro_sam/{training => }/models/build_sam.py | 0 micro_sam/{training => models}/peft_sam.py | 10 +-- micro_sam/{ => models}/sam_3d_wrapper.py | 81 ++++++++++--------- .../{ => models}/simple_sam_3d_wrapper.py | 43 ++++++---- micro_sam/training/semantic_sam_trainer.py | 28 +++---- micro_sam/training/util.py | 19 ++--- micro_sam/util.py | 45 ++++++----- test/test_bioimageio/test_model_export.py | 1 + .../models => test/test_models}/__init__.py | 0 test/test_models/test_peft_sam.py | 26 ++++++ test/test_models/test_sam_3d_wrapper.py | 27 +++++++ .../test_models/test_simple_sam_3d_wrapper.py | 29 +++++++ test/test_peft_training.py | 49 ----------- 14 files changed, 207 insertions(+), 153 deletions(-) create mode 100644 micro_sam/models/__init__.py rename micro_sam/{training => }/models/build_sam.py (100%) rename micro_sam/{training => models}/peft_sam.py (90%) rename micro_sam/{ => models}/sam_3d_wrapper.py (73%) rename micro_sam/{ => models}/simple_sam_3d_wrapper.py (75%) rename {micro_sam/training/models => test/test_models}/__init__.py (100%) create mode 100644 test/test_models/test_peft_sam.py create mode 100644 test/test_models/test_sam_3d_wrapper.py create mode 100644 test/test_models/test_simple_sam_3d_wrapper.py delete mode 100644 test/test_peft_training.py diff --git a/micro_sam/models/__init__.py b/micro_sam/models/__init__.py new file mode 100644 index 00000000..27377e7b --- /dev/null +++ b/micro_sam/models/__init__.py @@ -0,0 +1,2 @@ +from .build_sam import sam_model_registry +from .peft_sam import PEFT_Sam diff --git a/micro_sam/training/models/build_sam.py b/micro_sam/models/build_sam.py similarity index 100% rename from micro_sam/training/models/build_sam.py rename to micro_sam/models/build_sam.py diff --git a/micro_sam/training/peft_sam.py b/micro_sam/models/peft_sam.py similarity index 90% rename from micro_sam/training/peft_sam.py rename to micro_sam/models/peft_sam.py index c67db7cb..dcc38a56 100644 --- a/micro_sam/training/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -53,9 +53,9 @@ def forward(self, x): class PEFT_Sam(nn.Module): - """Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/ + """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. - Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. + Inspired by https://github.com/JamesQFreeman/Sam_LoRA/ Args: model: The Segment Anything model. @@ -71,16 +71,14 @@ def __init__( peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Union[List[int]] = None ): - super(PEFT_Sam, self).__init__() + super().__init__() assert rank > 0 if attention_layers_to_update: self.peft_layers = attention_layers_to_update else: # Applies PEFT to the image encoder by default - self.peft_layers = list( - range(len(model.image_encoder.blocks)) - ) + self.peft_layers = list(range(len(model.image_encoder.blocks))) self.peft_module = peft_module self.peft_blocks = [] diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py similarity index 73% rename from micro_sam/sam_3d_wrapper.py rename to micro_sam/models/sam_3d_wrapper.py index 4582cfc4..4a7645d0 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, List, Dict, Type import torch import torch.nn as nn @@ -6,10 +6,10 @@ from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam -from .util import get_sam_model +from ..util import get_sam_model -def get_3d_sam_model( +def get_sam_3d_model( device, n_classes, image_size, @@ -18,15 +18,8 @@ def get_3d_sam_model( model_type="vit_b", checkpoint_path=None, ): - if lora_rank is None: - use_lora = False - rank = None - freeze_encoder_ = freeze_encoder - else: - use_lora = True - rank = lora_rank - freeze_encoder_ = False - + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False _, sam = get_sam_model( model_type=model_type, device=device, @@ -35,8 +28,7 @@ def get_3d_sam_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, ) sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) @@ -46,11 +38,10 @@ def get_3d_sam_model( class Sam3DWrapper(nn.Module): def __init__(self, sam_model: Sam, freeze_encoder: bool): - """ - Initializes the Sam3DWrapper object. + """Initializes the Sam3DWrapper object. Args: - sam_model (Sam): The Sam model to be wrapped. + sam_model: The Sam model to be wrapped. """ super().__init__() sam_model.image_encoder = ImageEncoderViT3DWrapper( @@ -63,25 +54,42 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool): for param in self.sam_model.image_encoder.parameters(): param.requires_grad = False - # FIXME - # - handling of the image size here is wrong, this only works for square images - # - this does not take care of resizing - # unclear how batches are handled - def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: - return self._forward_train(batched_input, multimask_output, image_size) + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool + ) -> List[Dict[str, torch.Tensor]]: + """Predict 3D masks for the current inputs. + + Unlike original SAM this model only supports automatic segmentation and does not support prompts. + + Args: + batched_input: A list over input images, each a dictionary with the following keys.L + 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. + 'original_size': The original size of the image (HxW) before transformation. + multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. + + Returns: + A list over input images, where each element is as dictionary with the following keys: + 'masks': Mask prediction for this object. + 'iou_predictions': IOU score prediction for this object. + 'low_res_masks': Low resolution mask prediction for this object. + """ + batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) + original_size = batched_input[0]["original_size"] + assert all(inp["original_size"] == original_size for inp in batched_input) - def _forward_train(self, batched_input, multimask_output, image_size): # dimensions: [b, 3, d, h, w] - shape = batched_input.shape + shape = batched_images.shape assert shape[1] == 3 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] # Transpose the axes, so that the depth axis is the first axis and the channel # axis is the second axis. This is expected by the transformer! - batched_input = batched_input.transpose(1, 2) - assert batched_input.shape[1] == d_size - batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) + batched_images = batched_images.transpose(1, 2) + assert batched_images.shape[1] == d_size + batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) - input_images = self.sam_model.preprocess(batched_input) + input_images = self.sam_model.preprocess(batched_images) image_embeddings = self.sam_model.image_encoder(input_images, d_size) sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( points=None, boxes=None, masks=None @@ -95,8 +103,8 @@ def _forward_train(self, batched_input, multimask_output, image_size): ) masks = self.sam_model.postprocess_masks( low_res_masks, - input_size=(image_size, image_size), - original_size=(image_size, image_size) + input_size=batched_images.shape[-2:], + original_size=original_size, ) # Bring the masks and low-res masks into the correct shape: @@ -112,11 +120,12 @@ def _forward_train(self, batched_input, multimask_output, image_size): masks = masks.transpose(1, 2) low_res_masks = low_res_masks.transpose(1, 2) - outputs = { - "masks": masks, - "iou_predictions": iou_predictions, - "low_res_logits": low_res_masks - } + # Make the output compatable with the SAM output. + outputs = [{ + "masks": mask.unsqueeze(0), + "iou_predictions": iou_pred, + "low_res_logits": low_res_mask.unsqueeze(0) + } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] return outputs diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py similarity index 75% rename from micro_sam/simple_sam_3d_wrapper.py rename to micro_sam/models/simple_sam_3d_wrapper.py index 30c8c20a..cf4ddbcc 100644 --- a/micro_sam/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -1,12 +1,13 @@ from contextlib import nullcontext +from typing import Any, List, Dict import torch import torch.nn as nn -from .util import get_sam_model +from ..util import get_sam_model -def get_simple_3d_sam_model( +def get_simple_sam_3d_model( device, n_classes, image_size, @@ -15,14 +16,6 @@ def get_simple_3d_sam_model( model_type="vit_b", checkpoint_path=None, ): - if lora_rank is None: - use_lora = False - rank = None - freeze_encoder_ = freeze_encoder - else: - use_lora = True - rank = lora_rank - freeze_encoder_ = False _, sam = get_sam_model( model_type=model_type, @@ -31,10 +24,11 @@ def get_simple_3d_sam_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, ) + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d @@ -142,8 +136,27 @@ def _apply_image_encoder(self, x, D): encoder_features = torch.stack(encoder_features, 2) return encoder_features - def forward(self, x, **kwargs): - B, D, C, H, W = x.shape + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool + ) -> List[Dict[str, torch.Tensor]]: + """Predict 3D masks for the current inputs. + + Unlike original SAM this model only supports automatic segmentation and does not support prompts. + + Args: + batched_input: A list over input images, each a dictionary with the following keys.L + 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. + multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. + + Returns: + A list over input images, where each element is as dictionary with the following keys: + 'masks': Mask prediction for this object. + """ + x = torch.stack([inp["image"] for inp in batched_input], dim=0) + + B, C, D, H, W = x.shape assert C == 3 with self.no_grad(): @@ -154,5 +167,5 @@ def forward(self, x, **kwargs): out = decoder(out) logits = self.out_conv(out) - outputs = {"masks": logits} + outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] return outputs diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 6e3dad7e..5c82b7d5 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -62,8 +62,18 @@ def _compute_loss(self, y, masks): return net_loss def _get_model_outputs(self, batched_inputs): - image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) - batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) + # Precompute the image embeddings if the model exposes it as functionality. + if hasattr(self.model, "image_embeddings_oft"): + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) + else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. + # We need to take care of sending things to the device here. + batched_inputs = [ + {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} + for inp in batched_inputs + ] + batched_outputs = self.model(batched_inputs, multimask_output=True) + masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) return masks @@ -123,17 +133,3 @@ def _validate_impl(self, forward_context): self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)) return metric_val - - -class SemanticSamTrainer3D(SemanticSamTrainer): - def _get_model_outputs(self, batched_inputs): - model_input = torch.stack([inp["image"] for inp in batched_inputs]).to(self.device) - image_size = batched_inputs[0]["original_size"][-1] - batched_outputs = self.model( - model_input, - multimask_output=True, - image_size=image_size - ) - # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) - masks = batched_outputs["masks"] - return masks diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 6ad6ce40..dae8598c 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -1,6 +1,6 @@ import os from math import ceil, floor -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -43,8 +43,8 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, - use_lora: bool = False, - rank: Optional[int] = None, + lora_rank: Optional[int] = None, + lora_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs ) -> TrainableSAM: @@ -59,9 +59,11 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. - use_lora: Whether to use the low rank adaptation method for finetuning. - rank: The rank of the decomposition matrices for updating weights in each attention layer. + lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. + If None then LoRA is not used. + lora_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. Returns: The trainable segment anything model. @@ -74,8 +76,7 @@ def get_trainable_sam_model( checkpoint_path=checkpoint_path, return_sam=True, return_state=True, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, flexible_load_checkpoint=flexible_load_checkpoint, **model_kwargs ) @@ -93,7 +94,7 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: # in case LoRA is switched on, we cannot freeze the image encoder - if use_lora and (l_item == "image_encoder"): + if (lora_rank is not None) and (l_item == "image_encoder"): raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): @@ -227,7 +228,7 @@ def __call__(self, x, y): """ batched_inputs = [] for image, gt in zip(x, y): - batched_input = {"image": image, "original_size": image.shape[1:]} + batched_input = {"image": image, "original_size": image.shape[-2:]} batched_inputs.append(batched_input) return batched_inputs diff --git a/micro_sam/util.py b/micro_sam/util.py index 75ebe724..b2bc8d28 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -24,6 +24,7 @@ from skimage.segmentation import relabel_sequential from .__version__ import __version__ +from . import models as custom_models try: # Avoid import warnigns from mobile_sam @@ -132,18 +133,18 @@ def models(): "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt", "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt", - "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", + "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", - "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", + "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa } decoder_urls = { - "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", - "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", - "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", - "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", - "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", - "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", + "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", # noqa + "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", # noqa + "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", # noqa + "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa + "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa + "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa } urls = {**encoder_urls, **decoder_urls} @@ -270,8 +271,8 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, - use_lora: bool = False, - rank: Optional[int] = None, + lora_rank: Optional[int] = None, + lora_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs, ) -> SamPredictor: @@ -306,8 +307,9 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. - use_lora: Whether to use the low rank adaptation method for finetuning. - rank: The rank of the decomposition matrices for updating weights in each attention layer. + lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. + If None then LoRA is not used. + lora_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: @@ -329,7 +331,8 @@ def get_sam_model( # If we have a custom model then we may also have a decoder checkpoint. # Download it here, so that we can add it to the state. decoder_name = f"{model_type}_decoder" - decoder_path = model_registry.fetch(decoder_name, progressbar=True) if decoder_name in model_registry.registry else None + decoder_path = model_registry.fetch( + decoder_name, progressbar=True) if decoder_name in model_registry.registry else None # checkpoint_path has been passed, we use it instead of downloading a model. else: @@ -358,19 +361,17 @@ def get_sam_model( if model_kwargs: # Checks whether model_kwargs have been provided or not if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") - - from .training.models import build_sam - sam = build_sam.sam_model_registry[abbreviated_model_type](**model_kwargs) + sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) else: sam = sam_model_registry[abbreviated_model_type]() - # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything - if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers - from micro_sam.training.peft_sam import PEFT_Sam - if rank is None: - rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them - sam = PEFT_Sam(sam, rank=rank).sam + # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. + # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + if lora_rank is not None: + if abbreviated_model_type == "vit_t": + raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") + sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: diff --git a/test/test_bioimageio/test_model_export.py b/test/test_bioimageio/test_model_export.py index 37567742..6b0e61aa 100644 --- a/test/test_bioimageio/test_model_export.py +++ b/test/test_bioimageio/test_model_export.py @@ -11,6 +11,7 @@ @unittest.skipIf(spec_minor < 5, "Needs bioimagio.spec >= 0.5") +@unittest.expectedFailure class TestModelExport(unittest.TestCase): tmp_folder = "tmp" model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" diff --git a/micro_sam/training/models/__init__.py b/test/test_models/__init__.py similarity index 100% rename from micro_sam/training/models/__init__.py rename to test/test_models/__init__.py diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py new file mode 100644 index 00000000..1af3ef2c --- /dev/null +++ b/test/test_models/test_peft_sam.py @@ -0,0 +1,26 @@ +import unittest + +import torch +import micro_sam.util as util + + +class TestPEFTSam(unittest.TestCase): + model_type = "vit_b" + + def test_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_models/test_sam_3d_wrapper.py b/test/test_models/test_sam_3d_wrapper.py new file mode 100644 index 00000000..46c9b3e9 --- /dev/null +++ b/test/test_models/test_sam_3d_wrapper.py @@ -0,0 +1,27 @@ +import unittest + +import torch + + +class TestSAM3DWrapper(unittest.TestCase): + model_type = "vit_b" + + def test_sam_3d_wrapper(self): + from micro_sam.models.sam_3d_wrapper import get_sam_3d_model + + image_size = 256 + n_classes = 2 + sam_3d = get_sam_3d_model(device="cpu", model_type=self.model_type, image_size=image_size, n_classes=n_classes) + + # Shape: C X D X H X W + shape = (3, 4, image_size, image_size) + expected_shape = (1, n_classes, 4, image_size, image_size) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[-2:]}] + output = sam_3d(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_models/test_simple_sam_3d_wrapper.py b/test/test_models/test_simple_sam_3d_wrapper.py new file mode 100644 index 00000000..79e511de --- /dev/null +++ b/test/test_models/test_simple_sam_3d_wrapper.py @@ -0,0 +1,29 @@ +import unittest + +import torch + + +class TestSimpleSAM3DWrapper(unittest.TestCase): + model_type = "vit_b" + + def test_simple_sam_3d_wrapper(self): + from micro_sam.models.simple_sam_3d_wrapper import get_simple_sam_3d_model + + image_size = 256 + n_classes = 2 + sam_3d = get_simple_sam_3d_model( + device="cpu", model_type=self.model_type, image_size=image_size, n_classes=n_classes + ) + + # Shape: C X D X H X W + shape = (3, 4, image_size, image_size) + expected_shape = (1, n_classes, 4, image_size, image_size) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[-2:]}] + output = sam_3d(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_peft_training.py b/test/test_peft_training.py deleted file mode 100644 index 7c2f1270..00000000 --- a/test/test_peft_training.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -import torch - -from micro_sam.util import get_sam_model -from micro_sam.training.peft_sam import PEFT_Sam - - -class TestPEFTModule(unittest.TestCase): - """Integraton test for instantiating a PEFT SAM model. - """ - def _fetch_sam_model(self, model_type, device): - _, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True) - return sam_model - - def _create_dummy_inputs(self, shape): - input_image = torch.ones(shape) - return input_image - - def test_peft_sam(self): - model_type = "vit_b" - device = "cpu" - - # Load the dummy inputs. - input_shape = (1, 512, 512) - inputs = self._create_dummy_inputs(shape=input_shape) - - # Convert to the inputs expected by Segment Anything - batched_inputs = [ - {"image": inputs, "original_size": input_shape[1:]} - ] - - # Load the Segment Anything model. - sam_model = self._fetch_sam_model(model_type=model_type, device=device) - - # Wrap the Segment Anything model with PEFT methods. - peft_sam_model = PEFT_Sam(model=sam_model, rank=4) - - # Get the model outputs - outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False) - - # Check the expected shape of the outputs - mask_shapes = [output["masks"].shape[-2:] for output in outputs] - for shape in mask_shapes: - self.assertEqual(shape, input_shape[1:]) - - -if __name__ == "__main__": - unittest.main() From 3d8d8791d9ad584eac18a2cdf25fb396f73877a4 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 3 Jul 2024 22:45:11 +0200 Subject: [PATCH 14/24] Fix bug in precompute for 3d data (#649) --- micro_sam/precompute_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index e2ddc1ac..d07ea1bc 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -68,7 +68,7 @@ def cache_amg_state( if verbose: print("Precomputing the state for instance segmentation.") - amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) + amg.initialize(raw if i is None else raw[i], image_embeddings=image_embeddings, verbose=verbose, i=i) amg_state = amg.get_state() # put all state onto the cpu so that the state can be deserialized without a gpu From b4f786571434ab06fb3f837fb178b37e2843e051 Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 4 Jul 2024 13:47:47 +0200 Subject: [PATCH 15/24] merges... --- development/predict_3d_model_with_lucchi.py | 11 +++++------ development/train_3d_model_with_lucchi.py | 9 +++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/development/predict_3d_model_with_lucchi.py b/development/predict_3d_model_with_lucchi.py index 9cdcdbc0..fe4fe250 100644 --- a/development/predict_3d_model_with_lucchi.py +++ b/development/predict_3d_model_with_lucchi.py @@ -139,13 +139,12 @@ def predict(args): model = get_sam_3d_model(device, n_classes=args.n_classes, image_size=args.patch_shape[1], lora_rank=4, model_type=args.model_type, - checkpoint_path=cp_path - ) + # checkpoint_path=args.checkpoint_path + ) - # checkpoint = torch.load(cp_path, map_location=device) - # #print(checkpoint.keys()) - # # # Load the state dictionary from the checkpoint - # model.load_state_dict(checkpoint['model_state']) + checkpoint = torch.load(cp_path, map_location=device) + # # Load the state dictionary from the checkpoint + model.load_state_dict(checkpoint['model'].state_dict()) model.eval() data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index d393ece6..e4f0a2da 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -11,7 +11,11 @@ from torch_em.transform.raw import normalize from micro_sam.models.sam_3d_wrapper import get_sam_3d_model +<<<<<<< HEAD from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer +======= +from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D +>>>>>>> f3d8d8d (problems) import micro_sam.training as sam_training @@ -141,8 +145,13 @@ def train_on_lucchi(args): train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) +<<<<<<< HEAD trainer = SemanticSamTrainer( name="3d-sam-lucchi-train", +======= + trainer = SemanticSamTrainer3D( + name="3d-sam-lucchi-new", +>>>>>>> f3d8d8d (problems) model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, From 63b465442bf14e39f3f007c2235e31b9f79c2ec4 Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 5 Jul 2024 10:46:19 +0200 Subject: [PATCH 16/24] added support for vitl and vith --- development/train_3d_model_with_lucchi.py | 62 +++++++++-------------- micro_sam/models/sam_3d_wrapper.py | 18 +++++-- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index e4f0a2da..3bfe32d0 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -11,11 +11,9 @@ from torch_em.transform.raw import normalize from micro_sam.models.sam_3d_wrapper import get_sam_3d_model -<<<<<<< HEAD + from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer -======= -from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D ->>>>>>> f3d8d8d (problems) + import micro_sam.training as sam_training @@ -85,7 +83,20 @@ def __getitem__(self, index): def transform_labels(y): - return (y > 0).astype("float32") + #return (y > 0).astype("float32") + # use torch_em to get foreground and boundary channels + transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + one_hot_channels = transform(y) + # Combine foreground and background using element-wise maximum + foreground = np.where(one_hot_channels[0] > 0, 1, 0) + + # Combine foreground and boundaries with priority to boundaries (ensures boundaries are 2) + combined = np.where(one_hot_channels[1] > 0, 2, foreground) + + # Set background to 0 + combined[combined == 0] = 0 + + return combined def get_loaders(input_path, patch_shape): @@ -99,18 +110,6 @@ def get_loaders(input_path, patch_shape): raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels ) return train_loader, val_loader -# def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, num_workers=1): -# assert split in ("train", "test") -# data_path = os.path.join(path, f"lucchi_{split}.h5") -# raw_key, label_key = "raw", "labels" -# ds = LucchiSegmentationDataset( -# raw_path=data_path, label_path=data_path, raw_key=raw_key, -# label_key=label_key, patch_shape=patch_shape, label_transform=label_transform) -# loader = torch.utils.data.DataLoader( -# ds, batch_size=batch_size, shuffle=True, -# num_workers=num_workers) -# loader.shuffle = True -# return loader def train_on_lucchi(args): @@ -124,34 +123,18 @@ def train_on_lucchi(args): n_iterations = args.n_iterations save_root = args.save_root - # label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) - # label_transform = None - raw_data = np.random.rand(64, 256, 256) # Shape (z, y, x) - raw_data2, label = next(iter(get_lucchi_loader(input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True))) - - # Create an instance of RawTrafoFor3dInputs - transformer = RawTrafoFor3dInputs() - # Apply transformations - processed_data = transformer(raw_data) - processed_data2 = transformer(raw_data2) - print("input (64,256,256)", processed_data.shape) - print("input", raw_data2.shape, processed_data2.shape) device = "cuda" if torch.cuda.is_available() else "cpu" sam_3d = get_sam_3d_model( device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type, lora_rank=4) train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) - optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) + optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) -<<<<<<< HEAD + trainer = SemanticSamTrainer( - name="3d-sam-lucchi-train", -======= - trainer = SemanticSamTrainer3D( - name="3d-sam-lucchi-new", ->>>>>>> f3d8d8d (problems) + name="3d-sam-vith-masamhyp-lucchi", model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, @@ -164,7 +147,7 @@ def train_on_lucchi(args): #logger=None ) # check_loader(train_loader, n_samples=10) - trainer.fit(n_iterations) + trainer.fit(epochs=n_iterations) def main(): @@ -179,9 +162,10 @@ def main(): ) parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") - parser.add_argument("--batch_size", type=int, default=1, help="Batch size") + parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") + parser.add_argument("--batch_size", type=int, default=3, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") + parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate") parser.add_argument( "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 4a7645d0..1e5df8df 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -31,21 +31,33 @@ def get_sam_3d_model( lora_rank=lora_rank, ) - sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) + sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_, model_type=model_type) sam_3d.to(device) return sam_3d class Sam3DWrapper(nn.Module): - def __init__(self, sam_model: Sam, freeze_encoder: bool): + def __init__(self, sam_model: Sam, freeze_encoder: bool, model_type: str = "vit_b"): """Initializes the Sam3DWrapper object. Args: sam_model: The Sam model to be wrapped. """ super().__init__() + # differantiate between model sizes + if model_type == "vit_b": + embed_dim = 768 + num_heads = 12 + elif model_type == "vit_l": + embed_dim = 1024 + num_heads = 16 + elif model_type == "vit_h": + embed_dim = 1280 + num_heads = 16 sam_model.image_encoder = ImageEncoderViT3DWrapper( - image_encoder=sam_model.image_encoder + image_encoder=sam_model.image_encoder, + num_heads=num_heads, + embed_dim=embed_dim, ) self.sam_model = sam_model From eaacf7a2d4d5179c509f56686dccfd9402728d00 Mon Sep 17 00:00:00 2001 From: Luca Date: Tue, 9 Jul 2024 09:24:33 +0200 Subject: [PATCH 17/24] changed training for n iterations to n epochs --- development/train_3d_model_with_lucchi.py | 27 ++++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 3bfe32d0..4b41dd1a 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -120,21 +120,26 @@ def train_on_lucchi(args): num_workers = args.num_workers n_classes = args.n_classes model_type = args.model_type - n_iterations = args.n_iterations + n_epochs = args.n_epochs save_root = args.save_root device = "cuda" if torch.cuda.is_available() else "cpu" - sam_3d = get_sam_3d_model( - device, n_classes=n_classes, image_size=patch_shape[1], - model_type=model_type, lora_rank=4) + if args.without_lora: + sam_3d = get_sam_3d_model( + device, n_classes=n_classes, image_size=patch_shape[1], + model_type=model_type, lora_rank=None) # freeze encoder + else: + sam_3d = get_sam_3d_model( + device, n_classes=n_classes, image_size=patch_shape[1], + model_type=model_type, lora_rank=4) train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) trainer = SemanticSamTrainer( - name="3d-sam-vith-masamhyp-lucchi", + name=args.exp_name, model=sam_3d, convert_inputs=ConvertToSemanticSamInputs(), num_classes=n_classes, @@ -147,7 +152,7 @@ def train_on_lucchi(args): #logger=None ) # check_loader(train_loader, n_samples=10) - trainer.fit(epochs=n_iterations) + trainer.fit(epochs=n_epochs) def main(): @@ -160,16 +165,22 @@ def main(): "--model_type", "-m", default="vit_b", help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) + parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") - parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") + + parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") - parser.add_argument("--batch_size", type=int, default=3, help="Batch size") + parser.add_argument("--batch_size", "-bs", type=int, default=3, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate") parser.add_argument( "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." ) + parser.add_argument( + "--exp_name", default="vitb_3d_lora4", + help="The filepath to where the logs and the checkpoints will be saved." + ) args = parser.parse_args() train_on_lucchi(args) From e3b2dbb78d6f7079ac5a261c2c88783047dc0251 Mon Sep 17 00:00:00 2001 From: Luca Date: Tue, 9 Jul 2024 17:28:48 +0200 Subject: [PATCH 18/24] debug train sam without encoder on mitottomo --- development/predict_3d_model_with_lucchi.py | 49 ++-- development/train_3d_model_with_lucchi.py | 13 +- ...in_3d_model_with_lucchi_without_decoder.py | 252 ++++++++++++++++++ 3 files changed, 286 insertions(+), 28 deletions(-) create mode 100644 development/train_3d_model_with_lucchi_without_decoder.py diff --git a/development/predict_3d_model_with_lucchi.py b/development/predict_3d_model_with_lucchi.py index fe4fe250..4279b378 100644 --- a/development/predict_3d_model_with_lucchi.py +++ b/development/predict_3d_model_with_lucchi.py @@ -96,31 +96,32 @@ def run_semantic_segmentation_3d( assert os.path.exists(image_path), image_path # Perform segmentation only on the semantic class - for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): - if is_multiclass: - semantic_class_name = "all" - if i > 0: # We only perform segmentation for multiclass once. - continue + # for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): + # if is_multiclass: + # semantic_class_name = "all" + # if i > 0: # We only perform segmentation for multiclass once. + # continue + semantic_class_name = "all" #since we only perform segmentation for multiclass # We skip the images that already have been segmented - image_name = os.path.splitext(image_name)[0] + ".tif" - prediction_path = os.path.join(prediction_dir, semantic_class_name, image_name) - if os.path.exists(prediction_path): - continue + image_name = os.path.splitext(image_name)[0] + ".tif" + prediction_path = os.path.join(prediction_dir, "all", image_name) + if os.path.exists(prediction_path): + continue - if image_key is None: - image = imageio.imread(image_path) - else: - with open_file(image_path, "r") as f: - image = f[image_key][:] + if image_key is None: + image = imageio.imread(image_path) + else: + with open_file(image_path, "r") as f: + image = f[image_key][:] - # create the prediction folder - os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) + # create the prediction folder + os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) - _run_semantic_segmentation_for_image_3d( - model=model, image=image, prediction_path=prediction_path, - patch_shape=patch_shape, halo=halo, - ) + _run_semantic_segmentation_for_image_3d( + model=model, image=image, prediction_path=prediction_path, + patch_shape=patch_shape, halo=halo, + ) def transform_labels(y): @@ -144,7 +145,9 @@ def predict(args): checkpoint = torch.load(cp_path, map_location=device) # # Load the state dictionary from the checkpoint - model.load_state_dict(checkpoint['model'].state_dict()) + for k, v in checkpoint.items(): + print("keys", k) + model.load_state_dict(checkpoint['model_state']) #.state_dict() model.eval() data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) @@ -169,7 +172,7 @@ def main(): ) parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") - parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") + parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--num_workers", type=int, default=4, help="num_workers") parser.add_argument( @@ -177,7 +180,7 @@ def main(): help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-lucchi-train/", + "--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-vitb-masamhyp-lucchi", help="The filepath to where the logs and the checkpoints will be saved." ) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 4b41dd1a..9ff888c1 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -124,7 +124,6 @@ def train_on_lucchi(args): save_root = args.save_root - device = "cuda" if torch.cuda.is_available() else "cpu" if args.without_lora: sam_3d = get_sam_3d_model( @@ -135,7 +134,10 @@ def train_on_lucchi(args): device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type, lora_rank=4) train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) - optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) + #optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) + optimizer = torch.optim.Adam(sam_3d.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=15, verbose=True) + #masam no scheduler trainer = SemanticSamTrainer( @@ -146,6 +148,7 @@ def train_on_lucchi(args): train_loader=train_loader, val_loader=val_loader, optimizer=optimizer, + lr_scheduler=scheduler, device=device, compile_model=False, save_root=save_root, @@ -170,15 +173,15 @@ def main(): parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") - parser.add_argument("--batch_size", "-bs", type=int, default=3, help="Batch size") + parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3 parser.add_argument("--num_workers", type=int, default=4, help="num_workers") - parser.add_argument("--learning_rate", type=float, default=0.0008, help="base learning rate") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008 parser.add_argument( "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--exp_name", default="vitb_3d_lora4", + "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", help="The filepath to where the logs and the checkpoints will be saved." ) diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py new file mode 100644 index 00000000..664fc6db --- /dev/null +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -0,0 +1,252 @@ +import numpy as np +from glob import glob +import h5py +from micro_sam.training import train_sam, default_sam_dataset +from torch_em.data.sampler import MinInstanceSampler +from torch_em.segmentation import get_data_loader +import torch +import torch_em +import os +import argparse +from skimage.measure import regionprops + + +def get_rois_coordinates_skimage(file, label_key, min_shape, euler_threshold=None, min_amount_pixels=None): + """ + Calculates the average coordinates for each unique label in a 3D label image using skimage.regionprops. + + Args: + file (h5py.File): Handle to the open HDF5 file. + label_key (str): Key for the label data within the HDF5 file. + min_shape (tuple): A tuple representing the minimum size for each dimension of the ROI. + euler_threshold (int, optional): The Euler number threshold. If provided, only regions with the specified Euler number will be considered. + min_amount_pixels (int, optional): The minimum amount of pixels. If provided, only regions with at least this many pixels will be considered. + + Returns: + dict or None: A dictionary mapping unique labels to lists of average coordinates for each dimension, or None if no labels are found. + """ + + label_data = file[label_key] + label_shape = label_data.shape + + # Ensure data type is suitable for regionprops (usually uint labels) + # if label_data.dtype != np.uint: + # label_data = label_data.astype(np.uint).value + + # Find connected regions (objects) using regionprops + regions = regionprops(label_data) + + # Check if any regions were found + if not regions: + return None + + label_extents = {} + for region in regions: + if euler_threshold is not None: + if region.euler_number != euler_threshold: + continue + if min_amount_pixels is not None: + if region["area"] < min_amount_pixels: + continue + + # # Extract relevant information for ROI calculation + label = region.label # Get the label value + min_coords = region.bbox[:3] # Minimum coordinates (excluding intensity channel) + max_coords = region.bbox[3:6] # Maximum coordinates (excluding intensity channel) + + # Clip coordinates and create ROI extent (similar to previous approach) + clipped_min_coords = np.clip(min_coords, 0, label_shape[0] - min_shape[0]) + clipped_max_coords = np.clip(max_coords, min_shape[1], label_shape[1]) + roi_extent = tuple(slice(min_val, min_val + min_shape[dim]) for dim, (min_val, max_val) in enumerate(zip(clipped_min_coords, clipped_max_coords))) + + # Check for labels within the ROI extent (new part) + roi_data = file[label_key][roi_extent] + amount_label_pixels = np.count_nonzero(roi_data) + if amount_label_pixels < 100 or amount_label_pixels < min_amount_pixels: # Check for any non-zero values (labels) + continue # Skip this ROI if no labels present + + label_extents[label] = roi_extent + + return label_extents + + +def get_data_paths_and_rois(data_dir, min_shape, + data_format="*.h5", + image_key="raw", + label_key_mito="labels/mitochondria", + label_key_cristae="labels/cristae", + with_thresholds=True): + + data_paths = glob(os.path.join(data_dir, "**", data_format), recursive=True) + rois_list = [] + new_data_paths = [] # one data path for each ROI + + for data_path in data_paths: + try: + # Open the HDF5 file in read-only mode + with h5py.File(data_path, "r") as f: + # Check for existence of image and label datasets (considering key flexibility) + if image_key not in f or (label_key_mito is not None and label_key_mito not in f): + print(f"Warning: Key(s) missing in {data_path}. Skipping {image_key}") + continue + + #label_data_mito = f[label_key_mito][()] if label_key_mito is not None else None + + # Extract ROIs (assuming ndim of label data is the same as image data) + if with_thresholds: + rois = get_rois_coordinates_skimage(f, label_key_mito, min_shape, min_amount_pixels=100) # euler_threshold=1, + else: + rois = get_rois_coordinates_skimage(f, label_key_mito, min_shape, euler_threshold=None, min_amount_pixels=None) + for label_id, roi in rois.items(): + rois_list.append(roi) + new_data_paths.append(data_path) + except OSError: + print(f"Error accessing file: {data_path}. Skipping...") + + return new_data_paths, rois_list + + +def split_data_paths_to_dict(data_paths, rois_list, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1): + """ + Splits data paths and ROIs into training, validation, and testing sets without shuffling. + + Args: + data_paths (list): List of paths to all HDF5 files. + rois_list (list): List of ROIs corresponding to the data paths. + train_ratio (float, optional): Proportion of data for training (0.0-1.0) (default: 0.8). + val_ratio (float, optional): Proportion of data for validation (0.0-1.0) (default: 0.1). + test_ratio (float, optional): Proportion of data for testing (0.0-1.0) (default: 0.1). + + Returns: + tuple: A tuple containing two dictionaries: + - data_split: Dictionary containing "train", "val", and "test" keys with data paths. + - rois_split: Dictionary containing "train", "val", and "test" keys with corresponding ROIs. + """ + + if train_ratio + val_ratio + test_ratio != 1.0: + raise ValueError("Sum of train, validation, and test ratios must equal 1.0.") + num_data = len(data_paths) + if rois_list is not None: + if len(rois_list) != num_data: + raise ValueError(f"Length of data paths and number of ROIs in the dictionary must match: len rois {len(rois_list)}, len data_paths {len(data_paths)}") + + train_size = int(num_data * train_ratio) + val_size = int(num_data * val_ratio) # Optional validation set + test_size = num_data - train_size - val_size + + data_split = { + "train": data_paths[:train_size], + "val": data_paths[train_size:train_size+val_size], + "test": data_paths[train_size+val_size:] + } + + if rois_list is not None: + rois_split = { + "train": rois_list[:train_size], + "val": rois_list[train_size:train_size+val_size], + "test": rois_list[train_size+val_size:] + } + + return data_split, rois_split + else: + return data_split + + +def get_data_paths(data_dir, data_format="*.h5"): + data_paths = glob(os.path.join(data_dir, "**", data_format), recursive=True) + return data_paths + + +def train(args): + n_workers = 4 if torch.cuda.is_available() else 1 + device = "cuda" if torch.cuda.is_available() else "cpu" + data_dir = args.input_path + with_rois = True if args.without_rois is False else False + patch_shape = args.patch_shape + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + ndim = 3 + + if with_rois: + data_paths, rois_dict = get_data_paths_and_rois(data_dir, min_shape=patch_shape, with_thresholds=True) + data, rois_dict = split_data_paths_to_dict(data_paths, rois_dict, train_ratio=.8, val_ratio=0.2, test_ratio=0) + else: + data_paths = get_data_paths(data_dir) + data = split_data_paths_to_dict(data_paths, rois_list=None, train_ratio=.5, val_ratio=0.5, test_ratio=0) + #path = "/scratch-emmy/projects/nim00007/fruit-fly-data/cambridge_data/parker_s2_soma_roi_z472-600_y795-1372_x1122-1687_clahed.zarr" + label_key = "labels/mitochondria" # "./annotations1.tif" + + # train_ds = default_sam_dataset( + # raw_paths=data["train"][0], raw_key="raw", + # label_paths=data["train"][0], label_key=label_key, + # patch_shape=args.patch_shape, with_segmentation_decoder=False, + # sampler=MinInstanceSampler(3), + # #rois=rois_dict["train"], + # n_samples=200, + # ) + # train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) + + # val_ds = default_sam_dataset( + # raw_paths=data["val"][0], raw_key="raw", + # label_paths=data["val"][0], label_key=label_key, + # patch_shape=args.patch_shape, with_segmentation_decoder=False, + # sampler=MinInstanceSampler(3), + # #rois=rois_dict["val"], + # is_train=False, n_samples=25, + # ) + # val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) + train_loader = torch_em.default_segmentation_loader( + raw_paths=data["train"], raw_key="raw", + label_paths=data["train"], label_key="labels/mitochondria", + patch_shape=patch_shape, ndim=ndim, batch_size=1, + label_transform=label_transform, num_workers=n_workers, + ) + val_loader = torch_em.default_segmentation_loader( + raw_paths=data["train"], raw_key="raw", + label_paths=data["val"], label_key="labels/mitochondria", + patch_shape=patch_shape, ndim=ndim, batch_size=1, + label_transform=label_transform, num_workers=n_workers, + ) + + train_sam( + name="nucleus_model", model_type="vit_b", + train_loader=train_loader, val_loader=val_loader, + n_epochs=50, n_objects_per_batch=10, + with_segmentation_decoder=False, + save_root=args.save_root, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") + parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") + + parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") + parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") + parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3 + parser.add_argument("--num_workers", type=int, default=4, help="num_workers") + parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008 + parser.add_argument( + "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam_training_on_mitotomo", + help="The filepath to where the logs and the checkpoints will be saved." + ) + parser.add_argument( + "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", + help="The filepath to where the logs and the checkpoints will be saved." + ) + parser.add_argument("--without_rois", type=bool, default=True, help="Train without Regions Of Interest (ROI)") + + args = parser.parse_args() + train(args) + + +if __name__ == "__main__": + main() From a19f73d67f1d9679c20b48b3543273bb97f70e33 Mon Sep 17 00:00:00 2001 From: Luca Freckmann Date: Wed, 10 Jul 2024 15:12:21 +0200 Subject: [PATCH 19/24] added parameter for raw transform and min_size for label_transform to default_sam_dataset --- ...in_3d_model_with_lucchi_without_decoder.py | 119 ++++++++++++------ micro_sam/training/training.py | 13 +- 2 files changed, 89 insertions(+), 43 deletions(-) diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py index 664fc6db..411651c1 100644 --- a/development/train_3d_model_with_lucchi_without_decoder.py +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -4,11 +4,13 @@ from micro_sam.training import train_sam, default_sam_dataset from torch_em.data.sampler import MinInstanceSampler from torch_em.segmentation import get_data_loader +from torch_em.transform.raw import normalize import torch import torch_em import os import argparse from skimage.measure import regionprops +from torch_em.util.debug import check_loader def get_rois_coordinates_skimage(file, label_key, min_shape, euler_threshold=None, min_amount_pixels=None): @@ -123,8 +125,8 @@ def split_data_paths_to_dict(data_paths, rois_list, train_ratio=0.8, val_ratio=0 - rois_split: Dictionary containing "train", "val", and "test" keys with corresponding ROIs. """ - if train_ratio + val_ratio + test_ratio != 1.0: - raise ValueError("Sum of train, validation, and test ratios must equal 1.0.") + if not np.isclose(train_ratio + val_ratio + test_ratio, 1.0, atol=0.01): + raise ValueError(f"Sum of train, validation, and test ratios must equal 1.0. But instead got:{train_ratio + val_ratio + test_ratio}") num_data = len(data_paths) if rois_list is not None: if len(rois_list) != num_data: @@ -157,58 +159,99 @@ def get_data_paths(data_dir, data_format="*.h5"): return data_paths +def raw_transform(image): + image = normalize(image) + image = image * 255 + return image + + + def train(args): n_workers = 4 if torch.cuda.is_available() else 1 device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = args.input_path with_rois = True if args.without_rois is False else False + with_rois = False patch_shape = args.patch_shape - label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + bs = args.batch_size + #label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) + label_transform = torch_em.transform.label.labels_to_binary ndim = 3 if with_rois: data_paths, rois_dict = get_data_paths_and_rois(data_dir, min_shape=patch_shape, with_thresholds=True) - data, rois_dict = split_data_paths_to_dict(data_paths, rois_dict, train_ratio=.8, val_ratio=0.2, test_ratio=0) + data, rois_dict = split_data_paths_to_dict(data_paths, rois_dict, train_ratio=.7, val_ratio=0.2, test_ratio=0.1) else: data_paths = get_data_paths(data_dir) - data = split_data_paths_to_dict(data_paths, rois_list=None, train_ratio=.5, val_ratio=0.5, test_ratio=0) + data = split_data_paths_to_dict(data_paths, rois_list=None, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1) #path = "/scratch-emmy/projects/nim00007/fruit-fly-data/cambridge_data/parker_s2_soma_roi_z472-600_y795-1372_x1122-1687_clahed.zarr" label_key = "labels/mitochondria" # "./annotations1.tif" - - # train_ds = default_sam_dataset( - # raw_paths=data["train"][0], raw_key="raw", - # label_paths=data["train"][0], label_key=label_key, - # patch_shape=args.patch_shape, with_segmentation_decoder=False, - # sampler=MinInstanceSampler(3), - # #rois=rois_dict["train"], - # n_samples=200, - # ) - # train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) - - # val_ds = default_sam_dataset( - # raw_paths=data["val"][0], raw_key="raw", - # label_paths=data["val"][0], label_key=label_key, - # patch_shape=args.patch_shape, with_segmentation_decoder=False, - # sampler=MinInstanceSampler(3), - # #rois=rois_dict["val"], - # is_train=False, n_samples=25, - # ) - # val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) - train_loader = torch_em.default_segmentation_loader( + train_ds = default_sam_dataset( raw_paths=data["train"], raw_key="raw", - label_paths=data["train"], label_key="labels/mitochondria", - patch_shape=patch_shape, ndim=ndim, batch_size=1, - label_transform=label_transform, num_workers=n_workers, + label_paths=data["train"], label_key=label_key, + patch_shape=args.patch_shape, with_segmentation_decoder=False, + sampler=MinInstanceSampler(2), + raw_transform=raw_transform, + #rois=np.s_[64:, :, :], + #n_samples=200, ) - val_loader = torch_em.default_segmentation_loader( - raw_paths=data["train"], raw_key="raw", - label_paths=data["val"], label_key="labels/mitochondria", - patch_shape=patch_shape, ndim=ndim, batch_size=1, - label_transform=label_transform, num_workers=n_workers, + train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) + + val_ds = default_sam_dataset( + raw_paths=data["val"], raw_key="raw", + label_paths=data["val"], label_key=label_key, + patch_shape=args.patch_shape, with_segmentation_decoder=False, + sampler=MinInstanceSampler(2), + raw_transform=raw_transform, + #rois=np.s_[64:, :, :], + is_train=False, + #n_samples=25, ) + val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) + # if with_rois: + # train_loader = torch_em.default_segmentation_loader( + # raw_paths=data["train"], raw_key="raw", + # label_paths=data["train"], label_key="labels/mitochondria", + # patch_shape=patch_shape, ndim=ndim, batch_size=bs, + # label_transform=label_transform, raw_transform=raw_transform, + # num_workers=n_workers, + # rois=rois_dict["train"] + # #rois=[np.s_[64:, :, :]] * len(data["train"]) + # ) + # val_loader = torch_em.default_segmentation_loader( + # raw_paths=data["val"], raw_key="raw", + # label_paths=data["val"], label_key="labels/mitochondria", + # patch_shape=patch_shape, ndim=ndim, batch_size=bs, + # label_transform=label_transform, raw_transform=raw_transform, + # num_workers=n_workers, + # rois=rois_dict["val"] + # # rois=[np.s_[64:, :, :]] * len(data["val"]) + # ) + # else: + # train_loader = torch_em.default_segmentation_loader( + # raw_paths=data["train"], raw_key="raw", + # label_paths=data["train"], label_key=label_key, + # patch_shape=patch_shape, ndim=ndim, batch_size=bs, + # label_transform=label_transform, raw_transform=raw_transform, + # num_workers=n_workers, + # ) + # print("len data[val]", len(data["val"])) + # val_loader = torch_em.default_segmentation_loader( + # raw_paths=data["val"], raw_key="raw", + # label_paths=data["val"], label_key=label_key, + # patch_shape=patch_shape, ndim=ndim, batch_size=bs, + # label_transform=label_transform, raw_transform=raw_transform, + # num_workers=n_workers, + # ) + + + #check_loader(train_loader, n_samples=3) + # x,y =next(iter(train_loader)) + # print("shapes of x and y", x.shape, y.shape) + # breakpoint() train_sam( - name="nucleus_model", model_type="vit_b", + name="mito_model", model_type="vit_b", train_loader=train_loader, val_loader=val_loader, n_epochs=50, n_objects_per_batch=10, with_segmentation_decoder=False, @@ -217,7 +260,7 @@ def train(args): def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Mitochondria dataset.") parser.add_argument( "--input_path", "-i", default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/", help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." @@ -227,7 +270,7 @@ def main(): help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." ) parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") - parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") + parser.add_argument("--patch_shape", type=int, nargs=3, default=(1, 512, 512), help="Patch shape for data loading (3D tuple)") parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") @@ -242,7 +285,7 @@ def main(): "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", help="The filepath to where the logs and the checkpoints will be saved." ) - parser.add_argument("--without_rois", type=bool, default=True, help="Train without Regions Of Interest (ROI)") + parser.add_argument("--without_rois", action="store_true", help="Train without Regions Of Interest (ROI)") args = parser.parse_args() train(args) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 43fd28df..9cf1c910 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -1,6 +1,6 @@ import os from glob import glob -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable import imageio.v3 as imageio import torch @@ -287,7 +287,7 @@ def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): path = raw_paths[0] else: path = raw_paths - assert isinstance(raw_paths, (str, os.PathLike)) + assert isinstance(path, (str, os.PathLike)) # Check the underlying data dimensionality. if raw_key is None: # If no key is given then we assume it's an image file. @@ -319,10 +319,12 @@ def default_sam_dataset( label_key: Optional[str], patch_shape: Tuple[int], with_segmentation_decoder: bool, + raw_transform: Optional[Callable] = require_8bit, with_channels: bool = False, sampler=None, # Type? n_samples: Optional[int] = None, is_train: bool = True, + min_size=25, **kwargs, ) -> Dataset: """Create a PyTorch Dataset for training a SAM model. @@ -348,14 +350,15 @@ def default_sam_dataset( """ # Set the data transformations. - raw_transform = require_8bit + raw_transform = raw_transform if with_segmentation_decoder: label_transform = torch_em.transform.label.PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=25, + foreground=True, instances=True, min_size=min_size, ) else: - label_transform = torch_em.transform.label.connected_components + label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) + #torch_em.transform.label.connected_components # Set a default sampler if none was passed. if sampler is None: From a90ca2ee7eab699e3a36455cb8715a335de70efb Mon Sep 17 00:00:00 2001 From: Luca Date: Wed, 10 Jul 2024 15:14:36 +0200 Subject: [PATCH 20/24] added checkpoint to train_with_lucchi --- development/train_3d_model_with_lucchi.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index 9ff888c1..bb76d094 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -122,6 +122,7 @@ def train_on_lucchi(args): model_type = args.model_type n_epochs = args.n_epochs save_root = args.save_root + cp_path = args.checkpoint_path device = "cuda" if torch.cuda.is_available() else "cpu" @@ -133,6 +134,11 @@ def train_on_lucchi(args): sam_3d = get_sam_3d_model( device, n_classes=n_classes, image_size=patch_shape[1], model_type=model_type, lora_rank=4) + if cp_path is not None: + if os.path.exists(cp_path): + checkpoint = torch.load(cp_path, map_location=device) + # # Load the state dictionary from the checkpoint + sam_3d.load_state_dict(checkpoint['model_state']) #.state_dict() train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) #optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) optimizer = torch.optim.Adam(sam_3d.parameters(), lr=1e-5) @@ -180,6 +186,10 @@ def main(): "--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", help="The filepath to where the logs and the checkpoints will be saved." ) + parser.add_argument( + "--checkpoint_path", default=None, + help="The filepath to where the checkpoints are loaded from." + ) parser.add_argument( "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", help="The filepath to where the logs and the checkpoints will be saved." From ad76f2e139758e6850a7985f5f86907bde75c83d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 11 Jul 2024 22:46:47 +0200 Subject: [PATCH 21/24] Add min-size to training and fix other issues --- micro_sam/training/training.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index bdb40168..defcd398 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -290,11 +290,11 @@ def train_sam( def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): - if not isinstance(raw_paths, (str, os.PathLike)): - path = raw_paths[0] - else: + if isinstance(raw_paths, (str, os.PathLike)): path = raw_paths - assert isinstance(raw_paths, (str, os.PathLike)) + else: + path = raw_paths[0] + assert isinstance(path, (str, os.PathLike)) # Check the underlying data dimensionality. if raw_key is None: # If no key is given then we assume it's an image file. @@ -330,6 +330,8 @@ def default_sam_dataset( sampler=None, # Type? n_samples: Optional[int] = None, is_train: bool = True, + min_size: int = 25, + max_sampling_attempts: Optional[int] = None, **kwargs, ) -> Dataset: """Create a PyTorch Dataset for training a SAM model. @@ -349,6 +351,8 @@ def default_sam_dataset( sampler: A sampler to reject batches according to a given criterion. n_samples: The number of samples for this dataset. is_train: Whether this dataset is used for training or validation. + min_size: Minimal object size. Smaller objects will be filtered. + max_sampling_attempts: Number of sampling attempts to make from a dataset. Returns: The dataset. @@ -359,14 +363,16 @@ def default_sam_dataset( if with_segmentation_decoder: label_transform = torch_em.transform.label.PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=25, + foreground=True, instances=True, min_size=min_size, ) else: - label_transform = torch_em.transform.label.connected_components + label_transform = torch_em.transform.label.MinSizeLabelTransform( + min_size=min_size + ) # Set a default sampler if none was passed. if sampler is None: - sampler = torch_em.data.sampler.MinInstanceSampler(3) + sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) # Check the patch shape to add a singleton if required. patch_shape = _update_patch_shape( @@ -389,6 +395,11 @@ def default_sam_dataset( sampler=sampler, n_samples=n_samples, **kwargs, ) + + # TODO + if max_sampling_attempts is not None: + pass + return dataset From a5508936143e1db3cc1cad73641cc2c1bb97c180 Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 12 Jul 2024 08:49:47 +0200 Subject: [PATCH 22/24] removed unused code --- development/train_3d_model_with_lucchi.py | 2 - ...in_3d_model_with_lucchi_without_decoder.py | 52 ++++--------------- micro_sam/util.py | 1 - 3 files changed, 11 insertions(+), 44 deletions(-) diff --git a/development/train_3d_model_with_lucchi.py b/development/train_3d_model_with_lucchi.py index bb76d094..2783ab01 100644 --- a/development/train_3d_model_with_lucchi.py +++ b/development/train_3d_model_with_lucchi.py @@ -75,10 +75,8 @@ def __getitem__(self, index): raw = raw.view(image_shape) raw = raw.squeeze(0) raw = raw.repeat(1, 3, 1, 1) - # print("raw shape", raw.shape) # wanted label shape: (1, z, y, x) label = (label != 0).to(torch.float) - # print("label shape", label.shape) return raw, label diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py index 411651c1..eb8d6eb3 100644 --- a/development/train_3d_model_with_lucchi_without_decoder.py +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -175,8 +175,10 @@ def train(args): patch_shape = args.patch_shape bs = args.batch_size #label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) - label_transform = torch_em.transform.label.labels_to_binary - ndim = 3 + label_transform = torch_em.transform.label.MinSizeLabelTransform + ndim = 2 + min_size = 50 + max_sampling_attempts = 5000 if with_rois: data_paths, rois_dict = get_data_paths_and_rois(data_dir, min_shape=patch_shape, with_thresholds=True) @@ -190,61 +192,29 @@ def train(args): raw_paths=data["train"], raw_key="raw", label_paths=data["train"], label_key=label_key, patch_shape=args.patch_shape, with_segmentation_decoder=False, - sampler=MinInstanceSampler(2), + sampler=MinInstanceSampler(2, min_size=min_size), + min_size=min_size, raw_transform=raw_transform, #rois=np.s_[64:, :, :], #n_samples=200, ) + train_ds.max_sampling_attempts = max_sampling_attempts train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) val_ds = default_sam_dataset( raw_paths=data["val"], raw_key="raw", label_paths=data["val"], label_key=label_key, patch_shape=args.patch_shape, with_segmentation_decoder=False, - sampler=MinInstanceSampler(2), + sampler=MinInstanceSampler(2, min_size=min_size), + min_size=min_size, raw_transform=raw_transform, #rois=np.s_[64:, :, :], is_train=False, #n_samples=25, ) + val_ds.max_sampling_attempts = max_sampling_attempts val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) - # if with_rois: - # train_loader = torch_em.default_segmentation_loader( - # raw_paths=data["train"], raw_key="raw", - # label_paths=data["train"], label_key="labels/mitochondria", - # patch_shape=patch_shape, ndim=ndim, batch_size=bs, - # label_transform=label_transform, raw_transform=raw_transform, - # num_workers=n_workers, - # rois=rois_dict["train"] - # #rois=[np.s_[64:, :, :]] * len(data["train"]) - # ) - # val_loader = torch_em.default_segmentation_loader( - # raw_paths=data["val"], raw_key="raw", - # label_paths=data["val"], label_key="labels/mitochondria", - # patch_shape=patch_shape, ndim=ndim, batch_size=bs, - # label_transform=label_transform, raw_transform=raw_transform, - # num_workers=n_workers, - # rois=rois_dict["val"] - # # rois=[np.s_[64:, :, :]] * len(data["val"]) - # ) - # else: - # train_loader = torch_em.default_segmentation_loader( - # raw_paths=data["train"], raw_key="raw", - # label_paths=data["train"], label_key=label_key, - # patch_shape=patch_shape, ndim=ndim, batch_size=bs, - # label_transform=label_transform, raw_transform=raw_transform, - # num_workers=n_workers, - # ) - # print("len data[val]", len(data["val"])) - # val_loader = torch_em.default_segmentation_loader( - # raw_paths=data["val"], raw_key="raw", - # label_paths=data["val"], label_key=label_key, - # patch_shape=patch_shape, ndim=ndim, batch_size=bs, - # label_transform=label_transform, raw_transform=raw_transform, - # num_workers=n_workers, - # ) - - + #check_loader(train_loader, n_samples=3) # x,y =next(iter(train_loader)) # print("shapes of x and y", x.shape, y.shape) diff --git a/micro_sam/util.py b/micro_sam/util.py index 09d06c09..af5707a8 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -1010,7 +1010,6 @@ def segmentation_to_one_hot( masks = segmentation.copy() if segmentation_ids is None: n_ids = int(segmentation.max()) - else: assert segmentation_ids[0] != 0, "No objects were found." From b6a7ce9a7f072923242a3cacf57284a0289b3085 Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 12 Jul 2024 15:12:40 +0200 Subject: [PATCH 23/24] updates on train 3d without decoer --- .../train_3d_model_with_lucchi_without_decoder.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/development/train_3d_model_with_lucchi_without_decoder.py b/development/train_3d_model_with_lucchi_without_decoder.py index eb8d6eb3..bac541be 100644 --- a/development/train_3d_model_with_lucchi_without_decoder.py +++ b/development/train_3d_model_with_lucchi_without_decoder.py @@ -177,7 +177,7 @@ def train(args): #label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=False) label_transform = torch_em.transform.label.MinSizeLabelTransform ndim = 2 - min_size = 50 + min_size = 100 max_sampling_attempts = 5000 if with_rois: @@ -198,7 +198,8 @@ def train(args): #rois=np.s_[64:, :, :], #n_samples=200, ) - train_ds.max_sampling_attempts = max_sampling_attempts + for ds in train_ds.datasets: + ds.max_sampling_attempts = max_sampling_attempts train_loader = get_data_loader(train_ds, shuffle=True, batch_size=2) val_ds = default_sam_dataset( @@ -212,7 +213,8 @@ def train(args): is_train=False, #n_samples=25, ) - val_ds.max_sampling_attempts = max_sampling_attempts + for ds in val_ds.datasets: + ds.max_sampling_attempts = max_sampling_attempts val_loader = get_data_loader(val_ds, shuffle=True, batch_size=1) #check_loader(train_loader, n_samples=3) @@ -221,7 +223,7 @@ def train(args): # breakpoint() train_sam( - name="mito_model", model_type="vit_b", + name=args.exp_name, model_type="vit_b", train_loader=train_loader, val_loader=val_loader, n_epochs=50, n_objects_per_batch=10, with_segmentation_decoder=False, @@ -252,7 +254,7 @@ def main(): help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument( - "--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", + "--exp_name", default="vitb_3d-mitotomo", help="The filepath to where the logs and the checkpoints will be saved." ) parser.add_argument("--without_rois", action="store_true", help="Train without Regions Of Interest (ROI)") From 34220413985690782b7e1428447f0513725684eb Mon Sep 17 00:00:00 2001 From: Luca Date: Fri, 12 Jul 2024 17:53:56 +0200 Subject: [PATCH 24/24] bash script for sbatch --- development/train_without_decoder.sh | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 development/train_without_decoder.sh diff --git a/development/train_without_decoder.sh b/development/train_without_decoder.sh new file mode 100644 index 00000000..0ee0654c --- /dev/null +++ b/development/train_without_decoder.sh @@ -0,0 +1,15 @@ +#!/bin/bash +#SBATCH --partition=grete:shared +#SBATCH -G A100:1 +#SBATCH --time=2-00:00:00 +#SBATCH --account=nim00007 +#SBATCH --nodes=1 +#SBATCH -c 32 +#SBATCH --mem 128G +#SBATCH --job-name=mito-net + + +source /home/nimlufre/.bashrc +conda activate sam + +python /home/nimlufre/micro-sam/development/train_3d_model_with_lucchi_without_decoder.py