In [1]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pillow_avif
import torch
from PIL import Image
from torchvision.transforms.functional import (
    normalize,
    pad,
    resize,
    to_pil_image,
)

from segment_anything import sam_model_registry, SamPredictor

### The image encoder

In [2]:
print(sam_model_registry.keys())

dict_keys(['default', 'vit_h', 'vit_l', 'vit_b'])


In [3]:
# Default is vit_h
sam = sam_model_registry['default'](checkpoint="../models/sam_vit_h_4b8939.pth")
print(sam)

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0-31): 32 x Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate='none')
        )
      )
    )
    (neck): Sequential(
      (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): LayerNorm2d()
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): LayerNorm2d

In [4]:
image_encoder = sam.image_encoder
print(image_encoder)

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-31): 32 x Block(
      (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=1280, out_features=5120, bias=True)
        (lin2): Linear(in_features=5120, out_features=1280, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(1280, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LayerNorm2d()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): LayerNorm2d()
  )
)


### Transform images

In [5]:
model_image_params = {
    'image_mode': sam.image_format,
    'image_size': image_encoder.img_size,
    'pixel_mean': sam.pixel_mean,
    'pixel_std': sam.pixel_std,
    'device': sam.device,
}
print(model_image_params)

{'image_mode': 'RGB', 'image_size': 1024, 'pixel_mean': tensor([[[123.6750]],

        [[116.2800]],

        [[103.5300]]]), 'pixel_std': tensor([[[58.3950]],

        [[57.1200]],

        [[57.3750]]]), 'device': device(type='cpu')}


In [6]:
@staticmethod
def resize_longest_side(
    image: np.ndarray,
    target_length: int,
) -> np.array:
    """
    Resizes images so that the longest side is resized to the target length.
    Arguments:
      image: of shape HWC
      target_length:
      image_format:
    """
    h, w = image.shape[0], image.shape[1]
    longest_side = max(h, w)
    scale = target_length * 1.0 / longest_side
    # Adjust +0.5 to compensate int() flooring effect
    newh = int(h * scale + 0.5)
    neww = int(w * scale + 0.5)
    # Convert
    #   tensor of shape CHW
    #   a numpy ndarray of shape HWC
    # to
    #   a PIL Image
    pil_image = to_pil_image(image)
    resized_image = np.array(resize(pil_image, [newh, neww]))
    return resized_image

In [7]:
@staticmethod
def to_image_tensor(
    image: np.ndarray,
    device: str,
) -> torch.Tensor:
    """
    Arguments:
      image (np.ndarray): The image for calculating masks. Expects an
        image in HWC uint8 format, with pixel values in [0, 255].
    """
    image_torch = torch.as_tensor(image, device=device).float()
    # permute(2, 0, 1): HWC -> CHW
    # coniguous() rearranges elements linearly in memory; it is usually necessary after the permute() call
    # [None, :, :, :] adds a dimension for batch
    # TODO: einops + continguous
    image_torch = image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    return image_torch

In [8]:
@staticmethod
@torch.no_grad()
def compute_embeddings(
    image: np.ndarray,
    image_mode: str = 'RGB',
) -> torch.Tensor:
    """
    Calculates the image embeddings for the provided image.
    Arguments:
      image: np.ndarray
        Expects an image in HWC uint8 format with pixel values in [0, 255].
    """
    assert image_mode in [
        'RGB',
        'BGR',
    ]
    if image_mode != model_image_params['image_mode']:
        # Reverse on the last dimension which is the color channel
        # In NumPy, Ellipsis (...) is a shorthand for selecting all preceding dimensions
        # This is equivalent to image[:, :, ::-1]
        # ::-1 is indexing start:stop:step where step -1 reverses the array
        image = image[..., ::-1]

    image = resize_longest_side(image, model_image_params['image_size'])
    image_tensor = to_image_tensor(image, device=model_image_params['device'])

    # normalize() expects image tensor CHW or BCHW
    image_tensor = normalize(
        image_tensor,
        model_image_params['pixel_mean'].squeeze(),
        model_image_params['pixel_std'].squeeze(),
    )

    # Skip BC to read HW
    h, w = image_tensor.shape[-2:]
    padh = model_image_params['image_size'] - h
    padw = model_image_params['image_size'] - w
    # Pad right and bottom
    image_tensor = pad(image_tensor, (0, 0, padw, padh))

    image_embeddings = image_encoder(image_tensor)
    return image_embeddings

### Verify embeddings

In [9]:
image = Image.open('../../../fashion/images/brands/erdem-01.webp')
image = np.array(image)

In [10]:
%time embeddings = compute_embeddings(image)

CPU times: user 59.5 s, sys: 11.5 s, total: 1min 10s
Wall time: 7.16 s


In [11]:
print(embeddings.shape)

torch.Size([1, 256, 64, 64])


In [12]:
predictor = SamPredictor(sam)

In [13]:
%time predictor.set_image(image)

CPU times: user 56.6 s, sys: 8.73 s, total: 1min 5s
Wall time: 6.59 s


In [14]:
predictor.is_image_set

True

In [15]:
torch.equal(embeddings, predictor.features)

True