Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SemanticSam3dLogger #643

Merged
merged 11 commits into from
Jun 28, 2024
30 changes: 23 additions & 7 deletions micro_sam/sam_3d_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,32 @@
from segment_anything.modeling.image_encoder import window_partition, window_unpartition
from segment_anything.modeling import Sam

from .util import get_sam_model


def get_3d_sam_model(device, n_classes, image_size, model_type="vit_b"):
predictor, sam = get_sam_model(
return_sam=True, model_type=model_type, device=device, num_multimask_outputs=n_classes,
flexible_load_checkpoint=True, image_size=image_size,
from .util import get_sam_model, _load_checkpoint, _handle_checkpoint_loading


def get_3d_sam_model(
device,
n_classes,
image_size,
model_type="vit_b",
checkpoint_path=None,
):
_, sam = get_sam_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just pass the checkpoint here? Then we don't need to duplicate the code below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that won't work as we need to pass the weights to the wrapper model (to initialize the adapter blocks)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I don't fully understand. Let's discuss tomorrow :) .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay now I see already where your suspicion comes from.

TLDR: I think we might have to merge the get_3d_sam_model into get_sam_model for the best possible design (which leads to flexibly loading SAM checkpoints for finetuning, and the 3d-SAM checkpoints for downstream semantic inference).

I have a plan on this. I'll take care of this first thing in the morning. Thanks for spotting.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current design should work at least for intialization. We can revisit this later to discuss how we do this for actually loading the trained 3d models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okie, I'll leave it to here for now then. Thanks!

model_type=model_type,
device=device,
return_sam=True,
flexible_load_checkpoint=True,
num_multimask_outputs=n_classes,
image_size=image_size,
)

sam_3d = Sam3DWrapper(sam)
sam_3d.to(device)

if checkpoint_path is not None:
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
_, model_state = _load_checkpoint(checkpoint_path)
sam_3d = _handle_checkpoint_loading(sam_3d, model_state)

return sam_3d


Expand Down
91 changes: 80 additions & 11 deletions micro_sam/training/semantic_sam_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
import time

import numpy as np

import torch
import torch.nn as nn

from torch_em.loss import DiceLoss
from torch_em.trainer import DefaultTrainer
from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im


class CustomDiceLoss(nn.Module):
def __init__(self, num_classes: int, softmax: bool = True) -> None:
super().__init__()
self.num_classes = num_classes
self.dice_loss = DiceLoss()
self.softmax = softmax

def _one_hot_encoder(self, input_tensor):
tensor_list = []
for i in range(self.num_classes):
temp_prob = input_tensor == i # * torch.ones_like(input_tensor)
tensor_list.append(temp_prob)
output_tensor = torch.cat(tensor_list, dim=1)
return output_tensor.float()

def __call__(self, pred, target):
if self.softmax:
pred = torch.softmax(pred, dim=1)
target = self._one_hot_encoder(target)
loss = self.dice_loss(pred, target)
return loss


class SemanticSamTrainer(DefaultTrainer):
Expand All @@ -13,31 +39,34 @@ class SemanticSamTrainer(DefaultTrainer):
def __init__(
self,
convert_inputs,
num_classes: int = 1,
num_classes: int,
**kwargs
):
loss = DiceLoss()
metric = DiceLoss()
assert num_classes > 1

loss = CustomDiceLoss(num_classes=num_classes)
metric = CustomDiceLoss(num_classes=num_classes)
super().__init__(loss=loss, metric=metric, **kwargs)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved

self.convert_inputs = convert_inputs
self.num_classes = num_classes
self.compute_ce_loss = nn.BCELoss() if num_classes == 1 else nn.CrossEntropyLoss()
self.compute_ce_loss = nn.CrossEntropyLoss()
self._kwargs = kwargs

def _compute_loss(self, y, masks):
target = y.to(self.device, non_blocking=True)
# Compute dice loss for the predictions
dice_loss = self.loss(masks, y.to(self.device, non_blocking=True))
dice_loss = self.loss(masks, target)

# Compute cross entropy loss for the predictions
ce_loss = self.compute_ce_loss(masks, y.to(self.device, non_blocking=True))
ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long())

net_loss = dice_loss + ce_loss
return net_loss

def _get_model_outputs(self, batched_inputs):
image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs)
batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1))
batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True)
masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
return masks

Expand All @@ -56,11 +85,13 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

backprop(net_loss)

self._iteration += 1

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True)
predictions = torch.softmax(masks, dim=1)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
self.logger.log_train(self._iteration, net_loss, lr, x, y, predictions, log_gradients=False)

self._iteration += 1
if self._iteration >= self.max_iteration:
break
progress.update(1)
Expand All @@ -86,11 +117,13 @@ def _validate_impl(self, forward_context):

loss_val /= len(self.val_loader)
metric_val /= len(self.val_loader)
dice_metric = 1 - (metric_val / self.num_classes)
print()
print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}")
print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")

if self.logger is not None:
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks)
predictions = torch.softmax(masks, dim=1)
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, predictions)

return metric_val

Expand All @@ -107,3 +140,39 @@ def _get_model_outputs(self, batched_inputs):
# masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
masks = batched_outputs["masks"]
return masks


class SemanticSamLogger3D(TensorboardLogger):
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
def log_images(self, step, x, y, prediction, name, gradients=None):

selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :]
selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2]

image = normalize_im(x[selection_image].cpu())
self.tb.add_image(tag=f"{name}/input",
img_tensor=image,
global_step=step)

prediction = torch.softmax(prediction, dim=1)
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
im, im_name = self.make_image(image, y, prediction, selection, gradients)
im_name = f"{name}/{im_name}"
self.tb.add_image(tag=im_name, img_tensor=im, global_step=step)

def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)

# the embedding visualisation function currently doesn't support gradients,
# so we can't log them even if log_gradients is true
log_grads = log_gradients
if self.have_embeddings:
log_grads = False

if step % self.log_image_interval == 0:
gradients = prediction.grad if log_grads else None
self.log_images(step, x, y, prediction, "train", gradients=gradients)

def log_validation(self, step, metric, loss, x, y, prediction):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.log_images(step, x, y, prediction, "validation")
Loading