Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add how to preprocess mask for finetuning with SAM #27361

Closed
rwood-97 opened this issue Nov 8, 2023 · 1 comment · Fixed by #27463
Closed

Add how to preprocess mask for finetuning with SAM #27361

rwood-97 opened this issue Nov 8, 2023 · 1 comment · Fixed by #27463
Labels
Feature request Request for a new feature Vision

Comments

@rwood-97
Copy link
Contributor

rwood-97 commented Nov 8, 2023

Feature request

The SAM image processor takes images as input and resizes them so that the longest edge is 1024 (using default values). This is the size expect as input fo the SAM model.
For inference, this works fine as only the images need resizing but for fine-tuning as per this tutorial, you need to resize both your images and your masks as the SAM model produces pred_masks with size 256x256. If I don't resize my masks I get ground truth has different shape (torch.Size([2, 1, 768, 1024])) from input (torch.Size([2, 1, 256, 256])) when trying to calculate loss.

To fix this, I've currently written a resize and pad function into my code:

from PIL import Image

def resize_mask(image):
    longest_edge = 256
    
    # get new size
    w, h = image.size
    scale = longest_edge * 1.0 / max(h, w)
    new_h, new_w = h * scale, w * scale
    new_h = int(new_h + 0.5)
    new_w = int(new_w + 0.5)

    resized_image = image.resize((new_w, new_h), resample=Image.Resampling.BILINEAR)
    return resized_image

def pad_mask(image):
    pad_height = 256 - image.height
    pad_width = 256 - image.width

    padding = ((0, pad_height), (0, pad_width))
    padded_image = np.pad(image, padding, mode="constant")
    return padded_image

def process_mask(image):
    resized_mask = resize_mask(image)
    padded_mask = pad_mask(resized_mask)
    return padded_mask

and then have added this to my definition of SAMDataset:

class SAMDataset(Dataset):
    def __init__(self, dataset, processor, transform = None):
        self.dataset = dataset
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        if self.transform:
            image = self.transform(item["pixel_values"])
        else:
            image = item["pixel_values"]
        
        # get bounding box prompt
        padded_mask = process_mask(item["label"])
        prompt = get_bounding_box(padded_mask)

        # prepare image and prompt for the model
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}

        # add ground truth segmentation
        inputs["ground_truth_mask"] = padded_mask

        return inputs

This seems to work fine.

What I think would be good is to allow input of masks in the SAM image processor. For example, the Segformer image processor takes images and masks as inputs and resizes both to the size expected by the Segformer model.

I have also seen there is a 'post_process_mask' method in the SAM image processor but I am unsure how to implement this in the tutorial I'm following. If you think this is a better way vs. what I am suggesting then please could you explain where I would add this in the code from the tutorial notebook.

Motivation

Easier fine tuning of SAM model.

Your contribution

I could try write a PR for this and/or make a PR to update the notebook instead .

@amyeroberts
Copy link
Collaborator

Hi @rwood-97, thanks for raising this issue!

Agreed - being able to pass in the masks to the image processor would be ideal! Feel free to ping me on a PR for review if you'd like to open one :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Vision
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants