## Import dependencies

In [1]:
import os
import numpy as np
from pycocotools.coco import COCO
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import wandb

### Define Model

In [2]:
class MinimalSam(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        def depthwise_conv(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, in_c, kernel_size=3, padding=1, groups=in_c, bias=False),
                nn.BatchNorm2d(in_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_c, out_c, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
            )
        
        self.encoder_stage1 = depthwise_conv(in_channels, 48)
        self.down1 = nn.Conv2d(48, 48, kernel_size=3, stride=2, padding=1, bias=False)
        self.bottleneck = depthwise_conv(48, 96)
        self.up1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), depthwise_conv(96, 48))
        self.output_head = nn.Conv2d(48, 1, kernel_size=1)

    def forward(self, x):
        x = self.encoder_stage1(x)
        x = self.down1(x)
        x = self.bottleneck(x)
        x = self.up1(x)
        x = self.output_head(x)
        return x

### Define Dataset

In [3]:
class MinimalSamDataset(Dataset):
    def __init__(self, annotation_file: str, img_dir: str, img_size: int):
        super().__init__()

        self.img_dir = img_dir
        self.img_size = img_size

        self.coco = COCO(annotation_file)
        self.ann_ids = self.coco.getAnnIds()
        self.anns = [ann for ann in self.coco.loadAnns(self.ann_ids) if ann.get("iscrowd", 0) == 0]
        

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # mean and std of ImageNet
        ])

    def __len__(self):
        return len(self.anns)
    
    def __getitem__(self, index):
        ann = self.anns[index]
        img_id = ann['image_id']
        img = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img['file_name'])

        image = Image.open(img_path).convert("RGB")
        mask = self.coco.annToMask(ann)

        cropped_image, cropped_mask = self._crop(image, mask)

        image_tensor = self.transform(cropped_image)
        mask_tensor = torch.tensor(np.array(cropped_mask), dtype=torch.float32).unsqueeze(0) # why unsqueeze??

        return image_tensor, mask_tensor
    
    def _crop(self, image, mask):
        ys, xs = np.where(mask > 0)
        if len(xs) == 0:
            center_x, center_y = image.size[0] // 2, image.size[1] // 2
        else:
            min_x, max_x = xs.min(), xs.max()
            min_y, max_y = ys.min(), ys.max()
            mask_center_x = (min_x + max_x) // 2
            mask_center_y = (min_y + max_y) // 2

            if mask[mask_center_y, mask_center_x]:
                center_x, center_y = mask_center_x, mask_center_y
            else:
                distances = (xs - mask_center_x) ** 2 + (ys - mask_center_y) ** 2
                closest_idx = np.argmin(distances)
                center_x, center_y = xs[closest_idx], ys[closest_idx]

        left = max(0, center_x - self.img_size // 2)
        top = max(0, center_y - self.img_size // 2)
        right = min(image.size[0], left + self.img_size)
        bottom = min(image.size[1], top + self.img_size)

        cropped_img = image.crop((left, top, right, bottom)).resize((self.img_size, self.img_size), Image.BILINEAR)
        cropped_mask = Image.fromarray(mask[top:bottom, left:right]).resize((self.img_size, self.img_size), Image.NEAREST)

        return cropped_img, cropped_mask  #, center_x, center_y, left, top

### Define loss functions & performance metrics

In [4]:
def bce_dice_loss(pred_mask, gt_mask):
    pred_mask = torch.sigmoid(pred_mask)

    bce = F.binary_cross_entropy(pred_mask, gt_mask)
    intersection = (pred_mask * gt_mask).sum(dim=(1, 2, 3))
    union = pred_mask.sum(dim=(1, 2, 3)) + gt_mask.sum(dim=(1, 2, 3))
    dice = 1 - ((2 * intersection + 1e-6) / (union + 1e-6)).mean()
    return bce + dice

def compute_iou(pred_mask, target_mask):
    pred_binary = (torch.sigmoid(pred_mask) > 0.5).cpu().numpy()
    target_binary = (torch.sigmoid(target_mask) > 0.5).cpu().numpy()

    intersection = np.logical_and(pred_binary, target_binary).sum()
    union = np.logical_or(pred_binary, target_binary).sum()
    return intersection / union if union > 0 else 1.0

### Train model

In [None]:
# setup model and database
annotation_file = "../dataset/annotations/instances_train2017.json"
img_dir = "../dataset/train2017"
img_size = 96

LEARNING_RATE = 3e-4
NUM_EPOCHS = 1
BATCH_SIZE = 8
OUTPUT_DIR = "../outputs/"

def train():
    wandb.init(project="PicoSAM2-scratch", config={"img_size": img_size, "epochs": NUM_EPOCHS})
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MinimalSam().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, step / 1000))

    dataset = MinimalSamDataset(annotation_file, img_dir, img_size)
    train_len = int(len(dataset) * 0.95)
    train_ds, val_ds = random_split(dataset, [train_len, len(dataset) - train_len])
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

    for epoch in range(NUM_EPOCHS):
        model.train()

        # logging
        total_loss, total_iou, samples = 0, 0, 0


        for batch_idx, (images, masks) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1} - Train")):
            images, masks = images.to(device), masks.to(device)

            preds = model(images)
            loss = bce_dice_loss(preds, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            # logging
            batch_iou = compute_iou(preds, masks)
            wandb.log({"batch_loss": loss.item(), "batch_mIoU": batch_iou, "epoch": epoch + 1})

            total_loss += loss.item() * images.size(0)
            total_iou += batch_iou * images.size(0)
            samples += images.size(0)

        wandb.log({"train_loss": total_loss / samples, "train_mIoU": total_iou / samples, "epoch": epoch + 1})

        model.eval()
        val_loss, val_iou, val_samples = 0, 0, 0
        with torch.no_grad():
            for images, masks, _, _ in tqdm(val_loader, desc=f"Epoch {epoch + 1} - Val"):
                images, masks = images.to(device), masks.to(device)
                preds = model(images)
                loss = bce_dice_loss(preds, masks)
                val_loss += loss.item() * images.size(0)
                val_iou += compute_iou(preds, masks) * images.size(0)
                val_samples += images.size(0)

        wandb.log({"val_loss": val_loss / val_samples, "val_mIoU": val_iou / val_samples, "epoch": epoch + 1})

        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"PicoSAM2_epoch{epoch + 1}.pt"))

loading annotations into memory...
Done (t=9.01s)
creating index...
index created!


In [6]:
train()

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: You chose "Don't visualize my results"


loading annotations into memory...
Done (t=9.58s)
creating index...
index created!


Epoch 1 - Train:   0%|          | 118/100932 [00:28<6:39:49,  4.20it/s]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7e401dd59b80>> (for post_run_cell), with arguments args (<ExecutionResult object at 7e401f0dbda0, execution_count=6 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7e401f0dbc50, raw_cell="train()" transformed_cell="train()
" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/home/cyril/git/giga-sam/notebooks/minimal-sam.ipynb#X13sZmlsZQ%3D%3D> result=None>,),kwargs {}:


ConnectionResetError: Connection lost