From 6ffe8a729b873bdf470249c86103809b6f3f75b2 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 12:30:34 +0200 Subject: [PATCH 01/11] Add semantic sam 3d logger --- micro_sam/training/semantic_sam_trainer.py | 40 +++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index b3f1cc0a..d24dd960 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,10 +1,13 @@ import time +import numpy as np + import torch import torch.nn as nn from torch_em.loss import DiceLoss from torch_em.trainer import DefaultTrainer +from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im class SemanticSamTrainer(DefaultTrainer): @@ -56,11 +59,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(net_loss) + self._iteration += 1 + if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True) - self._iteration += 1 if self._iteration >= self.max_iteration: break progress.update(1) @@ -107,3 +111,37 @@ def _get_model_outputs(self, batched_inputs): # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) masks = batched_outputs["masks"] return masks + + +class SemanticSamLogger3D(TensorboardLogger): + def log_images(self, step, x, y, prediction, name, gradients=None): + + selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] + + image = normalize_im(x[selection].cpu()) + self.tb.add_image(tag=f"{name}/input", + img_tensor=image, + global_step=step) + + im, im_name = self.make_image(image, y, prediction, selection, gradients) + im_name = f"{name}/{im_name}" + self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) + + def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + + # the embedding visualisation function currently doesn't support gradients, + # so we can't log them even if log_gradients is true + log_grads = log_gradients + if self.have_embeddings: + log_grads = False + + if (step + 1) % self.log_image_interval == 0: + gradients = prediction.grad if log_grads else None + self.log_images(step, x, y, prediction, "train", gradients=gradients) + + def log_validation(self, step, metric, loss, x, y, prediction): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + # self.log_images(step, x, y, prediction, "validation") From 88970bd8cbe587e9541acecb72c7bbaaca8fd076 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 14:33:17 +0200 Subject: [PATCH 02/11] Making get_sam_3d_model more flexible --- development/check_3d_model.py | 3 +-- micro_sam/sam_3d_wrapper.py | 30 +++++++++++++++++++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/development/check_3d_model.py b/development/check_3d_model.py index ac49609c..82ca2186 100644 --- a/development/check_3d_model.py +++ b/development/check_3d_model.py @@ -1,6 +1,5 @@ import numpy as np import torch -import micro_sam.util as util from micro_sam.sam_3d_wrapper import get_3d_sam_model from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D @@ -9,7 +8,7 @@ def predict_3d_model(): d_size = 8 device = "cuda" if torch.cuda.is_available() else "cpu" - sam_3d = get_3d_sam_model(device, d_size) + _, sam_3d = get_3d_sam_model(device, d_size) input_ = 255 * np.random.rand(1, d_size, 3, 1024, 1024).astype("float32") with torch.no_grad(): diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index ccb9968e..cb2dbf01 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -6,16 +6,32 @@ from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam -from .util import get_sam_model - - -def get_3d_sam_model(device, n_classes, image_size, model_type="vit_b"): - predictor, sam = get_sam_model( - return_sam=True, model_type=model_type, device=device, num_multimask_outputs=n_classes, - flexible_load_checkpoint=True, image_size=image_size, +from .util import get_sam_model, _load_checkpoint, _handle_checkpoint_loading + + +def get_3d_sam_model( + device, + n_classes, + image_size, + model_type="vit_b", + checkpoint_path=None, +): + _, sam = get_sam_model( + model_type=model_type, + device=device, + return_sam=True, + flexible_load_checkpoint=True, + num_multimask_outputs=n_classes, + image_size=image_size, ) + sam_3d = Sam3DWrapper(sam) sam_3d.to(device) + + if checkpoint_path is not None: + _, model_state = _load_checkpoint(checkpoint_path) + sam_3d = _handle_checkpoint_loading(sam_3d, model_state) + return sam_3d From 246a1bbb4d1d5f7d160004fcb4668d516b5404af Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 14:34:54 +0200 Subject: [PATCH 03/11] Revert changes on dev script --- development/check_3d_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/development/check_3d_model.py b/development/check_3d_model.py index 82ca2186..ac49609c 100644 --- a/development/check_3d_model.py +++ b/development/check_3d_model.py @@ -1,5 +1,6 @@ import numpy as np import torch +import micro_sam.util as util from micro_sam.sam_3d_wrapper import get_3d_sam_model from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D @@ -8,7 +9,7 @@ def predict_3d_model(): d_size = 8 device = "cuda" if torch.cuda.is_available() else "cpu" - _, sam_3d = get_3d_sam_model(device, d_size) + sam_3d = get_3d_sam_model(device, d_size) input_ = 255 * np.random.rand(1, d_size, 3, 1024, 1024).astype("float32") with torch.no_grad(): From fe6a49859d241e09bddd187819d7868a9e30a06c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 27 Jun 2024 15:05:52 +0200 Subject: [PATCH 04/11] Fix logger --- micro_sam/training/semantic_sam_trainer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index d24dd960..800182e5 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -29,11 +29,13 @@ def __init__( self._kwargs = kwargs def _compute_loss(self, y, masks): + target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions - dice_loss = self.loss(masks, y.to(self.device, non_blocking=True)) + dice_loss = self.loss(masks, target) + breakpoint() # Compute cross entropy loss for the predictions - ce_loss = self.compute_ce_loss(masks, y.to(self.device, non_blocking=True)) + ce_loss = self.compute_ce_loss(masks, target) net_loss = dice_loss + ce_loss return net_loss @@ -63,7 +65,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True) + self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) if self._iteration >= self.max_iteration: break @@ -116,9 +118,10 @@ def _get_model_outputs(self, batched_inputs): class SemanticSamLogger3D(TensorboardLogger): def log_images(self, step, x, y, prediction, name, gradients=None): + selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] - image = normalize_im(x[selection].cpu()) + image = normalize_im(x[selection_image].cpu()) self.tb.add_image(tag=f"{name}/input", img_tensor=image, global_step=step) @@ -137,11 +140,11 @@ def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): if self.have_embeddings: log_grads = False - if (step + 1) % self.log_image_interval == 0: + if step % self.log_image_interval == 0: gradients = prediction.grad if log_grads else None self.log_images(step, x, y, prediction, "train", gradients=gradients) def log_validation(self, step, metric, loss, x, y, prediction): self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - # self.log_images(step, x, y, prediction, "validation") + self.log_images(step, x, y, prediction, "validation") From cadcc19ba7b4159d036b4e6774cab0d3bce11446 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 15:52:34 +0200 Subject: [PATCH 05/11] Add updates to semanticsamtrainer --- micro_sam/training/semantic_sam_trainer.py | 36 ++++++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 800182e5..9d036dfb 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -10,32 +10,56 @@ from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im +class CustomDiceLoss(nn.Module): + def __init__(self, num_classes: int, softmax: bool = True) -> None: + super().__init__() + self.num_classes = num_classes + self.dice_loss = DiceLoss() + self.softmax = softmax + + def _one_hot_encoder(self, input_tensor): + tensor_list = [] + for i in range(self.num_classes): + temp_prob = input_tensor == i # * torch.ones_like(input_tensor) + tensor_list.append(temp_prob) + output_tensor = torch.cat(tensor_list, dim=1) + return output_tensor.float() + + def __call__(self, pred, target): + if self.softmax: + pred = torch.softmax(pred, dim=1) + target = self._one_hot_encoder(target) + loss = self.dice_loss(pred, target) + return loss + + class SemanticSamTrainer(DefaultTrainer): """ """ def __init__( self, convert_inputs, - num_classes: int = 1, + num_classes: int, **kwargs ): - loss = DiceLoss() - metric = DiceLoss() + assert num_classes > 1 + + loss = CustomDiceLoss(num_classes=num_classes) + metric = CustomDiceLoss(num_classes=num_classes) super().__init__(loss=loss, metric=metric, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes - self.compute_ce_loss = nn.BCELoss() if num_classes == 1 else nn.CrossEntropyLoss() + self.compute_ce_loss = nn.CrossEntropyLoss() self._kwargs = kwargs def _compute_loss(self, y, masks): target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions dice_loss = self.loss(masks, target) - breakpoint() # Compute cross entropy loss for the predictions - ce_loss = self.compute_ce_loss(masks, target) + ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) net_loss = dice_loss + ce_loss return net_loss From d73186659d17a8c2065e2829a0427806c4bdb251 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 27 Jun 2024 17:04:42 +0200 Subject: [PATCH 06/11] Minor changes (#644) --- micro_sam/training/semantic_sam_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 9d036dfb..9ed212aa 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -66,7 +66,7 @@ def _compute_loss(self, y, masks): def _get_model_outputs(self, batched_inputs): image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) - batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1)) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) return masks @@ -116,8 +116,9 @@ def _validate_impl(self, forward_context): loss_val /= len(self.val_loader) metric_val /= len(self.val_loader) + dice_metric = 1 - (metric_val / self.num_classes) print() - print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}") + print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) @@ -150,6 +151,7 @@ def log_images(self, step, x, y, prediction, name, gradients=None): img_tensor=image, global_step=step) + prediction = torch.softmax(prediction, dim=1) im, im_name = self.make_image(image, y, prediction, selection, gradients) im_name = f"{name}/{im_name}" self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) From 56b09e60b14e2a4df79447b9f21e15fb4e62421b Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 17:23:26 +0200 Subject: [PATCH 07/11] Minor update to semanticsamlogger --- micro_sam/training/semantic_sam_trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 9ed212aa..2b55bbad 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -89,7 +89,8 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) + predictions = torch.softmax(masks, dim=1) + self.logger.log_train(self._iteration, net_loss, lr, x, y, predictions, log_gradients=False) if self._iteration >= self.max_iteration: break @@ -121,7 +122,8 @@ def _validate_impl(self, forward_context): print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: - self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) + predictions = torch.softmax(masks, dim=1) + self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, predictions) return metric_val From d37736b85b31dc20c6ebadea9cc6553382793ebe Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 17:43:09 +0200 Subject: [PATCH 08/11] Update logger class name for semanticsamlogger --- micro_sam/training/semantic_sam_trainer.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 2b55bbad..9c8d6a45 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -89,8 +89,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - predictions = torch.softmax(masks, dim=1) - self.logger.log_train(self._iteration, net_loss, lr, x, y, predictions, log_gradients=False) + self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) if self._iteration >= self.max_iteration: break @@ -122,8 +121,7 @@ def _validate_impl(self, forward_context): print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: - predictions = torch.softmax(masks, dim=1) - self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, predictions) + self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) return metric_val @@ -142,7 +140,7 @@ def _get_model_outputs(self, batched_inputs): return masks -class SemanticSamLogger3D(TensorboardLogger): +class SemanticSamLogger(TensorboardLogger): def log_images(self, step, x, y, prediction, name, gradients=None): selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] From f991b39ad8bfda0d27768a022bfff0000d4a7d0f Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 27 Jun 2024 23:38:56 +0200 Subject: [PATCH 09/11] Make semanticsamlogger as default logger --- micro_sam/training/semantic_sam_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 9c8d6a45..3387484d 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -46,7 +46,8 @@ def __init__( loss = CustomDiceLoss(num_classes=num_classes) metric = CustomDiceLoss(num_classes=num_classes) - super().__init__(loss=loss, metric=metric, **kwargs) + logger = SemanticSamLogger() + super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes From 3bab6edb27bb3466d3152abd2324405075ac6d28 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 28 Jun 2024 00:17:10 +0200 Subject: [PATCH 10/11] Minor fix to logger initialization --- micro_sam/training/semantic_sam_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 3387484d..93228752 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -46,7 +46,7 @@ def __init__( loss = CustomDiceLoss(num_classes=num_classes) metric = CustomDiceLoss(num_classes=num_classes) - logger = SemanticSamLogger() + logger = SemanticSamLogger super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) self.convert_inputs = convert_inputs From 6521a943a42306c35a4533ec595d7c004598eba0 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 28 Jun 2024 00:40:27 +0200 Subject: [PATCH 11/11] Remove extra checkpoint loading --- micro_sam/sam_3d_wrapper.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index cb2dbf01..1676652b 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -6,7 +6,7 @@ from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam -from .util import get_sam_model, _load_checkpoint, _handle_checkpoint_loading +from .util import get_sam_model def get_3d_sam_model( @@ -19,6 +19,7 @@ def get_3d_sam_model( _, sam = get_sam_model( model_type=model_type, device=device, + checkpoint_path=checkpoint_path, return_sam=True, flexible_load_checkpoint=True, num_multimask_outputs=n_classes, @@ -27,11 +28,6 @@ def get_3d_sam_model( sam_3d = Sam3DWrapper(sam) sam_3d.to(device) - - if checkpoint_path is not None: - _, model_state = _load_checkpoint(checkpoint_path) - sam_3d = _handle_checkpoint_loading(sam_3d, model_state) - return sam_3d