Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Api cleanup #648

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions micro_sam/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .build_sam import sam_model_registry
from .peft_sam import PEFT_Sam
File renamed without changes.
10 changes: 4 additions & 6 deletions micro_sam/training/peft_sam.py → micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def forward(self, x):


class PEFT_Sam(nn.Module):
"""Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Inspired by https://github.com/JamesQFreeman/Sam_LoRA/

Args:
model: The Segment Anything model.
Expand All @@ -71,16 +71,14 @@ def __init__(
peft_module: nn.Module = LoRASurgery,
attention_layers_to_update: Union[List[int]] = None
):
super(PEFT_Sam, self).__init__()
super().__init__()

assert rank > 0

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
else: # Applies PEFT to the image encoder by default
self.peft_layers = list(
range(len(model.image_encoder.blocks))
)
self.peft_layers = list(range(len(model.image_encoder.blocks)))

self.peft_module = peft_module
self.peft_blocks = []
Expand Down
81 changes: 45 additions & 36 deletions micro_sam/sam_3d_wrapper.py → micro_sam/models/sam_3d_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Type
from typing import Any, List, Dict, Type

import torch
import torch.nn as nn

from segment_anything.modeling.image_encoder import window_partition, window_unpartition
from segment_anything.modeling import Sam

from .util import get_sam_model
from ..util import get_sam_model


def get_3d_sam_model(
def get_sam_3d_model(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed these names to match the order of sam_3d we use everywhere else, cc @lufre1

device,
n_classes,
image_size,
Expand All @@ -18,15 +18,8 @@ def get_3d_sam_model(
model_type="vit_b",
checkpoint_path=None,
):
if lora_rank is None:
use_lora = False
rank = None
freeze_encoder_ = freeze_encoder
else:
use_lora = True
rank = lora_rank
freeze_encoder_ = False

# Make sure not to freeze the encoder when using LoRA.
freeze_encoder_ = freeze_encoder if lora_rank is None else False
_, sam = get_sam_model(
model_type=model_type,
device=device,
Expand All @@ -35,8 +28,7 @@ def get_3d_sam_model(
flexible_load_checkpoint=True,
num_multimask_outputs=n_classes,
image_size=image_size,
use_lora=use_lora,
rank=rank,
lora_rank=lora_rank,
)

sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_)
Expand All @@ -46,11 +38,10 @@ def get_3d_sam_model(

class Sam3DWrapper(nn.Module):
def __init__(self, sam_model: Sam, freeze_encoder: bool):
"""
Initializes the Sam3DWrapper object.
"""Initializes the Sam3DWrapper object.

Args:
sam_model (Sam): The Sam model to be wrapped.
sam_model: The Sam model to be wrapped.
"""
super().__init__()
sam_model.image_encoder = ImageEncoderViT3DWrapper(
Expand All @@ -63,25 +54,42 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool):
for param in self.sam_model.image_encoder.parameters():
param.requires_grad = False

# FIXME
# - handling of the image size here is wrong, this only works for square images
# - this does not take care of resizing
# unclear how batches are handled
def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor:
return self._forward_train(batched_input, multimask_output, image_size)
def forward(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now matches the input and output signature of the original SAM model.
cc @lufre1 @anwai98

self,
batched_input: List[Dict[str, Any]],
multimask_output: bool
) -> List[Dict[str, torch.Tensor]]:
"""Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Args:
batched_input: A list over input images, each a dictionary with the following keys.L
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
'original_size': The original size of the image (HxW) before transformation.
multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.

Returns:
A list over input images, where each element is as dictionary with the following keys:
'masks': Mask prediction for this object.
'iou_predictions': IOU score prediction for this object.
'low_res_masks': Low resolution mask prediction for this object.
"""
batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0)
original_size = batched_input[0]["original_size"]
assert all(inp["original_size"] == original_size for inp in batched_input)

def _forward_train(self, batched_input, multimask_output, image_size):
# dimensions: [b, 3, d, h, w]
shape = batched_input.shape
shape = batched_images.shape
assert shape[1] == 3
batch_size, d_size, hw_size = shape[0], shape[2], shape[-2]
# Transpose the axes, so that the depth axis is the first axis and the channel
# axis is the second axis. This is expected by the transformer!
batched_input = batched_input.transpose(1, 2)
assert batched_input.shape[1] == d_size
batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size)
batched_images = batched_images.transpose(1, 2)
assert batched_images.shape[1] == d_size
batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size)

input_images = self.sam_model.preprocess(batched_input)
input_images = self.sam_model.preprocess(batched_images)
image_embeddings = self.sam_model.image_encoder(input_images, d_size)
sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
points=None, boxes=None, masks=None
Expand All @@ -95,8 +103,8 @@ def _forward_train(self, batched_input, multimask_output, image_size):
)
masks = self.sam_model.postprocess_masks(
low_res_masks,
input_size=(image_size, image_size),
original_size=(image_size, image_size)
input_size=batched_images.shape[-2:],
original_size=original_size,
)

# Bring the masks and low-res masks into the correct shape:
Expand All @@ -112,11 +120,12 @@ def _forward_train(self, batched_input, multimask_output, image_size):
masks = masks.transpose(1, 2)
low_res_masks = low_res_masks.transpose(1, 2)

outputs = {
"masks": masks,
"iou_predictions": iou_predictions,
"low_res_logits": low_res_masks
}
# Make the output compatable with the SAM output.
outputs = [{
"masks": mask.unsqueeze(0),
"iou_predictions": iou_pred,
"low_res_logits": low_res_mask.unsqueeze(0)
} for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)]
return outputs


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import nullcontext
from typing import Any, List, Dict

import torch
import torch.nn as nn

from .util import get_sam_model
from ..util import get_sam_model


def get_simple_3d_sam_model(
def get_simple_sam_3d_model(
device,
n_classes,
image_size,
Expand All @@ -15,14 +16,6 @@ def get_simple_3d_sam_model(
model_type="vit_b",
checkpoint_path=None,
):
if lora_rank is None:
use_lora = False
rank = None
freeze_encoder_ = freeze_encoder
else:
use_lora = True
rank = lora_rank
freeze_encoder_ = False

_, sam = get_sam_model(
model_type=model_type,
Expand All @@ -31,10 +24,11 @@ def get_simple_3d_sam_model(
return_sam=True,
image_size=image_size,
flexible_load_checkpoint=True,
use_lora=use_lora,
rank=rank,
lora_rank=lora_rank,
)

# Make sure not to freeze the encoder when using LoRA.
freeze_encoder_ = freeze_encoder if lora_rank is None else False
sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_)
sam_3d.to(device)
return sam_3d
Expand Down Expand Up @@ -142,8 +136,27 @@ def _apply_image_encoder(self, x, D):
encoder_features = torch.stack(encoder_features, 2)
return encoder_features

def forward(self, x, **kwargs):
B, D, C, H, W = x.shape
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool
) -> List[Dict[str, torch.Tensor]]:
"""Predict 3D masks for the current inputs.

Unlike original SAM this model only supports automatic segmentation and does not support prompts.

Args:
batched_input: A list over input images, each a dictionary with the following keys.L
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model.
multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder.

Returns:
A list over input images, where each element is as dictionary with the following keys:
'masks': Mask prediction for this object.
"""
x = torch.stack([inp["image"] for inp in batched_input], dim=0)

B, C, D, H, W = x.shape
assert C == 3

with self.no_grad():
Expand All @@ -154,5 +167,5 @@ def forward(self, x, **kwargs):
out = decoder(out)
logits = self.out_conv(out)

outputs = {"masks": logits}
outputs = [{"masks": mask.unsqueeze(0)} for mask in logits]
return outputs
14 changes: 0 additions & 14 deletions micro_sam/training/semantic_sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,3 @@ def _validate_impl(self, forward_context):
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1))

return metric_val


class SemanticSamTrainer3D(SemanticSamTrainer):
def _get_model_outputs(self, batched_inputs):
model_input = torch.stack([inp["image"] for inp in batched_inputs]).to(self.device)
image_size = batched_inputs[0]["original_size"][-1]
batched_outputs = self.model(
model_input,
multimask_output=True,
image_size=image_size
)
# masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs])
masks = batched_outputs["masks"]
return masks
15 changes: 8 additions & 7 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from math import ceil, floor
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -43,8 +43,8 @@ def get_trainable_sam_model(
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
freeze: Optional[List[str]] = None,
return_state: bool = False,
use_lora: bool = False,
rank: Optional[int] = None,
lora_rank: Optional[int] = None,
lora_kwargs: Optional[Dict] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the use_lora, since it is redundant with rank = None and renamed rank -> lora_rank to have a more explicit name.
I added lora_kwargs, so that we can also customize the PEFT SAM layers from here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me as well 🤝

flexible_load_checkpoint: bool = False,
**model_kwargs
) -> TrainableSAM:
Expand All @@ -59,9 +59,11 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
use_lora: Whether to use the low rank adaptation method for finetuning.
rank: The rank of the decomposition matrices for updating weights in each attention layer.
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora.
If None then LoRA is not used.
lora_kwargs: Keyword arguments for th PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.

Returns:
The trainable segment anything model.
Expand All @@ -74,8 +76,7 @@ def get_trainable_sam_model(
checkpoint_path=checkpoint_path,
return_sam=True,
return_state=True,
use_lora=use_lora,
rank=rank,
lora_rank=lora_rank,
flexible_load_checkpoint=flexible_load_checkpoint,
**model_kwargs
)
Expand Down
Loading