In [1]:
import transformers
import torch
import torchvision
import cv2
from PIL import Image
import json
import numpy as np
from matplotlib import pyplot as plt
from glob import glob
import os

device = "cuda"

In [2]:
import pycocotools.mask as mask_util

import torch

def normalize(tensor, mean, std):
    mask = torch.any(tensor > 0, dim=0, keepdim=True)

    # Normalize only non-black pixels
    normalized_tensor = torch.zeros_like(tensor)
    normalized_tensor = torch.where(
        mask,
        (tensor - mean.view(-1, 1, 1)) / std.view(-1, 1, 1),
        normalized_tensor
    )

    return normalized_tensor

class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, coco_folder, image_mean=[0, 0, 0], image_std=[0, 0, 0], train=True):
        ann_file = os.path.join(coco_folder, "coco_annotation", "MotionNet_train.json" if train else "MotionNet_valid.json")
        super(CocoDetection, self).__init__(os.path.join(coco_folder, "train/origin" if train else "valid/origin"), ann_file)
        self.pixel_mean = torch.tensor(image_mean)
        self.pixel_std = torch.tensor(image_std)

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        segms = [ann["segmentation"] for ann in target["annotations"]]
        masks = []
        for segm in segms:
            if segm:  # Check if the segmentation is not empty
                rle = mask_util.frPyObjects(segm, 1024, 1024)
                rle = mask_util.merge(rle)
                mask = mask_util.decode(rle).astype(bool)
                mask = Image.fromarray(mask)
                masks.append(mask)
        
        # Filter out instances without masks
        if not masks:
            return None
        
        # Does not work well on practice
        #img = normalize(img, self.pixel_mean, self.pixel_std)
        return img, target, masks

In [3]:
image_mean = [123.675 / 255, 116.28 / 255, 103.53 / 255]
image_std = [58.395 / 255, 57.12 / 255, 57.375 / 255]

train_dataset = CocoDetection(coco_folder='./partnetsim-1024-fixed-viewpoints/coco', image_mean=image_mean, image_std=image_std, train=True)
val_dataset = CocoDetection(coco_folder='./partnetsim-1024-fixed-viewpoints/coco', image_mean=image_mean, image_std=image_std, train=False)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


In [6]:
from torch.utils.data import DataLoader
from transformers import AutoProcessor
from torch.nn.utils.rnn import pad_sequence

processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
processor.image_processor.do_normalize = False

import random
import numpy as np

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if not batch:
        return None
    imgs = [item[0] for item in batch]
    bboxes = [[ann["bbox"] for ann in item[1]["annotations"]] for item in batch]
    masks = [item[2] for item in batch]

    augmented_bboxes = []
    for bbox_list in bboxes:
        augmented_bbox_list = []
        for bbox in bbox_list:
            x_min, y_min, w, h = bbox
            x_max = x_min + w
            y_max = y_min + h
            
            img_width, img_height = imgs[0].size
            
            # Apply random variations to bbox coordinates
            # I have tried 0-70 and 0-20
            x_min = max(0, x_min - random.randint(0, 70))
            x_max = min(img_width, x_max + random.randint(0, 70))
            y_min = max(0, y_min - random.randint(0, 70))
            y_max = min(img_height, y_max + random.randint(0, 70))
            
            x_new = x_min
            y_new = y_min
            w_new = x_max - x_min
            h_new = y_max - y_min
            
            augmented_bbox_list.append([x_new, y_new, w_new, h_new])
        
        augmented_bboxes.append(augmented_bbox_list)

    # Pad bboxes to have the same length within a batch
    max_bboxes = max(len(bbox_list) for bbox_list in augmented_bboxes)
    padded_bboxes = [bbox_list + [[0, 0, 0, 0]] * (max_bboxes - len(bbox_list)) for bbox_list in augmented_bboxes]

    input = processor(images=imgs, input_boxes=padded_bboxes, return_tensors="pt").to(torch.float32)

    masks = [torch.tensor(np.array(mask), dtype=torch.int) for mask in masks]
    masks = pad_sequence(masks, batch_first=True).to(device)
    input["masks"] = masks

    return input


train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=1)

In [7]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [8]:
import pytorch_lightning as pl
from transformers import SamModel, SamProcessor
import torch
from torch.nn import functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        inputs = inputs.float()  
        targets = targets.float()  
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()


class DiceLoss(torch.nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        intersection = torch.sum(inputs * targets, dim=(1, 2))
        union = torch.sum(inputs, dim=(1, 2)) + torch.sum(targets, dim=(1, 2))
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        return 1.0 - dice.mean()

class SAM(pl.LightningModule):
    def __init__(self, lr, lr_backbone, weight_decay, processor, model_name):
        super().__init__()
        self.model = SamModel.from_pretrained(model_name)
        # Output only one mask per image
        self.model.multimask_output = False
        self.processor = processor
        self.lr = lr
        self.lr_backbone = lr_backbone
        self.weight_decay = weight_decay
        self.focal_loss = FocalLoss()
        self.dice_loss = DiceLoss()

    def forward(self, pixel_values, bboxes):
        outputs = self.model(pixel_values=pixel_values, input_boxes=bboxes)
        return outputs

    def common_step(self, batch, batch_idx):
        if batch is not None:
            gt_masks = batch.pop("masks")
            outputs = self.model(**batch, multimask_output=False)

            # Interpolate masks fro 256 to 1024
            masks = F.interpolate(
                outputs["pred_masks"].squeeze(2),
                (1024, 1024),
                mode="bilinear",
                align_corners=False,
            )
            
            masks = masks.float() 
            gt_masks = gt_masks.float()  
    
            loss_focal = self.focal_loss(masks, gt_masks)
            loss_dice = self.dice_loss(masks, gt_masks)

            # Compute IoU loss
            pred_masks = (masks >= 0.5).float()
            intersection = torch.sum(torch.mul(pred_masks, gt_masks), dim=(1, 2))
            union = torch.sum(pred_masks, dim=(1, 2)) + torch.sum(gt_masks, dim=(1, 2))
            epsilon = 1e-7
            batch_iou = (intersection / (union + epsilon)).unsqueeze(1)
    
            iou_scores = outputs.iou_scores.view(-1, 1)
            batch_iou = batch_iou.view(-1, 1)
            min_len = min(iou_scores.size(0), batch_iou.size(0))
            iou_scores = iou_scores[:min_len]
            batch_iou = batch_iou[:min_len]
            loss_iou = F.mse_loss(iou_scores, batch_iou, reduction='mean')
    
            loss_total = 20. * loss_focal + loss_dice + loss_iou
    
            return loss_total, {"focal_loss": loss_focal, "dice_loss": loss_dice, "iou_loss": loss_iou}
        else:
            # Case of empty batch (no masks)
            return None, {}

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        if loss is not None:
            self.log("training_loss", loss)
            for k, v in loss_dict.items():
                self.log("train_" + k, v.item())
            return loss
        else:
            self.log("training_loss", 0)
            for k, v in loss_dict.items():
                self.log("train_" + k, v.item())
            return loss
    
    def validation_step(self, batch, batch_idx):
        loss, loss_dict = self.common_step(batch, batch_idx)
        if loss is not None:
            self.log("validation_loss", loss)
            for k, v in loss_dict.items():
                self.log("validation_" + k, v.item())
            return loss
        else:
            self.log("validation_loss", 0)
            for k, v in loss_dict.items():
                self.log("validation_" + k, v.item())
            return 0

    def configure_optimizers(self):
        # Decoder and Encoder get different learning rates
        param_dicts = [
            {
                "params": [p for n, p in self.named_parameters() if "decoder" in n and p.requires_grad],
                "lr": self.lr,
            },
            {
                "params": [p for n, p in self.named_parameters() if "shared_image_embedding" in n and p.requires_grad],
                "lr": self.lr_backbone,
            },
        ]

        optimizer = torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader


In [None]:
# Useful for GPUs with tensor cores
torch.set_float32_matmul_precision('medium')

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from datetime import datetime

wandb_logger = WandbLogger(project="SAM-pl-finetune")

dirpath = f"checkpoints/SAM/{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}"

model = SAM(lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4, processor=processor, model_name="facebook/sam-vit-base").to(device)

trainer = pl.Trainer(
    max_steps=3000,
    gradient_clip_val=0.2,
    logger=wandb_logger,
    accelerator=device,
    devices=1,
)

wandb_logger.log_hyperparams(model.hparams)

trainer.fit(model)

In [None]:
model.model.push_to_hub(f"diliash/sam-{dirpath.split('/')[-1]}")
processor.push_to_hub(f"diliash/sam-{dirpath.split('/')[-1]}")