Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Project] support SAM inferencer #2897

Merged
merged 8 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions projects/sam_inference_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Introducing the Segment Anything Model (SAM) Inference Demo!

Welcome to the Segment Anything (SA) Inference Demo, a user-friendly implementation based on the original Segment Anything project. Our demo allows you to experience the power and versatility of the Segment Anything Model (SAM) through an easy-to-use API.

With this inference demo, you can explore the capabilities of the Segment Anything Model and witness its effectiveness in various tasks and image distributions. For more information on the original project, dataset, and model, please visit the official website at https://segment-anything.com.

### Prerequisites

- Python 3.10
- PyTorch 1.13

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add prerequisites for mmengine mmcv mmseg?

### Installation

We assume that you have already installed PyTorch. If not, please follow the instructions on the [PyTorch website](https://pytorch.org/).

**1. Install MMEngine & MMCV**

```shell
pip install openmim
mim install mmengine
mim install 'mmcv>=1.0.0'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmpretrain require mmcv>=2.0.0?

```

**2. Install MMPretrain**

```shell
pip install git+https://github.com/open-mmlab/mmpretrain.git@dev
```

**3. Install MMSegmentation**

```shell
pip install mmsegmentation
```

### Usage

Open the `sam_image_demo.ipynb` notebook and follow the instructions to run the demo.
2 changes: 2 additions & 0 deletions projects/sam_inference_demo/sam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling import * # noqa
from .utils import * # noqa
12 changes: 12 additions & 0 deletions projects/sam_inference_demo/sam/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .sam import SAM
from .transformer import TwoWayTransformer

__all__ = ['SAM', 'MaskDecoder', 'PromptEncoder', 'TwoWayTransformer']
45 changes: 45 additions & 0 deletions projects/sam_inference_demo/sam/modeling/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Type

import torch
import torch.nn as nn


class MLPBlock(nn.Module):

def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))


# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):

def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
196 changes: 196 additions & 0 deletions projects/sam_inference_demo/sam/modeling/mask_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Borrowed from https://github.com/facebookresearch/segment-anything

from typing import List, Tuple

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from mmseg.registry import MODELS
from .common import LayerNorm2d


@MODELS.register_module()
class MaskDecoder(nn.Module):

def __init__(
self,
*,
transformer_dim: int,
transformer: dict,
num_multimask_outputs: int = 3,
act_cfg: dict = dict(type='GELU'),
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""Predicts masks given an image and prompt embeddings, using a
tranformer architecture.

Borrowed from https://github.com/facebookresearch/segment-anything

Arguments:
transformer_dim (int): the channel dimension of the transformer
transformer (nn.Module): the transformer used to predict masks
num_multimask_outputs (int): the number of masks to predict
when disambiguating masks
activation (nn.Module): the type of activation to use when
upscaling masks
iou_head_depth (int): the depth of the MLP used to predict
mask quality
iou_head_hidden_dim (int): the hidden dimension of the MLP
used to predict mask quality
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = MODELS.build(transformer)

self.num_multimask_outputs = num_multimask_outputs

self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

activation = MODELS.build(act_cfg)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(
transformer_dim, transformer_dim // 4, kernel_size=2,
stride=2),
LayerNorm2d(transformer_dim // 4),
activation,
nn.ConvTranspose2d(
transformer_dim // 4,
transformer_dim // 8,
kernel_size=2,
stride=2),
activation,
)
self.output_hypernetworks_mlps = nn.ModuleList([
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
])

self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim,
self.num_mask_tokens, iou_head_depth)

def forward(
self,
image_embeddings: Tensor,
image_pe: Tensor,
sparse_prompt_embeddings: Tensor,
dense_prompt_embeddings: Tensor,
multimask_output: bool,
) -> Tuple[Tensor, Tensor]:
"""Predict masks given image and prompt embeddings.

Borrowed from https://github.com/facebookresearch/segment-anything

Arguments:
image_embeddings (Tensor): the embeddings from the image encoder
image_pe (Tensor): positional encoding with the shape of
image_embeddings
sparse_prompt_embeddings (Tensor): the embeddings of
the points and boxes
dense_prompt_embeddings (Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.

Returns:
Tensor: batched predicted masks
Tensor: batched predictions of mask quality
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)

# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]

# Prepare output
return masks, iou_pred

def predict_masks(
self,
image_embeddings: Tensor,
image_pe: Tensor,
sparse_prompt_embeddings: Tensor,
dense_prompt_embeddings: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Predicts masks.

See 'forward' for more details.
"""
# Concatenate output tokens
output_tokens = torch.cat(
[self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(
sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape

# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1:(1 + self.num_mask_tokens), :]

# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](
mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
b, -1, h, w)

# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

return masks, iou_pred


# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLP(nn.Module):

def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output

def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x