## Import dependencies

In [2]:
import os
import numpy as np
from pycocotools.coco import COCO
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms

### Define Model

In [3]:
class MinimalSam(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        def depthwise_conv(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, in_c, kernel_size=3, padding=1, groups=in_c, bias=False),
                nn.BatchNorm2d(in_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_c, out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
            )
        
        self.encoder_stage1 = depthwise_conv(in_channels, 48)
        self.down1 = nn.Conv2d(48, 48, kernel_size=3, stride=2, padding=1, bias=False)
        self.bottleneck = depthwise_conv(48, 96)
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), depthwise_conv(96, 48))
        self.output_head = nn.Conv2d(48, 1, kernel_size=1)

    def forward(self, x):
        x = self.encoder_stage1(x)
        x = self.down1(x)
        x = self.bottleneck(x)
        x = self.up1(x)
        x = self.output_head(x)
        return x

### Test Model Definition

In [4]:
model = MinimalSam()

print('The model:')
print(model)

print('\n\nJust one layer:')
print(model.encoder_stage1)

print('\n\nModel params:')
for param in model.parameters():
    print(param)

print('\n\nLayer params:')
for param in model.encoder_stage1.parameters():
    print(param)

The model:
MinimalSam(
  (encoder_stage1): Sequential(
    (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(3, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (down1): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bottleneck): Sequential(
    (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
    (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (up1): Sequential

### Define Dataset

In [None]:
class MinimalSamDataset(Dataset):
    def __init__(self, annotation_file: str, img_dir: str, img_size: int):
        super().__init__()

        self.img_dir = img_dir
        self.img_size = img_size

        self.coco = COCO(annotation_file)
        self.ann_ids = self.coco.getAnnIds()
        self.anns = [ann for ann in self.coco.loadAnns(self.ann_ids) if ann.get("iscrowd", 0) == 0]

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # mean and std of ImageNet
        ])

    def __len__(self):
        return len(self.anns)
    
    def __getitem__(self, index):
        ann = self.anns[index]
        img_id = ann['image_id']
        img = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img['file_name'])

        image = Image.open(img_path).convert("RGB")
        mask = self.coco.annToMask(ann)

        cropped_image, cropped_mask = self._crop(image, mask)

        image_tensor = self.transform(cropped_image)
        mask_tensor = torch.tensor(np.array(cropped_mask), dtype=torch.float32).unsqueeze(0) # why unsqueeze??

        return image_tensor, mask_tensor
    
    def _crop(self, image, mask):
        ys, xs = np.where(mask > 0)
        if len(xs) == 0:
            center_x, center_y = image.size[0] // 2, image.size[1] // 2
        else:
            min_x, max_x = xs.min(), xs.max()
            min_y, max_y = ys.min(), ys.max()
            mask_center_x = (min_x + max_x) // 2
            mask_center_y = (min_y + max_y) // 2

            if mask[mask_center_y, mask_center_x]:
                center_x, center_y = mask_center_x, mask_center_y
            else:
                distances = (xs - mask_center_x) ** 2 + (ys - mask_center_y) ** 2
                closest_idx = np.argmin(distances)
                center_x, center_y = xs[closest_idx], ys[closest_idx]

        left = max(0, center_x - self.img_size // 2)
        top = max(0, center_y - self.img_size // 2)
        right = min(image.size[0], left + self.img_size)
        bottom = min(image.size[1], top + self.img_size)

        cropped_img = image.crop((left, top, right, bottom)).resize((self.img_size, self.img_size), Image.BILINEAR)
        cropped_mask = Image.fromarray(mask[top:bottom, left:right]).resize((self.img_size, self.img_size), Image.NEAREST)

        return cropped_img, cropped_mask  #, center_x, center_y, left, top

### Test Dataset Definition

In [6]:
annotation_file = "../dataset/annotations/instances_train2017.json"
img_dir = "../dataset/train2017"
img_size = 96

dataset = MinimalSamDataset(annotation_file, img_dir, img_size)

loading annotations into memory...


Done (t=8.11s)
creating index...
index created!


In [9]:
print("Length of dataset:", dataset.__len__())
print(dataset.__getitem__(0))

Length of dataset: 849949
(tensor([[[ 2.1804,  2.1975,  2.1804,  ..., -1.8953, -1.8953, -1.8953],
         [ 2.2489,  2.2489,  2.1975,  ..., -1.7069, -1.7925, -1.8953],
         [ 2.1975,  2.2147,  2.1804,  ..., -2.0323, -1.9295, -1.8097],
         ...,
         [ 1.3242,  1.3242,  1.3584,  ...,  0.5878,  1.3755,  1.5982],
         [ 1.1187,  0.9646,  0.7248,  ...,  1.5468,  1.2899,  1.2214],
         [ 0.7248,  0.6734,  0.7077,  ...,  1.3584,  1.4612,  1.3755]],

        [[ 1.7808,  1.9384,  2.0259,  ..., -1.8081, -1.8081, -1.8431],
         [ 1.7983,  1.8683,  1.9909,  ..., -1.7206, -1.8081, -1.8782],
         [ 1.7808,  1.7633,  1.8859,  ..., -2.0357, -1.9132, -1.7381],
         ...,
         [-1.6331, -1.5280, -1.4055,  ..., -1.1253, -0.6001, -0.3901],
         [-1.3529, -1.3704, -1.5105,  ..., -0.5126, -0.6001, -0.5301],
         [-1.6155, -1.6506, -1.4930,  ..., -0.5826, -0.4951, -0.5126]],

        [[ 0.5834,  0.6879,  0.8622,  ..., -1.6127, -1.5779, -1.5953],
         [ 0.6879,