In [None]:
using_colab = True
if using_colab:
    import torch
    import torchvision
    from google.colab.patches import cv2_imshow
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg

    # !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

## Set-up

In [None]:
#@title visualization utils
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))


## Example image

In [None]:
image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print('image.shape', image.shape)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()

## Selecting objects with SAM

In [None]:
import sys
from typing import Optional, Tuple
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

class CustomSamPredictor(SamPredictor):
    @torch.no_grad()
    def set_torch_image(
        self,
        transformed_image,
        original_image_size,
    ) -> None:
        assert (
            len(transformed_image.shape) == 4
            and transformed_image.shape[1] == 3
            and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
        ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
        self.reset_image()

        self.original_size = original_image_size
        self.input_size = tuple(transformed_image.shape[-2:])
        input_image = self.model.preprocess(transformed_image)
        self.features = self.model.image_encoder(input_image)
        self.input_image = input_image
        self.transformed_image = transformed_image
        self.is_image_set = True
predictor = CustomSamPredictor(sam)

In [None]:
predictor.set_image(image)

In [None]:
input_point = np.array([[500, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
print('image.shape', image.shape)
print('image.max', image.max())
print('image.min', image.min())
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()

In [None]:
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

In [None]:
print(scores)

In [None]:
from google.colab.patches import cv2_imshow
cv2_imshow((logits[0] > 0)*255)

In [None]:
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()


In [None]:
logits[0]

In [None]:
masks_2, scores_2, logits_2 = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
    mask_input=logits[:1]
)

In [None]:
for i, (mask, score) in enumerate(zip(masks_2, scores_2)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()


In [None]:
state_dict = sam.state_dict()
print(sam)

In [None]:
from tabulate import tabulate
import copy
torch_weights = copy.deepcopy(sam.state_dict())
table = []
num_params = 0
for k in sorted(torch_weights):
  if 'mask_downscaling' in k:
    continue
  v = torch_weights[k]
  table.append((k, f'{v.shape}', f'{v.mean():.3f}', f'{v.std():.3f}'))
  num_params += np.prod(np.asarray(v.shape))
table_str = tabulate(
    table, tablefmt="pipe", headers=["Names", "shape", "mean", "std"])
print(table_str)
print('num_params', num_params)

In [None]:
!pip install ml_collections

In [None]:
#@title Jax image-encoder
"""ViT with windows attention."""

import functools
from typing import Any, Optional

import flax.linen as nn
import jax
import jax.numpy as jnp

KERNEL_INIT = {
    'normal': nn.initializers.normal(stddev=0.02),
}


class HMAttention(nn.Module):
  """Multi-head Attention block with relative position embeddings.

  Attributes:
  dim (int): Number of input channels.
  num_heads (int): Number of attention heads.
  qkv_bias (bool:  If True, add a learnable bias to query, key, value.
  beit_like_qkv_bias (bool): no bias for k.
  use_rel_pos (bool): If True, add relative positional embeddings to the
    attention map.
  rel_pos_zero_init (bool): If True, zero initialize relative positional
    parameters.
  input_size (int or None): Input resolution for calculating the relative
    positional parameter size.
  """
  dim: int
  num_heads: int = 8
  qkv_bias: bool = True
  beit_like_qkv_bias: bool = False
  use_rel_pos: bool = False
  rel_pos_zero_init: bool = True
  input_size: Optional[Any] = None
  kernel_init: str = 'normal'
  dtype: jnp.dtype = jnp.float32

  def get_rel_pos(self, q_size, k_size, rel_pos):
    """Get relative positional embeddings.

    Args:
      q_size (int): size of query q.
      k_size (int): size of key k.
      rel_pos (Tensor): relative position embeddings (L, C).
    Returns:
      Extracted positional embeddings according to relative positions.
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
      # Interpolate rel pos.
      rel_pos_resized = jax.image.resize(
          rel_pos,
          shape=(max_rel_dist, rel_pos.shape[1]),
          method='linear',
      )
    else:
      rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = jnp.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = jnp.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(
        q_size / k_size, 1.0)
    relative_coords = relative_coords.astype(jnp.int32).reshape(-1)
    return rel_pos_resized[relative_coords].reshape(q_size, k_size, -1)

  def add_decomposed_rel_pos(
      self, attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
    """Calculate decomposed Relative Positional Embeddings from paper:`mvitv2`.

    Args:
      attn (Tensor): attention map.
      q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
      rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
      rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
      q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
      k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
    Returns:
      attn (Tensor): attention map with added relative positional embeddings.
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    rh = self.get_rel_pos(q_h, k_h, rel_pos_h)
    rw = self.get_rel_pos(q_w, k_w, rel_pos_w)

    batch, _, dim = q.shape
    r_q = q.reshape(batch, q_h, q_w, dim)
    rel_h = jnp.einsum('bhwc,hkc->bhwk', r_q, rh)
    rel_w = jnp.einsum('bhwc,wkc->bhwk', r_q, rw)

    attn = (
        attn.reshape(batch, q_h, q_w, k_h, k_w) + rel_h[
            :, :, :, :, None] + rel_w[:, :, :, None, :]
    ).reshape(batch, q_h * q_w, k_h * k_w)

    return attn

  @nn.compact
  def __call__(self, x):
    batch, height, width, _ = x.shape
    head_dim = self.dim // self.num_heads
    if self.beit_like_qkv_bias:
      q_bias = self.param(
          'q_bias', nn.initializers.zeros, (self.dim,))
      v_bias = self.param(
          'v_bias', nn.initializers.zeros, (self.dim,))
      k_bias = jnp.zeros((self.dim,), dtype=jnp.float32)
      qkv_bias = jnp.concatenate([q_bias, k_bias, v_bias], axis=0)
      qkv = nn.Dense(
          self.dim * 3, use_bias=False, dtype=self.dtype,
          kernel_init=KERNEL_INIT[self.kernel_init], name='qkv')(
              x)  # batch x height x width x 3dim
      qkv = qkv + qkv_bias[None, None, None, :]
    else:
      qkv = nn.Dense(
          self.dim * 3, use_bias=self.qkv_bias, dtype=self.dtype,
          kernel_init=KERNEL_INIT[self.kernel_init], name='qkv')(
              x)  # batch x height x width x 3dim
    qkv = qkv.reshape(batch, height * width, 3, self.num_heads, -1).transpose(
        2, 0, 3, 1, 4)  # 3 x batch x num_heads x num_tokens x D
    qkv = qkv.reshape(3, batch * self.num_heads, height * width, -1)
    q, k, v = qkv[0], qkv[1], qkv[2]  # [batch * num_heads, num_tokens, D]
    attn = (q * (head_dim ** -0.5)) @ k.transpose(
        0, 2, 1)  # [batch * num_heads, num_tokens, num_tokens]
    if self.use_rel_pos:
      rel_pos_h = self.param(
          'rel_pos_h', nn.initializers.zeros,
          (2 * self.input_size[0] - 1, head_dim))
      rel_pos_w = self.param(
          'rel_pos_w', nn.initializers.zeros,
          (2 * self.input_size[0] - 1, head_dim))
      attn = self.add_decomposed_rel_pos(
          attn, q, rel_pos_h, rel_pos_w,
          (height, width), (height, width))
    attn = jax.nn.softmax(attn)
    x = (attn @ v).reshape(batch, self.num_heads, height, width, -1).transpose(
        0, 2, 3, 1, 4).reshape(batch, height, width, -1)
    x = nn.Dense(
        self.dim, dtype=self.dtype, kernel_init=KERNEL_INIT[self.kernel_init],
        name='proj')(x)
    return x


class Mlp(nn.Module):
  """Multilayer perceptron."""

  hidden_features: int
  out_features: int
  kernel_init: str = 'normal'
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(
        self.hidden_features, dtype=self.dtype,
        kernel_init=KERNEL_INIT[self.kernel_init], name='lin1')(x)
    x = nn.gelu(x, approximate=False)
    x = nn.Dense(
        self.out_features, dtype=self.dtype,
        kernel_init=KERNEL_INIT[self.kernel_init], name='lin2')(x)
    return x


class Block(nn.Module):
  """Transformer blocks with support of window attention and residual blocks.

  Attributes:
    dim (int): Number of input channels.
    num_heads (int): Number of attention heads in each ViT block.
    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
    qkv_bias (bool): If True, add a learnable bias to query, key, value.
    beit_like_qkv_bias (bool): no bias for k.
    drop_path (float): Stochastic depth rate.
    use_rel_pos (bool): If True, add relative positional embeddings to the
      attention map.
    rel_pos_zero_init (bool): If True, zero initialize relative positional
      parameters.
    window_size (int): Window size for window attention blocks. If it equals 0,
      then not use window attention.
    input_size (int or None): Input resolution for calculating the relative
      positional parameter size.
  """
  dim: int
  num_heads: int
  mlp_ratio: float = 4.0
  qkv_bias: bool = True
  beit_like_qkv_bias: bool = False
  drop_path: float = 0.0
  use_rel_pos: bool = False
  rel_pos_zero_init: bool = True
  window_size: int = 0
  input_size: Optional[Any] = None
  kernel_init: str = 'normal'
  layer_scale_init_value: float = -1.0
  dtype: jnp.dtype = jnp.float32

  def window_partition(self, x):
    """Partition into non-overlapping windows with padding if needed.

    Args:
      x (array): input tokens with [B, H, W, C].
    Returns:
      windows: windows after partition with [B * num_windows, window_size,
        window_size, C].
      (Hp, Wp): padded height and width before partition
    """
    batch, h, w, c = x.shape

    pad_h = (self.window_size - h % self.window_size) % self.window_size
    pad_w = (self.window_size - w % self.window_size) % self.window_size
    if pad_h > 0 or pad_w > 0:
      x = jnp.pad(
          x, ((0, 0), (0, pad_w), (0, pad_h), (0, 0)),
          'constant', constant_values=0)
    hp, wp = h + pad_h, w + pad_w

    x = x.reshape(
        batch, hp // self.window_size, self.window_size,
        wp // self.window_size, self.window_size, c)
    windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(
        -1, self.window_size, self.window_size, c)
    return windows, (hp, wp)

  def window_unpartition(self, windows, pad_hw, hw):
    """Window unpartition into original sequences and removing padding.

    Args:
      windows (array): inputs: [B * num_windows, window_size, window_size, C].
      pad_hw (Tuple): padded height and width (Hp, Wp).
      hw (Tuple): original height and width (H, W) before padding.

    Returns:
      x: unpartitioned sequences with [B, H, W, C].
    """
    hp, wp = pad_hw
    h, w = hw
    batch = windows.shape[0] // (
        hp * wp // self.window_size // self.window_size)
    x = windows.reshape(
        batch,
        hp // self.window_size, wp // self.window_size,
        self.window_size, self.window_size, -1)
    x = x.transpose(0, 1, 3, 2, 4, 5).reshape(batch, hp, wp, -1)
    if hp > h or wp > w:
      x = x[:, :h, :w, :]
    return x

  def get_keep_pattern(self,
                       x: jnp.ndarray,
                       deterministic: bool):
    """DropPath Layer."""
    if not deterministic and self.drop_path:
      shape = (x.shape[0],) + (1,) * (x.ndim - 1)
      drop_pattern = jax.random.bernoulli(
          self.make_rng('dropout'), self.drop_path, shape).astype(self.dtype)
      keep_pattern = (1. - drop_pattern)
      if self.drop_path < 1.:
        keep_pattern = keep_pattern / (1. - self.drop_path)
      return keep_pattern
    else:
      return 1.0

  @nn.compact
  def __call__(self, x, train=False):
    shortcut = x
    ln = functools.partial(nn.LayerNorm, epsilon=1e-6, dtype=self.dtype)
    x = ln(name='norm1')(x)
    # Window partition
    if self.window_size > 0:
      h, w = x.shape[1], x.shape[2]
      x, pad_hw = self.window_partition(x)

    x = HMAttention(
        self.dim,
        num_heads=self.num_heads,
        qkv_bias=self.qkv_bias,
        beit_like_qkv_bias=self.beit_like_qkv_bias,
        use_rel_pos=self.use_rel_pos,
        rel_pos_zero_init=self.rel_pos_zero_init,
        input_size=self.input_size if self.window_size == 0 else (
            self.window_size, self.window_size),
        kernel_init=self.kernel_init,
        dtype=self.dtype,
        name='attn')(x)
    # Reverse window partition
    if self.window_size > 0:
      x = self.window_unpartition(x, pad_hw, (h, w))

    if self.layer_scale_init_value > 0:
      gamma_1 = self.param(
          'gamma_1',
          nn.initializers.constant(self.layer_scale_init_value),
          (self.dim))
      x = x * gamma_1[..., :]
    x = shortcut + self.get_keep_pattern(x, not train) * x
    y = ln(name='norm2')(x)
    y = Mlp(
        int(self.dim * self.mlp_ratio),
        self.dim,
        kernel_init=self.kernel_init,
        dtype=self.dtype,
        name='mlp')(y)
    if self.layer_scale_init_value > 0:
      gamma_2 = self.param(
          'gamma_2',
          nn.initializers.constant(self.layer_scale_init_value),
          (self.dim))
      y = y * gamma_2[..., :]
    x = x + self.get_keep_pattern(y, not train) * y
    return x


class Neck(nn.Module):
  """Sam convolutional neck blocks."""
  out_chans: int = 768
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self, x):
    """Forward pass.

    Args:
      x: (batch_size, height, width, dim)
    Returns:
      x: (batch_size, height, width, dim)
    """
    x = nn.Conv(
        self.out_chans,
        (1, 1),
        strides=(1, 1),
        padding='VALID',
        use_bias=False,
        dtype=self.dtype,
        name='0')(x)
    x = nn.LayerNorm(name='1')(x)
    x = nn.Conv(
        self.out_chans,
        (3, 3),
        strides=(1, 1),
        padding=[(1, 1), (1, 1)],
        use_bias=False,
        dtype=self.dtype,
        name='2')(x)
    x = nn.LayerNorm(name='3')(x)
    return x


class ImageEncoderViT(nn.Module):
  """This ViT model in Sam.

  TODO(zhouxy): check difference from ViTDet:
    - neck block after transformers
    - no droppath (inference only?).

  Attributes:
    img_size (int): Input image size.
    patch_size (int): Patch size.
    in_chans (int): Number of input image channels.
    embed_dim (int): Patch embedding dimension.
    depth (int): Depth of ViT.
    num_heads (int): Number of attention heads in each ViT block.
    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
    out_chans (int): output channals
    qkv_bias (bool): If True, add a learnable bias to query, key, value.
    beit_like_qkv_bias (bool): no bias for k.
    drop_path_rate (float): Stochastic depth rate.
    use_abs_pos (bool): If True, use absolute positional embeddings.
    use_rel_pos (bool): If True, add relative positional embeddings to the
      attention map.
    rel_pos_zero_init (bool): If True, zero initialize relative positional
      parameters.
    window_size (int): Window size for window attention blocks.
    window_block_indexes (list): Indexes for blocks using window attention.
    pretrain_img_size (int): input image size for pretraining models.
  """
  img_size: int = 1024
  patch_size: int = 16
  in_chans: int = 3
  embed_dim: int = 768
  depth: int = 12
  num_heads: int = 12
  mlp_ratio: float = 4.0
  out_chans: int = 256
  qkv_bias: bool = True
  beit_like_qkv_bias: bool = False
  drop_path_rate: float = 0.1
  use_abs_pos: bool = True
  use_rel_pos: bool = True
  rel_pos_zero_init: bool = True
  window_size: int = 14
  window_block_indexes: Any = (0, 1, 3, 4, 6, 7, 9, 10)
  pretrain_img_size: int = 224
  kernel_init: str = 'normal'
  layer_scale_init_value: float = -1.0
  freeze_vit_layer: int = -1
  use_ln_pre: bool = False
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self,
               x: jnp.ndarray,
               train: bool = False,):
    """Forward vit.

    Args:
      x: (batch_size, H, W, 3)
      train: bool
    Returns:
      x: (batch_size, H // patch_size, W // patch_size, embed_dim)
    """
    x = nn.Conv(
        self.embed_dim, (self.patch_size, self.patch_size),
        strides=(self.patch_size, self.patch_size),
        padding='VALID',
        dtype=self.dtype,
        name='patch_embed.proj')(x)
    if self.use_abs_pos:
      pos_embed = self.param(
          'pos_embed', nn.initializers.zeros,
          (1, self.img_size // self.patch_size,
           self.img_size // self.patch_size, self.embed_dim))
      x = x + pos_embed
    dp_rates = [
        self.drop_path_rate * i / (self.depth - 1) for i in range(self.depth)]
    if self.use_ln_pre:
      x = nn.LayerNorm(name='ln_pre')(x)

    for i in range(self.depth):
      x = Block(
          dim=self.embed_dim,
          num_heads=self.num_heads,
          mlp_ratio=self.mlp_ratio,
          qkv_bias=self.qkv_bias,
          beit_like_qkv_bias=self.beit_like_qkv_bias,
          drop_path=dp_rates[i],
          use_rel_pos=self.use_rel_pos,
          rel_pos_zero_init=self.rel_pos_zero_init,
          window_size=self.window_size if i in self.window_block_indexes else 0,
          input_size=(
              self.img_size // self.patch_size,
              self.img_size // self.patch_size),
          kernel_init=self.kernel_init,
          dtype=self.dtype,
          layer_scale_init_value=self.layer_scale_init_value,
          name=f'blocks.{i}',
          )(x, train=train)
      if i + 1 < self.freeze_vit_layer:
        x = jax.lax.stop_gradient(x)

    x = Neck(out_chans=self.out_chans, name='neck')(x)
    return x


In [None]:
#@title Jax transformer
"""Sam transformer for running cross-attention."""

import math
from typing import Any

import flax.linen as nn
import jax.numpy as jnp


class TwoWayTransformer(nn.Module):
  """Transformer with query and key/ value inputs."""

  depth: int = 2
  embedding_dim: int = 256
  num_heads: int = 8
  mlp_dim: int = 2048
  activation: Any = nn.relu
  attention_downsample_rate: int = 2

  def setup(self):
    layers = []
    for i in range(self.depth):
      layer = TwoWayAttentionBlock(
          embedding_dim=self.embedding_dim,
          num_heads=self.num_heads,
          mlp_dim=self.mlp_dim,
          activation=self.activation,
          attention_downsample_rate=self.attention_downsample_rate,
          skip_first_layer_pe=(i == 0),
          name=f'layers.{i}')
      layers.append(layer)
    self.layers = layers

    self.final_attn_token_to_image = Attention(
        self.embedding_dim, self.num_heads, self.attention_downsample_rate,
        name='final_attn_token_to_image')
    self.norm_final_attn = nn.LayerNorm(epsilon=1e-5, name='norm_final_attn')

  def __call__(self, image_embedding, image_pe, point_embedding):
    """Forward pass.

    Args:
      image_embedding: (batch_size, h, w, embedding_dim)
      image_pe: (batch_size, h, w, embedding_dim)
      point_embedding: (batch_size, num_points, embedding_dim)
    Returns:
    """
    batch_size, c = image_embedding.shape[0], image_embedding.shape[3]
    image_embedding = image_embedding.reshape((batch_size, -1, c))
    image_pe = image_pe.reshape((batch_size, -1, c))

    # Prepare queries
    queries = point_embedding
    keys = image_embedding

    # Apply transformer blocks and final layernorm
    for layer in self.layers:
      queries, keys = layer(
          queries=queries,
          keys=keys,
          query_pe=point_embedding,
          key_pe=image_pe,
      )

    # Apply the final attention layer from the points to the image
    q = queries + point_embedding
    k = keys + image_pe
    attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm_final_attn(queries)

    return queries, keys


class TwoWayAttentionBlock(nn.Module):
  """Transformer block."""

  embedding_dim: int
  num_heads: int
  mlp_dim: int = 2048
  activation: Any = nn.relu
  attention_downsample_rate: int = 2
  skip_first_layer_pe: bool = False

  def setup(self):
    self.self_attn = Attention(
        self.embedding_dim, self.num_heads, name='self_attn')
    self.norm1 = nn.LayerNorm(epsilon=1e-5, name='norm1')

    self.cross_attn_token_to_image = Attention(
        self.embedding_dim, self.num_heads, self.attention_downsample_rate,
        name='cross_attn_token_to_image')
    self.norm2 = nn.LayerNorm(epsilon=1e-5, name='norm2')

    self.mlp = MLPBlock(
        self.embedding_dim, self.mlp_dim, self.activation,
        name='mlp')
    self.norm3 = nn.LayerNorm(epsilon=1e-5, name='norm3')

    self.norm4 = nn.LayerNorm(epsilon=1e-5, name='norm4')
    self.cross_attn_image_to_token = Attention(
        self.embedding_dim, self.num_heads, self.attention_downsample_rate,
        name='cross_attn_image_to_token')

  def __call__(self, queries, keys, query_pe, key_pe):
    """Forward two-way attention block.

    Args:
      queries: (batch_size, query_tokens, embedding_dim)
      keys: (batch_size, key_tokens, embedding_dim)
      query_pe: (batch_size, query_tokens, embedding_dim)
      key_pe: (batch_size, key_tokens, embedding_dim)
    Returns:
      queries: (batch_size, query_tokens, embedding_dim)
      keys: (batch_size, key_tokens, embedding_dim)
    """
    # Self attention block
    if self.skip_first_layer_pe:
      queries = self.self_attn(q=queries, k=queries, v=queries)
    else:
      q = queries + query_pe
      attn_out = self.self_attn(q=q, k=q, v=queries)
      queries = queries + attn_out
    queries = self.norm1(queries)

    # Cross attention block, tokens attending to image embedding
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm2(queries)

    # MLP block
    mlp_out = self.mlp(queries)
    queries = queries + mlp_out
    queries = self.norm3(queries)

    # Cross attention block, image embedding attending to tokens
    q = queries + query_pe
    k = keys + key_pe
    attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
    keys = keys + attn_out
    keys = self.norm4(keys)

    return queries, keys


class Attention(nn.Module):
  """Attention module."""
  embedding_dim: int
  num_heads: int
  downsample_rate: int = 1

  def setup(self):
    self.internal_dim = self.embedding_dim // self.downsample_rate
    assert self.internal_dim % self.num_heads == 0, (
        'num_heads must divide embedding_dim.')

    self.q_proj = nn.Dense(self.internal_dim, name='q_proj')
    self.k_proj = nn.Dense(self.internal_dim, name='k_proj')
    self.v_proj = nn.Dense(self.internal_dim, name='v_proj')
    self.out_proj = nn.Dense(self.embedding_dim, name='out_proj')

  def _separate_heads(self, x):
    b, n, c = x.shape
    x = x.reshape(b, n, self.num_heads, c // self.num_heads)
    return x.transpose((0, 2, 1, 3))  # B x N_heads x N_tokens x C_per_head

  def _recombine_heads(self, x):
    b, n_heads, n_tokens, c_per_head = x.shape
    x = x.transpose((0, 2, 1, 3))
    return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

  def __call__(self, q, k, v):
    """Forward attention module.

    Args:
      q: (batch_size, query_tokens, embedding_dim)
      k: (batch_size, key_tokens, embedding_dim)
      v: (batch_size, key_tokens, embedding_dim)
    Returns:
      out: (batch_size, query_tokens, embedding_dim)
    """
    # Input projections
    q = self.q_proj(q)
    k = self.k_proj(k)
    v = self.v_proj(v)

    # Separate into heads
    q = self._separate_heads(q)  # (batch_size, num_heads, n, c_per_head)
    k = self._separate_heads(k)  # (batch_size, num_heads, m, c_per_head)
    v = self._separate_heads(v)  # (batch_size, num_heads, m, c_per_head)

    # Attention
    _, _, _, c_per_head = q.shape
    attn = jnp.matmul(
        q, k.transpose((0, 1, 3, 2)))  # B x N_heads x N_tokens x N_tokens
    attn = attn / math.sqrt(c_per_head)
    attn = nn.softmax(attn, axis=-1)

    # Get output
    out = jnp.matmul(attn, v)
    out = self._recombine_heads(out)
    out = self.out_proj(out)

    return out


class MLPBlock(nn.Module):
  embedding_dim: int
  mlp_dim: int
  activation: Any = nn.relu

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(self.mlp_dim, name='lin1')(x)
    x = self.activation(x)
    x = nn.Dense(self.embedding_dim, name='lin2')(x)
    return x


In [None]:
#@title Jax mask decoder.
r"""Sam mask decoder.

Pytorch reference:

https://github.com/facebookresearch/segment-anything/blob/HEAD/\
segment_anything/modeling/mask_decoder.py

"""

import flax.linen as nn
import jax.numpy as jnp
# from scenic.projects.segment_anything.modeling import transformer


class MaskDecoder(nn.Module):
  """Sam mask decoder."""

  transformer_dim: int = 256
  num_multimask_outputs: int = 3
  iou_head_depth: int = 3
  iou_head_hidden_dim: int = 256

  def setup(self):
    self.iou_token = self.param(
        'iou_token.weight',
        nn.initializers.normal(stddev=1.),
        (1, self.transformer_dim))
    self.mask_tokens = self.param(
        'mask_tokens.weight',
        nn.initializers.normal(stddev=1.),
        (self.num_multimask_outputs + 1, self.transformer_dim))
    self.output_upscaling = OutputScaling(
        transformer_dim=self.transformer_dim, name='output_upscaling')

    self.output_hypernework_mlps = [
        MLP(hidden_dim=self.iou_head_hidden_dim,
            output_dim=self.transformer_dim // 8, num_layers=3,
            name=f'output_hypernetworks_mlps.{i}',
           ) for i in range(self.num_multimask_outputs + 1)]

    self.iou_prediction_head = MLP(
        hidden_dim=self.iou_head_hidden_dim,
        output_dim=self.num_multimask_outputs + 1,
        num_layers=self.iou_head_depth,
        name='iou_prediction_head')

    self.transformer = TwoWayTransformer(name='transformer')

  def predict_masks(
      self, image_embeddings, image_pe,
      sparse_prompt_embeddings, dense_prompt_embeddings):
    """Predict masks for a single image.

    Args:
      image_embeddings: (H, W, embed_dim)
      image_pe: (H, W, embed_dim)
      sparse_prompt_embeddings: (num_prompts, num_points, embed_dim)
      dense_prompt_embeddings: (num_prompts, H, W, embed_dim)
    Returns:
      masks: (num_prompts, num_multimask_outputs + 1, h', w')
      iou_pred: (num_prompts, num_multimask_outputs + 1)
    """
    output_tokens = jnp.concatenate(
        [self.iou_token, self.mask_tokens],
        axis=0)  # (num_multimask_outputs + 2, transformer_dim)
    num_prompts = sparse_prompt_embeddings.shape[0]
    output_tokens = jnp.broadcast_to(
        output_tokens[None],
        (num_prompts, self.num_multimask_outputs + 2, self.transformer_dim))
    tokens = jnp.concatenate(
        [output_tokens, sparse_prompt_embeddings], axis=1,
    )  # (num_prompts, num_multimask_outputs + 2 + num_points, embed_dim)

    src = jnp.repeat(
        image_embeddings[None], tokens.shape[0],
        axis=0)  # (num_prompts, H, W, D)
    src = src + dense_prompt_embeddings
    pos_src = jnp.repeat(
        image_pe[None], tokens.shape[0], axis=0)  # (num_prompts, H, W, D)
    num_prompts, h, w, d = src.shape

    hs, src = self.transformer(src, pos_src, tokens)
    iou_token_out = hs[:, 0, :]
    mask_tokens_out = hs[:, 1: (1 + self.num_multimask_outputs + 1), :]

    src = src.reshape(num_prompts, h, w, d)
    upscaled_embedding = self.output_upscaling(src)  # (num_prompts, h', w', d)
    hyper_in_list = []
    for i in range(self.num_multimask_outputs + 1):
      hyper_in_list.append(
          self.output_hypernework_mlps[i](
              mask_tokens_out[:, i, :])  # (num_prompts, d)
      )
    hyper_in = jnp.stack(hyper_in_list, axis=1)  # (num_prompts, num_masks, d)
    num_prompts, h, w, d = upscaled_embedding.shape
    masks = hyper_in @ upscaled_embedding.reshape(
        num_prompts, h * w, d).transpose(
            0, 2, 1)  # (num_prompts, num_masks, h'w')
    masks = masks.reshape(num_prompts, self.num_multimask_outputs + 1, h, w)

    iou_pred = self.iou_prediction_head(iou_token_out)
    return masks, iou_pred

  @nn.compact
  def __call__(
      self, image_embeddings, image_pe,
      sparse_prompt_embeddings, dense_prompt_embeddings,
      multimask_output: bool = True):
    """Forward model for a single image.

    Args:
      image_embeddings: (H, W, 3)
      image_pe: (H, W, D)
      sparse_prompt_embeddings: (num_prompts, num_points, embed_dim)
      dense_prompt_embeddings: (num_prompts, H, W, embed_dim)
      multimask_output: bool
    Returns:
      masks: (num_prompts, num_multimask_outputs, h', w'),
        num_multimask_outputs = 3 if multimask_output is True, otherwise 1.
      iou_pred: (num_prompts, num_multimask_outputs)
    """
    masks, iou_pred = self.predict_masks(
        image_embeddings=image_embeddings,
        image_pe=image_pe,
        sparse_prompt_embeddings=sparse_prompt_embeddings,
        dense_prompt_embeddings=dense_prompt_embeddings,
    )
    if multimask_output:
      return masks[:, 1:], iou_pred[:, 1:]
    else:
      return masks[:, :1], iou_pred[:, :1]


class MLP(nn.Module):
  hidden_dim: int
  output_dim: int
  num_layers: int

  @nn.compact
  def __call__(self, x):
    for i in range(self.num_layers - 1):
      x = nn.Dense(self.hidden_dim, name=f'layers.{i}')(x)
      x = nn.relu(x)
    x = nn.Dense(self.output_dim, name=f'layers.{self.num_layers - 1}')(x)
    return x


class OutputScaling(nn.Module):
  """Output scaling."""
  transformer_dim: int

  @nn.compact
  def __call__(self, x):
    x = nn.ConvTranspose(
        self.transformer_dim // 4, kernel_size=(2, 2), strides=(2, 2),
        transpose_kernel=True,
        name='0')(x)
    x = nn.LayerNorm(name='1')(x)
    x = nn.gelu(x, approximate=False)
    x = nn.ConvTranspose(
        self.transformer_dim // 8, kernel_size=(2, 2), strides=(2, 2),
        transpose_kernel=True,
        name='3')(x)
    x = nn.gelu(x, approximate=False)
    return x


In [None]:
#@title Jax prompt-encoder
r"""Sam prompt encoder.

Pytorch reference:

https://github.com/facebookresearch/segment-anything/blob/HEAD/\
segment_anything/modeling/prompt_encoder.py

"""

from typing import Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp


class PromptEncoder(nn.Module):
  """Sam prompt encoder for points and boxes."""

  embed_dim: int = 256
  image_embedding_size: Tuple[int, int] = (1024 // 16, 1024 // 16)
  input_image_size: Tuple[int, int] = (1024, 1024)
  num_point_embeddings: int = 4  # pos/neg point + 2 box corners
  mask_in_chans: int = 16

  def setup(self):
    self.pe_layer = PositionEmbeddingRandom(
        self.embed_dim // 2, name='pe_layer')
    point_embeddings = []
    # TODO(zhouxy): check if `nn.initializers.normal(stddev=1.)` is the same as
    # pytorch nn.Embedding default initialization.
    for i in range(self.num_point_embeddings):
      point_embeddings.append(self.param(
          f'point_embeddings.{i}.weight',
          nn.initializers.normal(stddev=1.),
          (1, self.embed_dim)))
    self.point_embeddings = point_embeddings
    del point_embeddings
    self.not_a_point_embed = self.param(
        'not_a_point_embed.weight',
        nn.initializers.normal(stddev=1.),
        (1, self.embed_dim))
    self.no_mask_embed = self.param(
        'no_mask_embed.weight',
        nn.initializers.normal(stddev=1.),
        (1, self.embed_dim))
    self.mask_downscaling = MaskDownScaling(
        mask_in_chans=self.mask_in_chans, embed_dim=self.embed_dim,
        name='mask_downscaling')

  def get_dense_pe(self):
    return self.pe_layer(self.image_embedding_size)

  def _embed_points(self, points, labels, pad):
    """Embed points.

    Args:
      points: (num_prompts, num_points, 2). In absolute coordinates.
      labels: (num_prompts, num_points)
      pad: bool
    Returns:
      point_embeddings: (num_prompts, num_points, embed_dim)
    """
    # Shift to center of pixel following:
    # https://github.com/facebookresearch/segment-anything/blob/main/\
    # segment_anything/modeling/prompt_encoder.py#L80
    points = points + 0.5
    if pad:
      padding_point = jnp.zeros((points.shape[0], 1, 2), dtype=jnp.float32)
      padding_label = - jnp.ones((labels.shape[0], 1), dtype=jnp.float32)
      points = jnp.concatenate([points, padding_point], axis=1)
      labels = jnp.concatenate([labels, padding_label], axis=1)
    point_embedding = self.pe_layer.forward_with_coords(
        points, self.input_image_size)  # (num_prompts, num_points, embed_dim)
    ignored_points = labels[..., None] == -1  # (num_prompts, num_points, 1)
    point_embedding = point_embedding * (1 - ignored_points) + (
        self.not_a_point_embed[None] * ignored_points)
    neg_points = labels[..., None] == 0  # (num_prompts, num_points, 1)
    point_embedding += neg_points * self.point_embeddings[0][None]
    pos_points = labels[..., None] == 1  # (num_prompts, num_points, 1)
    point_embedding += pos_points * self.point_embeddings[1][None]
    return point_embedding

  def _embed_boxes(self, boxes):
    boxes = boxes + 0.5
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(
        coords, self.input_image_size)  # (num_prompts, 2, embed_dim)
    lt_emb = corner_embedding[:, 0, :] + self.point_embeddings[2]
    rb_emb = corner_embedding[:, 1, :] + self.point_embeddings[3]
    corner_embedding = jnp.stack(
        [lt_emb, rb_emb], axis=1)  # (num_prompts, 2, embed_dim)
    return corner_embedding

  def _embed_masks(self, masks):
    mask_embedding = self.mask_downscaling(masks)
    return mask_embedding

  @nn.compact
  def __call__(self, points, point_labels, boxes=None, masks=None):
    """Forward pass. Currently only supports points prompt.

    Args:
      points: (num_prompts, num_points, 2)
      point_labels: (num_prompts, num_points): labels of each point. 1 means
        positive points, 0 means negative points (shouldn't be included in the
        mask), and -1 means padded/ ignored points.
      boxes: (num_prompts, 4) or None
      masks: (num_prompts, height, width) or None
    Returns:
      point_embeddings: (num_prompts, num_points, embed_dim)
      dense_embeddings: (num_prompts, H, W, embed_dim)
    """
    num_prompts = points.shape[0] if points is not None else (
        boxes.shape[0] if boxes is not None else masks.shape[0])
    sparse_embeddings = jnp.zeros(
        (num_prompts, 0, self.embed_dim), dtype=jnp.float32)
    if points is not None:
      assert boxes is None
      point_embeddings = self._embed_points(
          points, point_labels, pad=(boxes is None))
      sparse_embeddings = point_embeddings
    if boxes is not None:
      assert points is None
      box_embeddings = self._embed_boxes(boxes)
      sparse_embeddings = box_embeddings
    if masks is not None:
      dense_embeddings = self._embed_masks(masks)
    else:
      dense_embeddings = jnp.broadcast_to(
          self.no_mask_embed[:, None, None],
          (num_prompts, self.image_embedding_size[0],
           self.image_embedding_size[1], self.embed_dim,)
      )
    return sparse_embeddings, dense_embeddings


class PositionEmbeddingRandom(nn.Module):
  """Positional encoding using random spatial frequencies."""

  num_pos_feats: int
  scale: Optional[float] = None

  def setup(self):
    scale = 1.0 if self.scale is None or self.scale <= 0.0 else self.scale
    self.positional_encoding_gaussian_matrix = self.param(
        'positional_encoding_gaussian_matrix',
        nn.initializers.normal(stddev=scale),
        (2, self.num_pos_feats)
    )

  def _pe_encoding(self, coords):
    """PE encoding."""
    # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
    coords = 2 * coords - 1
    coords = coords @ jax.lax.stop_gradient(
        self.positional_encoding_gaussian_matrix)
    coords = 2 * jnp.pi * coords
    # outputs d_1 x ... x d_n x C shape
    return jnp.concatenate([jnp.sin(coords), jnp.cos(coords)], axis=-1)

  @nn.compact
  def __call__(self, size):
    """Forward pass.

    Args:
      size: 2
    Returns:
      pe: H x W x D
    """
    h, w = size
    grid = jnp.ones((h, w), dtype=jnp.float32)
    y_embed = jnp.cumsum(grid, axis=0) - 0.5
    x_embed = jnp.cumsum(grid, axis=1) - 0.5
    y_embed = y_embed / h
    x_embed = x_embed / w
    pe = self._pe_encoding(jnp.stack([x_embed, y_embed], axis=-1))
    return pe

  def forward_with_coords(self, coords_input, image_size):
    """Forward with points.

    Args:
      coords_input: (num_prompts, num_points, 2)
      image_size: (2,)
    Returns:
      embedding: (num_prompts, num_points, self.num_pos_feats * 2)
    """
    x = coords_input[:, :, 0] / image_size[1]
    y = coords_input[:, :, 1] / image_size[0]
    return self._pe_encoding(jnp.stack([x, y], axis=-1))


class MaskDownScaling(nn.Module):
  """Mask downscaling."""
  mask_in_chans: int = 16
  embed_dim: int = 256

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(
        self.mask_in_chans // 4, kernel_size=(2, 2), strides=(2, 2),
        name='0')(x)
    x = nn.LayerNorm(name='1')(x)
    x = nn.gelu(x, approximate=False)
    x = nn.Conv(
        self.mask_in_chans, kernel_size=(2, 2), strides=(2, 2),
        name='3')(x)
    x = nn.LayerNorm(name='4')(x)
    x = nn.gelu(x, approximate=False)
    x = nn.Conv(
        self.embed_dim, kernel_size=(1, 1), strides=(1, 1),
        name='6')(x)
    return x


In [None]:
#@title Util functions
"""Util functions for Segment Anything models."""

import jax.numpy as jnp
import numpy as np
# from scenic.projects.segment_anything.modeling import nms as nms_lib


def build_point_grid(points_per_side):
  """Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
  offset = 1. / (2 * points_per_side)
  points_one_side = jnp.linspace(offset, 1 - offset, points_per_side)
  points_x = jnp.tile(points_one_side[None, :], (points_per_side, 1))
  points_y = jnp.tile(points_one_side[:, None], (1, points_per_side))
  points = jnp.stack([points_x, points_y], axis=-1).reshape(-1, 2)
  return points  # (points_per_side ** 2, 1)


def batched_mask_to_box(masks):
  """Convert binary masks in (n, h, w) to boxes (n, 4)."""
  if masks.shape[0] == 0:
    return jnp.zeros((0, 4), dtype=jnp.float32)

  h, w = masks.shape[-2:]
  in_height = jnp.max(masks, axis=-1)  # (n, h)
  in_height_coords = in_height * jnp.arange(h)[None]  # (n, h)
  bottom_edges = jnp.max(in_height_coords, axis=-1)  # (n, )
  # Mark "0" as "h" so that we can take min.
  in_height_coords = in_height_coords + h * (1 - in_height)  # (n, h)
  top_edges = jnp.min(in_height_coords, axis=-1)  # (n,)

  in_width = jnp.max(masks, axis=-2)  # (n, w)
  in_width_coords = in_width * jnp.arange(w)[None]  # (n, w)
  right_edges = jnp.max(in_width_coords, axis=-1)  # (n,)
  in_width_coords = in_width_coords + w * (1 - in_width)  # (n, w)
  left_edges = jnp.min(in_width_coords, axis=-1)

  # mark empty mask as [0, 0, 0, 0]
  empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
  out = jnp.stack(
      [left_edges, top_edges, right_edges, bottom_edges], axis=-1)  # (n, 4)
  out = out * (1 - empty_filter)[:, None]
  return out


def batched_mask_to_box_np(masks):
  """Convert binary masks in (n, h, w) to boxes (n, 4)."""
  if masks.shape[0] == 0:
    return np.zeros((0, 4), dtype=np.float32)

  h, w = masks.shape[-2:]
  in_height = np.max(masks, axis=-1)  # (n, h)
  in_height_coords = in_height * np.arange(h)[None]  # (n, h)
  bottom_edges = np.max(in_height_coords, axis=-1)  # (n, )
  # Mark "0" as "h" so that we can take min.
  in_height_coords = in_height_coords + h * (1 - in_height)  # (n, h)
  top_edges = np.min(in_height_coords, axis=-1)  # (n,)

  in_width = np.max(masks, axis=-2)  # (n, w)
  in_width_coords = in_width * np.arange(w)[None]  # (n, w)
  right_edges = np.max(in_width_coords, axis=-1)  # (n,)
  in_width_coords = in_width_coords + w * (1 - in_width)  # (n, w)
  left_edges = np.min(in_width_coords, axis=-1)

  # mark empty mask as [0, 0, 0, 0]
  empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
  out = np.stack(
      [left_edges, top_edges, right_edges, bottom_edges], axis=-1)  # (n, 4)
  out = out * (1 - empty_filter)[:, None]
  return out


def calculate_stability_score(
    mask_logits, mask_threshold, stability_score_offset):
  """The stability score measures if the mask changes with different thresh."""
  low = (mask_logits > (mask_threshold + stability_score_offset)).sum(
      axis=-1).sum(axis=-1)
  high = (mask_logits > (mask_threshold - stability_score_offset)).sum(
      axis=-1).sum(axis=-1)
  return low / high


def nms(boxes, scores, iou_threshold, num_outputs=100):
  _, _, keep = nms_lib.non_max_suppression_padded(
      scores[None], boxes[None], num_outputs, iou_threshold,
      return_idx=True)  # pytype: disable=wrong-arg-types
  return keep[0]  # undo batch


In [None]:
#@title Jax SAM model
r"""Segment Anything Model.

Pytorch reference:

https://github.com/facebookresearch/segment-anything/blob/HEAD/\
segment_anything/modeling/sam.py

"""
from typing import Any

from flax import linen as nn
# import jax
import jax.numpy as jnp
import ml_collections
import dataclasses
# from scenic.projects.segment_anything.modeling.image_encoder import ImageEncoderViT
# from scenic.projects.segment_anything.modeling.mask_decoder import MaskDecoder
# from scenic.projects.segment_anything.modeling.prompt_encoder import PromptEncoder

PIXEL_MEAN = (123.675, 116.28, 103.53)
PIXEL_STD = (58.395, 57.12, 57.375)

class Sam(nn.Module):
  """Segment anything model.

  Default parameters following
  https://github.com/facebookresearch/segment-anything/blob/main/
  segment_anything/automatic_mask_generator.py#L35

  Attributes:
    mask_threshold: threshold to convert output logits to binary masks.
    pixel_mean: used in preprocessing inputs.
    pixel_std: used in preprocessing inputs.
    max_objects: number of output objects in "segment anything" mode.
    points_per_side: number of point anchors perside in "segment anything" mode.
    points_per_batch: batch size for processing point anchors.
    pred_iou_thresh: score threshold in "segment anything" mode.
    box_nms_thresh: NMS threshold
    stability_score_thresh: threshold for filtering with a stability metric.
    stability_score_offset: used in computing the stability metric.
    pre_nms_topk: new hyper-parameter in this implementation. Used for keeping a
      fixed shape after filtering mask predictions.
    image_encoder_args: args for image backbone.
    prompt_encoder_args: args for prompt encoder.
    mask_decoder_args: args for mask decoder.
  """
  mask_threshold: float = 0.0
  pixel_mean: Any = PIXEL_MEAN
  pixel_std: Any = PIXEL_STD
  max_objects: int = 100
  points_per_side: Optional[int] = 32
  points_per_batch: int = 64
  pred_iou_thresh: float = 0.88
  box_nms_thresh: float = 0.7
  stability_score_thresh: float = 0.95
  stability_score_offset: float = 1.0
  pre_nms_topk: int = 1536
  image_encoder_args: ml_collections.ConfigDict = dataclasses.field(
      default_factory=ml_collections.ConfigDict)
  prompt_encoder_args: ml_collections.ConfigDict = dataclasses.field(
      default_factory=ml_collections.ConfigDict)
  mask_decoder_args: ml_collections.ConfigDict = dataclasses.field(
      default_factory=ml_collections.ConfigDict)

  def setup(self):
    # pylint: disable=not-a-mapping
    self.image_encoder = ImageEncoderViT(
        **self.image_encoder_args, name='image_encoder')
    self.prompt_encoder = PromptEncoder(
        **self.prompt_encoder_args, name='prompt_encoder')
    self.mask_decoder = MaskDecoder(
        **self.mask_decoder_args, name='mask_decoder')
    # pylint: enable=not-a-mapping

  @nn.compact
  def __call__(
      self, image, point_coords, point_labels, padding_mask=None,
      image_embeddings=None, boxes=None, mask_inputs=None,
      multimask_output: bool = True, return_image_embedding: bool = False,
      upsample_mask: bool = True, return_batch_as_list: bool = True,
      train: bool = False, debug: bool = False):
    """Forward Sam model.

    Args:
      image: (batch_size, H, W, 3). Input pixels in RGB values [0, 255].
      point_coords: (batch_size, num_prompts, num_points, 2). Input point
        prompts. In absolute range [0, image.shape[1 or 2]].
      point_labels: (batch_size, num_prompts, num_points). 1: positive points;
        0: negative points. -1: padded/ ignored points.
      padding_mask: (batch_size, H, W). Indicate which pixels in the input are
        padded. 1: not padded; 0: padded. This is used to match the pytorch
        preprocessing process: normalize then pad, while in Jax we need to pad
        first.
      image_embeddings: cached image embeddings if they are provided.
        (batch_size, H', W', D). If not provided, image must be not None.
      boxes: (batch_size, num_prompts, 4); box prompts;
      mask_inputs: (batch_size, num_prompts, 1, H, W); mask prompts.
      multimask_output: bool. If false, C = 1, otherwise,
        C = self.mask_decoder_args.num_multimask_outputs
      return_image_embedding: bool
      upsample_mask: bool; If False, only return the 4x downsampled masks. This
        saves memory.
      return_batch_as_list: If True, return a list where each item is the
        results of a single image; If False, return a dict with batched results.
      train: bool
      debug: bool
    Returns:
      ret: a list (batch) of dicts, each with the following keys:
        'masks': (num_prompts, C, H, W). C is the num of masks (see above).
        'iou_predictions': (num_prompts, C). Predicted mask quality scores.
        'low_res_logits': (num_prompts, C, H', W'). The output mask of the
          mask decoder. The final masks are resized from this.
    """
    del debug
    msg = 'One of "image" or "image_embedding" should be provided!'
    assert image is not None or image_embeddings is not None, msg
    assert image is None or image_embeddings is None, msg
    if image_embeddings is None:
      assert image is not None
      image_embeddings = self.get_image_embeddings(
          image, padding_mask=padding_mask,
          train=train)  # (batch_size, H', W', D)

    ret = []
    for b, curr_embedding in enumerate(image_embeddings):
      curr_point_coords = point_coords[b] if point_coords is not None else None
      curr_point_labels = point_labels[b] if point_labels is not None else None
      box_prompt = boxes[b] if boxes is not None else None
      mask_prompt = mask_inputs[b] if mask_inputs is not None else None
      sparse_embeddings, dense_embeddings = self.prompt_encoder(
          curr_point_coords, curr_point_labels,
          boxes=box_prompt, masks=mask_prompt)
      low_res_masks, iou_predictions = self.mask_decoder(
          image_embeddings=curr_embedding,
          image_pe=self.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings,
          dense_prompt_embeddings=dense_embeddings,
          multimask_output=multimask_output,
      )
      size = self.image_encoder.img_size
      out = {
          'iou_predictions': iou_predictions,
          'low_res_logits': low_res_masks,
      }
      if upsample_mask:
        masks = self.postprocess_masks(
            low_res_masks, size, size) > self.mask_threshold
        out['masks'] = masks
      ret.append(out)
    if return_image_embedding:
      for batch_i, image_embedding in enumerate(image_embeddings):
        ret[batch_i]['image_embedding'] = image_embedding
    if not return_batch_as_list:
      ret = {k: jnp.stack([ret[i][k] for i in range(len(ret))], axis=0)
             for k in ret[0].keys()}
    return ret

  def get_image_embeddings(self, image, padding_mask=None, train=False):
    image = self.preprocess(image, padding_mask)  # (batch_size, H, W, 3)
    image_embeddings = self.image_encoder(
        image, train=train)  # (batch_size, H', W', D)
    return image_embeddings

  @staticmethod
  def postprocess_masks(masks, h, w):
    """Resize masks to input resolution."""
    masks = jax.image.resize(
        masks, (masks.shape[0], masks.shape[1], h, w),
        method='bilinear', antialias=False)
    return masks

  @staticmethod
  def postprocess_to_orig(
      lowres_masks, unpad_size, orig_size, mask_threshold=0.0):
    """Resize masks to input resolution."""
    lowres_h, lowres_w = lowres_masks.shape[1:]
    unpad_h, unpad_w = unpad_size
    down_ratio = max(lowres_h, lowres_w) / max(unpad_h, unpad_w)
    h, w = int(unpad_h * down_ratio), int(unpad_w * down_ratio)
    orig_h, orig_w = orig_size

    masks = (
        jax.image.resize(
            jax.device_put(
                lowres_masks[:, :h, :w],
                device=jax.local_devices(backend='cpu')[0],
            ),
            (lowres_masks.shape[0], orig_h, orig_w),
            method='bilinear',
            antialias=False,
        )
        > mask_threshold
    )
    boxes = batched_mask_to_box_np(np.asarray(masks))
    return masks, boxes

  def preprocess(self, inputs, padding_mask=None):
    """Proprocess images. Normalize pixels for non-padded pixels."""
    mean = jnp.asarray(self.pixel_mean, dtype=jnp.float32).reshape(1, 1, 1, 3)
    std = jnp.asarray(self.pixel_std, dtype=jnp.float32).reshape(1, 1, 1, 3)
    inputs = (inputs - mean) / std
    if padding_mask is not None:
      inputs = inputs * padding_mask[..., None]  # Padded pixels remain 0
    return inputs

  def generate(
      self, image=None, padding_mask=None, upsample_mask=True,
      image_embedding=None, return_image_embedding=False):
    """Automatically generate masks for all objects.

    This function is from the original SamAutomaticMaskGenerator at
    https://github.com/facebookresearch/segment-anything/blob/HEAD/
    segment_anything/automatic_mask_generator.py.

    Here we merge it inside the Sam flax model, as we don't use a separate
    predictor class.

    Here are a few key differences compared to the original implementation:

      - The original implementation did filtering inside each prompt-batch. We
        can't do this in jax as the filtering changes the data shape. Instead,
        we do a filtering after concatenating the raw outputs from all batches,
        and use an additional parameter "pre_nms_topk" to control the output
        shape. By default "pre_nms_topk" is half of all prompts.

      - We move mask upsampling (i.e., "postprocess_masks") to the very end of
        the process (after NMS), to save peak memory. This means the box-NMS and
        the stability_score are computed on the 4x-downsampled masks. This
        introduces small errors compared to the original implementation.

      - We don't support the multi-crop testing in the original code as this is
        not enabled in the default config.

    Args:
      image: a single image, (H x W x 3)
      padding_mask: (H x W)
      upsample_mask: bool; If False, only return the 4x downsampled masks. This
        saves memory.
      image_embedding: image embeddings if they are provided. (H', W', D). If
        not provided, image must be not None.
      return_image_embedding: bool
    Returns:
      Result dict of that image, with keys:
        'masks': (self.max_objects H, W).
        'iou_predictions': (self.max_objects,). Predicted mask quality scores.
        'low_res_logits': (self.max_objects, H', W'). The output mask of the
          mask decoder. The final masks are resized from this.
        'boxes': (self.max_objects, 4). Box from the masks.
        'stability_score': (stability_score,). A measurement of how stable the
          mask is when self.mask_threshold changes.
    """
    msg = 'One of "image" or "image_embedding" should be provided!'
    assert image is not None or image_embedding is not None, msg
    assert image is None or image_embedding is None, msg
    if image_embedding is None:
      padding_mask = padding_mask if padding_mask is not None else (
          jnp.ones((image.shape[0], image.shape[1]), dtype=jnp.float32))
      image_embedding = self.get_image_embeddings(
          image[None], padding_mask=padding_mask[None])[0]  # (H', W', D)
    else:
      nopadding_msg = 'Padding_mask should be provided if using image_embedding'
      assert padding_mask is not None, nopadding_msg

    point_grid = build_point_grid(
        self.points_per_side)[:, None]  # (points_per_side ** 2, 1, 2)
    # Ignore padded region in creating grid.
    valid_h = padding_mask.max(axis=1).sum()
    valid_w = padding_mask.max(axis=0).sum()
    point_grid = point_grid * jnp.asarray(
        [valid_w, valid_h], dtype=jnp.float32).reshape(1, 1, 2)
    point_labels = jnp.ones(
        (point_grid.shape[0], point_grid.shape[1]),
        dtype=jnp.int32)  # (points_per_side ** 2, 1)

    num_prompts = point_grid.shape[0]
    bs = self.points_per_batch
    assert num_prompts % bs == 0, num_prompts
    num_batches = num_prompts // bs
    low_res_masks, iou_predictions = [], []
    for b in range(num_batches):
      in_points = point_grid[b * bs: (b + 1) * bs]
      in_labels = point_labels[b * bs: (b + 1) * bs]
      sparse_embeddings_cur, dense_embeddings_cur = self.prompt_encoder(
          in_points, in_labels)
      low_res_masks_cur, iou_predictions_cur = self.mask_decoder(
          image_embeddings=image_embedding,
          image_pe=self.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings_cur,
          dense_prompt_embeddings=dense_embeddings_cur,
          multimask_output=True,
      )  # low_res_masks: (bs, 3, h', w')
      low_res_masks.append(low_res_masks_cur)
      iou_predictions.append(iou_predictions_cur)
    ret = {}
    if return_image_embedding:
      ret['image_embedding'] = image_embedding
    del image_embedding

    low_res_masks = jnp.concatenate(
        low_res_masks, axis=0)
    iou_predictions = jnp.concatenate(iou_predictions, axis=0)
    low_res_masks = low_res_masks.reshape(
        (-1,) + low_res_masks.shape[-2:])  # (points_per_side ** 2 * 3, h', w')
    iou_predictions = iou_predictions.reshape(-1)  # (points_per_side ** 2 * 3,)
    keep_mask = iou_predictions > self.pred_iou_thresh

    # Note: the original code computes stability_score on upsampled masks.
    stability_score = calculate_stability_score(
        low_res_masks,
        self.mask_threshold, self.stability_score_offset)
    if self.stability_score_thresh > 0.0:
      keep_mask = keep_mask & (stability_score > self.stability_score_thresh)

    iou_predictions = iou_predictions * keep_mask

    _, inds = jax.lax.top_k(iou_predictions, k=self.pre_nms_topk)
    iou_predictions = jnp.take_along_axis(iou_predictions, inds, axis=0)
    low_res_masks = jnp.take_along_axis(
        low_res_masks, inds[:, None, None], axis=0)

    # Note: the original code run NMS on upsampled masks.
    low_res_boxes = batched_mask_to_box(
        low_res_masks > self.mask_threshold)
    keep_inds = nms(
        low_res_boxes, iou_predictions,
        iou_threshold=self.box_nms_thresh,
        num_outputs=self.max_objects)  # (max_objects,)
    low_res_masks = jnp.take_along_axis(
        low_res_masks, keep_inds[:, None, None], axis=0)
    ret.update({
        'iou_predictions': jnp.take_along_axis(
            iou_predictions, keep_inds, axis=0),
        'low_res_logits': low_res_masks,
        'low_res_boxes': jnp.take_along_axis(
            low_res_boxes, keep_inds[:, None], axis=0),
        'stability_score': jnp.take_along_axis(
            stability_score, keep_inds, axis=0),
    })
    if upsample_mask:
      size = self.image_encoder.img_size
      masks = self.postprocess_masks(
          low_res_masks[None], size, size)[0] > self.mask_threshold
      boxes = batched_mask_to_box(masks)
      ret['masks'] = masks
      ret['boxes'] = boxes
    return ret

  def batch_generate(self, image, padding_mask, upsample_mask=True):
    return jax.vmap(lambda x, y: self.generate(x, y, upsample_mask))(
        image, padding_mask)

In [None]:
sam_model = Sam()

In [None]:
rng = {'dropout': jax.random.PRNGKey(0), 'params': jax.random.PRNGKey(0)}
S = 1024
num_prompts, num_points = 1, 1
inp = jax.random.normal(jax.random.PRNGKey(0), (1, S, S, 3))
mask_inputs = jax.random.normal(jax.random.PRNGKey(0), (1, S // 4, S // 4, 1))
point_coords = jnp.zeros((1, num_prompts, num_points, 2), jnp.float32)
point_labels = jnp.zeros((1, num_prompts, num_points), jnp.int32)
sam_vars = sam_model.init(
    rng, inp, point_coords, point_labels, padding_mask=None,
      image_embeddings=None, boxes=None, mask_inputs=mask_inputs)

In [None]:
import flax

In [None]:
from tabulate import tabulate
import copy

flattened_tree = flax.traverse_util.flatten_dict(sam_vars['params'], sep='.')
table = []
num_params = 0
for k in sorted(flattened_tree):
  v = flattened_tree[k]
  table.append((k, f'{v.shape}', f'{v.mean():.3f}', f'{v.std():.3f}'))
  num_params += jnp.prod(jnp.asarray(v.shape))
table_str = tabulate(
    table, tablefmt="pipe", headers=["Names", "shape", "mean", "std"])
print(table_str)
print('num_params', num_params)

In [None]:
def dfs(k, v, converted_torch_weight):
  """Recursively match weights."""
  if isinstance(v, jnp.ndarray):
    if k in converted_torch_weight:
      torch_data = converted_torch_weight[k]
      if len(v.shape) == 2 and 'not_a_point_embed' not in k and \
      'positional_encoding_gaussian_matrix' not in k and \
      'rel_pos' not in k and \
      'point_embeddings' not in k and \
      'iou_token' not in k and \
      'mask_tokens' not in k and\
      'no_mask_embed' not in k:
        torch_data = np.transpose(torch_data, (1, 0))
      if len(v.shape) == 4:
        if 'output_upscaling' in k:
          torch_data = np.transpose(torch_data, (2, 3, 1, 0))
        elif 'image_encoder.pos_embed' in k:
          torch_data = torch_data
        else:
          torch_data = np.transpose(torch_data, (2, 3, 1, 0))
      if torch_data.shape != v.shape:
        print('Wrong shape! {} {} {}'.format(
            k, torch_data.shape, v.shape))
    else:
      print(f'{k} not in checkpoint')
      torch_data = v
    return [(k, torch_data.shape)], torch_data
  lst, tree = [], {}
  for kk, vv in v.items():
    if isinstance(vv, jnp.ndarray) and (
        kk == 'kernel' or kk == 'scale' or kk == 'embedding'):
      if 'proposal_generator.scales' not in k:
        new_kk = 'weight'
      else:
        new_kk = kk
    else:
      new_kk = kk
    sub_lst, sub_tree = dfs(
        '{}.{}'.format(k, new_kk) if k else new_kk,
        vv,
        converted_torch_weight)
    lst.extend(sub_lst)
    tree[kk] = sub_tree
  return lst, tree

COMMEN_NAME_MAP = []

def map_names(state_dict, name_map):
  """Change names according to a pre-defined map."""
  ret = {}
  for k, v in state_dict.items():
    new_k = k
    for ori_name, new_name in name_map:
      new_k = new_k.replace(ori_name, new_name)
    ret[new_k] = v
  return ret

converted_torch_weight = {
    k: v for k, v in torch_weights.items()}
converted_torch_weight = {k: v.cpu().numpy() for k, v in converted_torch_weight.items()}

In [None]:
ret, tree = dfs('', sam_vars['params'], converted_torch_weight)
num_params = 0
for k, v in converted_torch_weight.items():
  num_params += np.prod(v.shape)
print('#params in loaded model:', num_params)
num_params = 0
for k, v in ret:
  num_params += np.prod(v)
print('#params in converted model:', num_params)

In [None]:
from tabulate import tabulate
import copy

flattened_tree = flax.traverse_util.flatten_dict(tree, sep='.')
table = []
num_params = 0
for k in sorted(flattened_tree):
  v = flattened_tree[k]
  table.append((k, f'{v.shape}', f'{v.mean():.3f}', f'{v.std():.3f}'))
  num_params += jnp.prod(jnp.asarray(v.shape))
table_str = tabulate(
    table, tablefmt="pipe", headers=["Names", "shape", "mean", "std"])
print(table_str)
print('num_params', num_params)

In [None]:
transformed_image = predictor.transformed_image.cpu().numpy().transpose(0, 2, 3, 1)
inp = np.zeros((1, S, S, 3), np.float32)
padding_mask = np.zeros((1, S, S), np.float32)
inp[0, :transformed_image.shape[1], :transformed_image.shape[2]] = transformed_image#[..., ::-1]
padding_mask[0, :transformed_image.shape[1], :transformed_image.shape[2]] = 1
point_coords = np.asarray(input_point.copy(), dtype=np.float32).reshape(1, 1, 1, 2) # jnp.zeros((1, num_prompts, num_points, 2), jnp.float32)
point_coords[..., 0] = point_coords[..., 0] / max(image.shape[:2]) * S
point_coords[..., 1] = point_coords[..., 1] / max(image.shape[:2]) * S
point_labels = jnp.asarray(input_label).reshape(1, 1, 1) # jnp.zeros((1, num_prompts, num_points), jnp.int32)

In [None]:
ret = sam_model.apply(
    {'params': tree},
    inp,
    point_coords,
    point_labels,
    padding_mask,
    train=False)


In [None]:
transformed_image = predictor.transformed_image.cpu().numpy().transpose(0, 2, 3, 1)[0].astype(np.uint8)
for i, (mask, score) in enumerate(zip(ret[0]['masks'][0], ret[0]['iou_predictions'][0])):
    plt.figure(figsize=(10,10))
    plt.imshow(transformed_image)
    show_mask(mask, plt.gca())
    show_points(point_coords[0, 0], input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()


In [None]:
jax.tree_util.tree_map(lambda x: x.shape, ret[0])

In [None]:
ret_with_mask_prompt = sam_model.apply(
    {'params': tree},
    inp,
    point_coords,
    point_labels,
    padding_mask,
    mask_inputs=ret[0]['low_res_logits'][:, 0, :, :, None],
    train=False)

In [None]:
print(jax.tree_util.tree_map(lambda x: x.shape, ret_with_mask_prompt[0]))
transformed_image = predictor.transformed_image.cpu().numpy().transpose(0, 2, 3, 1)[0].astype(np.uint8)
for i, (mask, score) in enumerate(zip(ret_with_mask_prompt[0]['masks'][0], ret_with_mask_prompt[0]['iou_predictions'][0])):
    plt.figure(figsize=(10,10))
    plt.imshow(transformed_image)
    show_mask(mask, plt.gca())
    show_points(point_coords[0, 0], input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()


In [None]:
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
out_path = 'sam_vit_b'
checkpoints.save_checkpoint(out_path, {'params': tree}, 0)

In [None]:
# from google.colab import files
# files.download(f'{out_path}/checkpoint_0')