-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Project] support SAM inferencer #2897
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
5a8d9e1
add sam modules
xiexinch f32b348
support api
xiexinch 2fcd19b
update comments
xiexinch 5c33988
Merge remote-tracking branch 'upstream/dev-1.x' into sam-inferencer
xiexinch 9ac35ae
mv to project
xiexinch 6ed098d
minor change
xiexinch 7a82831
add readme
xiexinch d09a9dc
update
xiexinch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Introducing the Segment Anything Model (SAM) Inference Demo! | ||
|
||
Welcome to the Segment Anything (SA) Inference Demo, a user-friendly implementation based on the original Segment Anything project. Our demo allows you to experience the power and versatility of the Segment Anything Model (SAM) through an easy-to-use API. | ||
|
||
With this inference demo, you can explore the capabilities of the Segment Anything Model and witness its effectiveness in various tasks and image distributions. For more information on the original project, dataset, and model, please visit the official website at https://segment-anything.com. | ||
|
||
### Prerequisites | ||
|
||
- Python 3.10 | ||
- PyTorch 1.13 | ||
- MMEngine >= v0.7.2 | ||
- MMCV >= v2.0.0 | ||
|
||
### Installation | ||
|
||
We assume that you have already installed PyTorch. If not, please follow the instructions on the [PyTorch website](https://pytorch.org/). | ||
|
||
**1. Install MMEngine & MMCV** | ||
|
||
```shell | ||
pip install openmim | ||
mim install mmengine | ||
mim install 'mmcv>=2.0.0' | ||
``` | ||
|
||
**2. Install MMPretrain** | ||
|
||
```shell | ||
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev | ||
``` | ||
|
||
**3. Install MMSegmentation** | ||
|
||
```shell | ||
pip install mmsegmentation | ||
``` | ||
|
||
### Usage | ||
|
||
Open the `sam_image_demo.ipynb` notebook and follow the instructions to run the demo. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .modeling import * # noqa | ||
from .utils import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .mask_decoder import MaskDecoder | ||
from .prompt_encoder import PromptEncoder | ||
from .sam import SAM | ||
from .transformer import TwoWayTransformer | ||
|
||
__all__ = ['SAM', 'MaskDecoder', 'PromptEncoder', 'TwoWayTransformer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Type | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class MLPBlock(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
embedding_dim: int, | ||
mlp_dim: int, | ||
act: Type[nn.Module] = nn.GELU, | ||
) -> None: | ||
super().__init__() | ||
self.lin1 = nn.Linear(embedding_dim, mlp_dim) | ||
self.lin2 = nn.Linear(mlp_dim, embedding_dim) | ||
self.act = act() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.lin2(self.act(self.lin1(x))) | ||
|
||
|
||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa | ||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa | ||
class LayerNorm2d(nn.Module): | ||
|
||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | ||
super().__init__() | ||
self.weight = nn.Parameter(torch.ones(num_channels)) | ||
self.bias = nn.Parameter(torch.zeros(num_channels)) | ||
self.eps = eps | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
u = x.mean(1, keepdim=True) | ||
s = (x - u).pow(2).mean(1, keepdim=True) | ||
x = (x - u) / torch.sqrt(s + self.eps) | ||
x = self.weight[:, None, None] * x + self.bias[:, None, None] | ||
return x |
196 changes: 196 additions & 0 deletions
196
projects/sam_inference_demo/sam/modeling/mask_decoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Borrowed from https://github.com/facebookresearch/segment-anything | ||
|
||
from typing import List, Tuple | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
from torch.nn import functional as F | ||
|
||
from mmseg.registry import MODELS | ||
from .common import LayerNorm2d | ||
|
||
|
||
@MODELS.register_module() | ||
class MaskDecoder(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
*, | ||
transformer_dim: int, | ||
transformer: dict, | ||
num_multimask_outputs: int = 3, | ||
act_cfg: dict = dict(type='GELU'), | ||
iou_head_depth: int = 3, | ||
iou_head_hidden_dim: int = 256, | ||
) -> None: | ||
"""Predicts masks given an image and prompt embeddings, using a | ||
tranformer architecture. | ||
|
||
Borrowed from https://github.com/facebookresearch/segment-anything | ||
|
||
Arguments: | ||
transformer_dim (int): the channel dimension of the transformer | ||
transformer (nn.Module): the transformer used to predict masks | ||
num_multimask_outputs (int): the number of masks to predict | ||
when disambiguating masks | ||
activation (nn.Module): the type of activation to use when | ||
upscaling masks | ||
iou_head_depth (int): the depth of the MLP used to predict | ||
mask quality | ||
iou_head_hidden_dim (int): the hidden dimension of the MLP | ||
used to predict mask quality | ||
""" | ||
super().__init__() | ||
self.transformer_dim = transformer_dim | ||
self.transformer = MODELS.build(transformer) | ||
|
||
self.num_multimask_outputs = num_multimask_outputs | ||
|
||
self.iou_token = nn.Embedding(1, transformer_dim) | ||
self.num_mask_tokens = num_multimask_outputs + 1 | ||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) | ||
|
||
activation = MODELS.build(act_cfg) | ||
self.output_upscaling = nn.Sequential( | ||
nn.ConvTranspose2d( | ||
transformer_dim, transformer_dim // 4, kernel_size=2, | ||
stride=2), | ||
LayerNorm2d(transformer_dim // 4), | ||
activation, | ||
nn.ConvTranspose2d( | ||
transformer_dim // 4, | ||
transformer_dim // 8, | ||
kernel_size=2, | ||
stride=2), | ||
activation, | ||
) | ||
self.output_hypernetworks_mlps = nn.ModuleList([ | ||
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) | ||
for i in range(self.num_mask_tokens) | ||
]) | ||
|
||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, | ||
self.num_mask_tokens, iou_head_depth) | ||
|
||
def forward( | ||
self, | ||
image_embeddings: Tensor, | ||
image_pe: Tensor, | ||
sparse_prompt_embeddings: Tensor, | ||
dense_prompt_embeddings: Tensor, | ||
multimask_output: bool, | ||
) -> Tuple[Tensor, Tensor]: | ||
"""Predict masks given image and prompt embeddings. | ||
|
||
Borrowed from https://github.com/facebookresearch/segment-anything | ||
|
||
Arguments: | ||
image_embeddings (Tensor): the embeddings from the image encoder | ||
image_pe (Tensor): positional encoding with the shape of | ||
image_embeddings | ||
sparse_prompt_embeddings (Tensor): the embeddings of | ||
the points and boxes | ||
dense_prompt_embeddings (Tensor): the embeddings of the mask inputs | ||
multimask_output (bool): Whether to return multiple masks or a single | ||
mask. | ||
|
||
Returns: | ||
Tensor: batched predicted masks | ||
Tensor: batched predictions of mask quality | ||
""" | ||
masks, iou_pred = self.predict_masks( | ||
image_embeddings=image_embeddings, | ||
image_pe=image_pe, | ||
sparse_prompt_embeddings=sparse_prompt_embeddings, | ||
dense_prompt_embeddings=dense_prompt_embeddings, | ||
) | ||
|
||
# Select the correct mask or masks for output | ||
if multimask_output: | ||
mask_slice = slice(1, None) | ||
else: | ||
mask_slice = slice(0, 1) | ||
masks = masks[:, mask_slice, :, :] | ||
iou_pred = iou_pred[:, mask_slice] | ||
|
||
# Prepare output | ||
return masks, iou_pred | ||
|
||
def predict_masks( | ||
self, | ||
image_embeddings: Tensor, | ||
image_pe: Tensor, | ||
sparse_prompt_embeddings: Tensor, | ||
dense_prompt_embeddings: Tensor, | ||
) -> Tuple[Tensor, Tensor]: | ||
"""Predicts masks. | ||
|
||
See 'forward' for more details. | ||
""" | ||
# Concatenate output tokens | ||
output_tokens = torch.cat( | ||
[self.iou_token.weight, self.mask_tokens.weight], dim=0) | ||
output_tokens = output_tokens.unsqueeze(0).expand( | ||
sparse_prompt_embeddings.size(0), -1, -1) | ||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) | ||
|
||
# Expand per-image data in batch direction to be per-mask | ||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) | ||
src = src + dense_prompt_embeddings | ||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) | ||
b, c, h, w = src.shape | ||
|
||
# Run the transformer | ||
hs, src = self.transformer(src, pos_src, tokens) | ||
iou_token_out = hs[:, 0, :] | ||
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :] | ||
|
||
# Upscale mask embeddings and predict masks using the mask tokens | ||
src = src.transpose(1, 2).view(b, c, h, w) | ||
upscaled_embedding = self.output_upscaling(src) | ||
hyper_in_list: List[Tensor] = [] | ||
for i in range(self.num_mask_tokens): | ||
hyper_in_list.append(self.output_hypernetworks_mlps[i]( | ||
mask_tokens_out[:, i, :])) | ||
hyper_in = torch.stack(hyper_in_list, dim=1) | ||
b, c, h, w = upscaled_embedding.shape | ||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view( | ||
b, -1, h, w) | ||
|
||
# Generate mask quality predictions | ||
iou_pred = self.iou_prediction_head(iou_token_out) | ||
|
||
return masks, iou_pred | ||
|
||
|
||
# Lightly adapted from | ||
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa | ||
class MLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
input_dim: int, | ||
hidden_dim: int, | ||
output_dim: int, | ||
num_layers: int, | ||
sigmoid_output: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self.num_layers = num_layers | ||
h = [hidden_dim] * (num_layers - 1) | ||
self.layers = nn.ModuleList( | ||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | ||
self.sigmoid_output = sigmoid_output | ||
|
||
def forward(self, x): | ||
for i, layer in enumerate(self.layers): | ||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | ||
if self.sigmoid_output: | ||
x = F.sigmoid(x) | ||
return x |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add prerequisites for mmengine mmcv mmseg?