In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import BasicBlock, Bottleneck
from transformers import SamModel, SamProcessor

In [None]:
class Processor:
    def __init__(self):
        self.processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")

    def preprocess(self, image):
        inputs = self.processor(image, return_tensors="pt")
        return inputs

    def postprocess(self, inputs, outputs):
        h, w = inputs['original_sizes'][0]

        resized = F.interpolate(
            outputs, size=(h, w),
            mode="bilinear", align_corners=True
        )

        resized = resized.squeeze(1)
        resized *= 255
        return resized


In [None]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.sam = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
        self.encoder = self.image_encoder
        self.decoder = self.initialize_decoder()

        # Freeze SAM encoder
        for param in self.sam.parameters():
            param.requires_grad = False

    def image_encoder(self, inputs):
        # Extract image embeddings using SAM
        # 256 x 64 x 64
        return self.sam.get_image_embeddings(inputs["pixel_values"])

    def initialize_decoder(self, use_bottleneck=False):
        # Choose block type: BasicBlock or Bottleneck
        Block = Bottleneck if use_bottleneck else BasicBlock

        upscale = nn.Sequential(
            # Upsample 256 x 64 x 64 -> 128 x 128 x 128
            nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2)),
            nn.ReLU(),
            Block(inplanes=128, planes=128),  # Residual Block

            # Upsample 128 x 128 x 128 -> 64 x 256 x 256
            nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2)),
            nn.ReLU(),
            Block(inplanes=64, planes=64),  # Residual Block

            # Reduce channels 64 x 256 x 256 -> 32 x 256 x 256
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            Block(inplanes=32, planes=32),  # Residual Block

            # Final output 32 x 256 x 256 -> 1 x 256 x 256
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

        return upscale

    def forward(self, inputs):
        embeddings = self.encoder(inputs)
        depth = self.decoder(embeddings)
        return depth


# Example

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Network().to(device)
processor = Processor()

preprocessor_config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Dataset
# https://www.kaggle.com/datasets/soumikrakshit/nyu-depth-v2

# Dataset and DataLoader Reference
# https://www.kaggle.com/code/shreydan/monocular-depth-estimation-nyuv2