In [170]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import SamModel, SamProcessor

In [193]:
class Processor:
    def __init__(self):
        self.processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")

    def preprocess(self, image):
        inputs = self.processor(image, return_tensors="pt")
        return inputs

    def postprocess(self, inputs, outputs):
        ret = []
        for i in range(outputs.shape[0]):
            h, w = inputs['original_sizes'][i]
            resized = F.interpolate(outputs[i:i+1],
                size=(h, w),
                mode="bilinear",
                align_corners=True
            )

            ret.append(resized.squeeze())

        return ret


In [194]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.sam = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
        self.encoder = self.image_encoder
        self.decoder = self.initialize_decoder()

        # Freeze SAM encoder
        for param in self.sam.parameters():
            param.requires_grad = False

    def image_encoder(self, inputs):
        # Extract image embeddings using SAM
        # 256 x 64 x 64
        return self.sam.get_image_embeddings(inputs["pixel_values"])

    def initialize_decoder(self):
        # Define the decoder architecture
        upscale = nn.Sequential(
            # 256 x 64 x 64 -> 64 x 128 x 128
            nn.ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2)),
            nn.ReLU(),

            # 64 x 128 x 128 -> 32 x 256 x 256
            nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2)),
            nn.ReLU(),

            # 1 x 256 x 256
            nn.Conv2d(32, 1, kernel_size=(3, 3), padding=1),
            nn.Sigmoid()
        )

        return upscale

    def forward(self, inputs):
        embeddings = self.encoder(inputs)
        depth = self.decoder(embeddings)
        return depth


# Example

In [195]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Network().to(device)
processor = Processor()

In [198]:
from PIL import Image
import requests

url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")

inputs = processor.preprocess([image, image]).to(device)
with torch.no_grad():
    outputs = model(inputs)
    prediction = processor.postprocess(inputs, outputs)