-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Api cleanup #648
Api cleanup #648
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .build_sam import sam_model_registry | ||
from .peft_sam import PEFT_Sam |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
from typing import Type | ||
from typing import Any, List, Dict, Type | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from segment_anything.modeling.image_encoder import window_partition, window_unpartition | ||
from segment_anything.modeling import Sam | ||
|
||
from .util import get_sam_model | ||
from ..util import get_sam_model | ||
|
||
|
||
def get_3d_sam_model( | ||
def get_sam_3d_model( | ||
device, | ||
n_classes, | ||
image_size, | ||
|
@@ -18,15 +18,8 @@ def get_3d_sam_model( | |
model_type="vit_b", | ||
checkpoint_path=None, | ||
): | ||
if lora_rank is None: | ||
use_lora = False | ||
rank = None | ||
freeze_encoder_ = freeze_encoder | ||
else: | ||
use_lora = True | ||
rank = lora_rank | ||
freeze_encoder_ = False | ||
|
||
# Make sure not to freeze the encoder when using LoRA. | ||
freeze_encoder_ = freeze_encoder if lora_rank is None else False | ||
_, sam = get_sam_model( | ||
model_type=model_type, | ||
device=device, | ||
|
@@ -35,8 +28,7 @@ def get_3d_sam_model( | |
flexible_load_checkpoint=True, | ||
num_multimask_outputs=n_classes, | ||
image_size=image_size, | ||
use_lora=use_lora, | ||
rank=rank, | ||
lora_rank=lora_rank, | ||
) | ||
|
||
sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) | ||
|
@@ -46,11 +38,10 @@ def get_3d_sam_model( | |
|
||
class Sam3DWrapper(nn.Module): | ||
def __init__(self, sam_model: Sam, freeze_encoder: bool): | ||
""" | ||
Initializes the Sam3DWrapper object. | ||
"""Initializes the Sam3DWrapper object. | ||
|
||
Args: | ||
sam_model (Sam): The Sam model to be wrapped. | ||
sam_model: The Sam model to be wrapped. | ||
""" | ||
super().__init__() | ||
sam_model.image_encoder = ImageEncoderViT3DWrapper( | ||
|
@@ -63,25 +54,42 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool): | |
for param in self.sam_model.image_encoder.parameters(): | ||
param.requires_grad = False | ||
|
||
# FIXME | ||
# - handling of the image size here is wrong, this only works for square images | ||
# - this does not take care of resizing | ||
# unclear how batches are handled | ||
def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: | ||
return self._forward_train(batched_input, multimask_output, image_size) | ||
def forward( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self, | ||
batched_input: List[Dict[str, Any]], | ||
multimask_output: bool | ||
) -> List[Dict[str, torch.Tensor]]: | ||
"""Predict 3D masks for the current inputs. | ||
|
||
Unlike original SAM this model only supports automatic segmentation and does not support prompts. | ||
|
||
Args: | ||
batched_input: A list over input images, each a dictionary with the following keys.L | ||
'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. | ||
'original_size': The original size of the image (HxW) before transformation. | ||
multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. | ||
|
||
Returns: | ||
A list over input images, where each element is as dictionary with the following keys: | ||
'masks': Mask prediction for this object. | ||
'iou_predictions': IOU score prediction for this object. | ||
'low_res_masks': Low resolution mask prediction for this object. | ||
""" | ||
batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) | ||
original_size = batched_input[0]["original_size"] | ||
assert all(inp["original_size"] == original_size for inp in batched_input) | ||
|
||
def _forward_train(self, batched_input, multimask_output, image_size): | ||
# dimensions: [b, 3, d, h, w] | ||
shape = batched_input.shape | ||
shape = batched_images.shape | ||
assert shape[1] == 3 | ||
batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] | ||
# Transpose the axes, so that the depth axis is the first axis and the channel | ||
# axis is the second axis. This is expected by the transformer! | ||
batched_input = batched_input.transpose(1, 2) | ||
assert batched_input.shape[1] == d_size | ||
batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) | ||
batched_images = batched_images.transpose(1, 2) | ||
assert batched_images.shape[1] == d_size | ||
batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) | ||
|
||
input_images = self.sam_model.preprocess(batched_input) | ||
input_images = self.sam_model.preprocess(batched_images) | ||
image_embeddings = self.sam_model.image_encoder(input_images, d_size) | ||
sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( | ||
points=None, boxes=None, masks=None | ||
|
@@ -95,8 +103,8 @@ def _forward_train(self, batched_input, multimask_output, image_size): | |
) | ||
masks = self.sam_model.postprocess_masks( | ||
low_res_masks, | ||
input_size=(image_size, image_size), | ||
original_size=(image_size, image_size) | ||
input_size=batched_images.shape[-2:], | ||
original_size=original_size, | ||
) | ||
|
||
# Bring the masks and low-res masks into the correct shape: | ||
|
@@ -112,11 +120,12 @@ def _forward_train(self, batched_input, multimask_output, image_size): | |
masks = masks.transpose(1, 2) | ||
low_res_masks = low_res_masks.transpose(1, 2) | ||
|
||
outputs = { | ||
"masks": masks, | ||
"iou_predictions": iou_predictions, | ||
"low_res_logits": low_res_masks | ||
} | ||
# Make the output compatable with the SAM output. | ||
outputs = [{ | ||
"masks": mask.unsqueeze(0), | ||
"iou_predictions": iou_pred, | ||
"low_res_logits": low_res_mask.unsqueeze(0) | ||
} for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] | ||
return outputs | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import os | ||
from math import ceil, floor | ||
from typing import List, Optional, Union | ||
from typing import Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
|
||
|
@@ -43,8 +43,8 @@ def get_trainable_sam_model( | |
checkpoint_path: Optional[Union[str, os.PathLike]] = None, | ||
freeze: Optional[List[str]] = None, | ||
return_state: bool = False, | ||
use_lora: bool = False, | ||
rank: Optional[int] = None, | ||
lora_rank: Optional[int] = None, | ||
lora_kwargs: Optional[Dict] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me as well 🤝 |
||
flexible_load_checkpoint: bool = False, | ||
**model_kwargs | ||
) -> TrainableSAM: | ||
|
@@ -59,9 +59,11 @@ def get_trainable_sam_model( | |
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder | ||
By default nothing is frozen and the full model is updated. | ||
return_state: Whether to return the full checkpoint state. | ||
use_lora: Whether to use the low rank adaptation method for finetuning. | ||
rank: The rank of the decomposition matrices for updating weights in each attention layer. | ||
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. | ||
If None then LoRA is not used. | ||
lora_kwargs: Keyword arguments for th PEFT wrapper class. | ||
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. | ||
model_kwargs: Additional keyword arguments for the `util.get_sam_model`. | ||
|
||
Returns: | ||
The trainable segment anything model. | ||
|
@@ -74,8 +76,7 @@ def get_trainable_sam_model( | |
checkpoint_path=checkpoint_path, | ||
return_sam=True, | ||
return_state=True, | ||
use_lora=use_lora, | ||
rank=rank, | ||
lora_rank=lora_rank, | ||
flexible_load_checkpoint=flexible_load_checkpoint, | ||
**model_kwargs | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed these names to match the order of
sam_3d
we use everywhere else, cc @lufre1