Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a mask prompt for boundary refinement #169

Open
Gpoxolcku opened this issue Apr 11, 2023 · 29 comments
Open

Using a mask prompt for boundary refinement #169

Gpoxolcku opened this issue Apr 11, 2023 · 29 comments
Labels
how-to Request for help on how to do something specific

Comments

@Gpoxolcku
Copy link

Hi, I have a roughly labeled dataset and trying to feed it's labels as a prompt into SAM. I want SAM to refine the segmentation labels and improve my dataset quality. In my case I don't use any additional prompt artifacts like points or boxes (though it works pretty good for such prompts). It seems to me that a pure mask prompt should be supported as well, according to the paper. But the results I obtain are kinda unreliable, an output mask mostly repeats an input one, even making it slightly worse. Is there a code snippet to build the prompts out of the foreign masks?
Thanks in advance!

@nikhilaravi nikhilaravi added the how-to Request for help on how to do something specific label Apr 12, 2023
@maoyj1998
Copy link

I encounter the same problem, I found the input mask size must be 256*256. When I resize my mask to this size, the output segmentation results is a mess and make no sense. Do anyone have a clue?

@rmokady
Copy link

rmokady commented Apr 18, 2023

Encountered the same issue

@kampelmuehler
Copy link

kampelmuehler commented Apr 19, 2023

Can confirm that also in their colab if you use just the mask_input without any sparse guidance the results will not change.

If you want to do the same with a mask from another source you must at least first zero pad it to square dimensions and then resize to 256x256. Also it might need some sort of normalization to properly work.

Output of the first stage, obtained using a single guidance point:
image

Output feeding the logits of the first stage as mask_input - without any additional queries:
image

@qraleq
Copy link

qraleq commented Apr 19, 2023

I'm observing the same behavior.
@Gpoxolcku Did you manage to figure it out? I would like to use SAM for exactly the same use case you're mentioning.

@Gpoxolcku
Copy link
Author

I'm observing the same behavior.
@Gpoxolcku Did you manage to figure it out? I would like to use SAM for exactly the same use case you're mentioning.

Not yet, unfortunately. The only crutch I came up with is sampling the points inside of an instance's mask (with different sampling strategies, e.g. w.r.t distance transform) as "positive" class and sampling more points outside of the mask as "negative" class. Then combining those sparse inputs along with a coarse binary mask into a prompt and feed it into the SAM. But still that's not perfect to refine the dataset

@kampelmuehler
Copy link

@Gpoxolcku
Interesting that you had success with this - I found it to not really work well with many query points. How many did you use?

@Davidyao99
Copy link

Davidyao99 commented Apr 20, 2023

@kampelmuehler when you say pad the mask, is this so that the mask fits over the transformed input image to the model? (which is also padded and squared.)

@tldrafael
Copy link

tldrafael commented Apr 20, 2023

Using only the mask logits did not work for me, it rendered nonsense results.

On the other hand, querying positive and negative points from the binary mask yielded a better result; and it improves a lot if you do it iteratively, feeding the predictor with random samples of positive and negative points of the binary mask besides the best logits outcome for some time (on the paper, they mention 11 iterations; see. Appendix $A - Training Algorithm);

The number of query points does not seem to matter too much; it seems to impact more on how fidelity you want it to keep with the original mask. In my data case, despite the iterative process improving the SAM outcome, it didn't refine fine details.

@antoniocandito
Copy link

I will be instrested to boundary refinement using a mask prompt.

First, I tried using a bbox and the object has been delineated with good accuracy.
Second, I did convert the bbox to a binary mask, but model keep generating no contours for this prompt.
I did resize the binary mask to 1x256x256.

Any help with this please?

@kampelmuehler
Copy link

@Davidyao99 yes, precisely

@GoGoPen
Copy link

GoGoPen commented May 3, 2023

The mask prompt and bbox prompt are needed to provided together to generate a proper mask.
The mask should be rescaled by long size as the preprocessing of the image. Then the mask need to be padded to 256x256. Also note that the mask should be a binary mask.

@cip8
Copy link

cip8 commented May 15, 2023

@antoniocandito , did you manage to make the mask input work?

@GoGoPen that's really useful information, thanks for sharing!
I tried with binary masks, then I tried to convert their values to logits (similar as to what SAM returns) but with no success.
Could you help us out with an approximate example on how to convert from binary to SAM-accepted format? 💙🖖

@markushunter
Copy link

What is the proper way to pad the mask? Do you add the pad to the lower right, or do you center the mask in the target dimensions and add padding to the top, bottom, left, and right?

@markushunter
Copy link

Not sure if this is the proper way to get the mask output, but this is what I discovered...

SamPredictor.predict() states that the mask_input should be something like this:

            mask_input (np.ndarray): A low resolution mask input to the model, typically
            coming from a previous prediction iteration. Has form 1xHxW, where
            for SAM, H=W=256.

Looking at a histogram of values the model produces in the low_res_masks_np output of predict(), the values are not a boolean mask. The values are floats.

image

In Sam.py, the mask_threshold is harcoded to 0.0. Thresholding the low_res_masks_np output with 0 showed an ok mask.

Looking at a thresholded and scaled * 128 version of low_res_masks_np, the model applies padding to the mask to the bottom and right only.

By making a custom mask_input for SamPredictor.predict() where negative locations are -8 and positive locations are 1 with -8 padding on the bottom and right of the mask as necessary, subsequent reruns of segment anything produced a mask. However, the mask still wasn't perfect.

@cip8
Copy link

cip8 commented May 16, 2023

Very good observations @markushunter! I'll try to add padding to the bottom-right and see if the results change.

I don't fully understand the part related to assigning values of -8 / 1. Do you mean that binary mask values (0/1) should be replaced to -8 and 1, because of the 0.0 mask_threshold in Sam.py?

Thanks for the info! 💙🖖

@markushunter
Copy link

@cip8 Yes, instead of using 0 or 1 for the values in the mask, you need to represent the negative space with a number far less than zero. Since SAM thresholds the mask at the floating point value 0.0, having the negative space as 0.0 isn't good enough.

The histogram seemed to imply that negative space in the output mask has values around -8 to -10, so I just ran with -8.

@cip8
Copy link

cip8 commented May 16, 2023

The docs say that logits from a previous run can be used for this mask_input.

These logits are indeed floats and look like this:

(1, 256, 256) / float32 
[[[-11.90418   -12.534466  -13.846361  ... -19.109943  -19.418356  -18.89853  ]
  [-12.359286  -15.481771  -14.45459   ... -20.847857  -19.149311  -20.422709 ]
  [-11.877727  -13.034173  -13.75271   ... -18.832436  -20.500711  -19.762798 ]
  ...
  [ -2.2596788   7.5174465   6.552994  ... -12.575503  -11.790027  -11.399892 ]
  [ -2.4307566   8.219517    6.204507  ... -10.520343  -11.538195  -10.084738 ]
  [ -2.2835727   5.5459557   5.4847026 ... -11.283685  -11.467551  -9.843957 ]]]

From what I understand they represent probabilities for the mask, do you know if that's accurate?

@cip8
Copy link

cip8 commented May 17, 2023

Grayscale mask to SAM mask_input:

Based on the info discussed so far, this is how I implemented a conversion between grayscale and SAM's mask_input in my code:

class Segmentix:
[...]
    def resize_mask(
        self, ref_mask: np.ndarray, longest_side: int = 256
    ) -> tuple[np.ndarray, int, int]:
        """
        Resize an image to have its longest side equal to the specified value.

        Args:
            ref_mask (np.ndarray): The image to be resized.
            longest_side (int, optional): The length of the longest side after resizing. Default is 256.

        Returns:
            tuple[np.ndarray, int, int]: The resized image and its new height and width.
        """
        height, width = ref_mask.shape[:2]
        if height > width:
            new_height = longest_side
            new_width = int(width * (new_height / height))
        else:
            new_width = longest_side
            new_height = int(height * (new_width / width))

        return (
            cv2.resize(
                ref_mask, (new_width, new_height), interpolation=cv2.INTER_NEAREST
            ),
            new_height,
            new_width,
        )

    def pad_mask(
        self,
        ref_mask: np.ndarray,
        new_height: int,
        new_width: int,
        pad_all_sides: bool = False,
    ) -> np.ndarray:
        """
        Add padding to an image to make it square.

        Args:
            ref_mask (np.ndarray): The image to be padded.
            new_height (int): The height of the image after resizing.
            new_width (int): The width of the image after resizing.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The padded image.
        """
        pad_height = 256 - new_height
        pad_width = 256 - new_width
        if pad_all_sides:
            padding = (
                (pad_height // 2, pad_height - pad_height // 2),
                (pad_width // 2, pad_width - pad_width // 2),
            )
        else:
            padding = ((0, pad_height), (0, pad_width))

        # Padding value defaults to '0' when the `np.pad`` mode is set to 'constant'.
        return np.pad(ref_mask, padding, mode="constant")

    def reference_to_sam_mask(
        self, ref_mask: np.ndarray, threshold: int = 127, pad_all_sides: bool = False
    ) -> np.ndarray:
        """
        Convert a grayscale mask to a binary mask, resize it to have its longest side equal to 256, and add padding to make it square.

        Args:
            ref_mask (np.ndarray): The grayscale mask to be processed.
            threshold (int, optional): The threshold value for the binarization. Default is 127.
            pad_all_sides (bool, optional): Whether to pad all sides of the image equally. If False, padding will be added to the bottom and right sides. Default is False.

        Returns:
            np.ndarray: The processed binary mask.
        """

        # Convert a grayscale mask to a binary mask.
        # Values over the threshold are set to 1, values below are set to -1.
        ref_mask = np.clip((ref_mask > threshold) * 2 - 1, -1, 1)

        # Resize to have the longest side 256.
        resized_mask, new_height, new_width = self.resize_mask(ref_mask)

        # Add padding to make it square.
        square_mask = self.pad_mask(resized_mask, new_height, new_width, pad_all_sides)

        # Expand SAM mask's dimensions to 1xHxW (1x256x256).
        return np.expand_dims(sam_mask, axis=0)

Usage example:

# Convert reference mask to SAM format & run predictor.
sam_mask: np.ndarray = self.reference_to_sam_mask(reference_mask)
masks, scores, logits = predictor.predict(
    multimask_output=False,
    box=np.array(ref_bbox),
    mask_input=sam_mask,
)

Experimental findings

  • Not sure if my implementation is 100% correct, but I can see some slight improvements in my final results.
  • Results are underwhelming: even though the provided mask_input is a good quality foreground / background separation mask, it seems to make little difference when used.
  • Some rough edges or small holes into the foreground are fixed for most of my test images. A small number look worse than when using a bounding-box only. But again, my implementation could be inaccurate (any further help is highly appreciated🙏)
  • Masks cannot be used by themselves, they need a bounding box or input points to work. When used alone, SAM's resulting mask will be all-black (no foreground detected).

Improvement suggestions

  • Expand mask selector's role: to better cover use-cases where SAM is used as part of a specialized segmentation pipeline, I would expand the quality and importance of the mask_input. This selector has (in theory) the potential to be much more powerful than bounding boxes or input points, but it seems to barely influence results in practice.

  • Improve documentation: feature is not well documented. We still don't know for sure how this works and why its results seem to be so inconsistent.

@AkshitSharma1
Copy link

Hi all,

I am trying to refine cell segmentation foreground/background mask predicted by another model using SAM. I have tried following iterative approach and Grayscale to mask_input approach (as mentioned by @cip8 sir) but no help. Please could someone guide me? All my images are greyscale of size (256,256)

@qraleq
Copy link

qraleq commented Jun 15, 2023

Hi all, has anyone managed to solve this problem efficiently?
I see marginal improvements using the approach @cip8 proposed.

@danigarciaoca
Copy link

danigarciaoca commented Jun 17, 2023

Hi everyone,

If you check this demo notebook, it is explained that the input mask is not such a mask: it is supposed to be the output low resolution mask from a previous iteration (prediction):

If available, a mask from a previous iteration can also be supplied to the model to aid in prediction

So for now, it seems that is no possible to prompt with an accurate mask (or not with good results).

Hope it helps!

@cip8
Copy link

cip8 commented Jun 17, 2023

If you check this demo notebook, it is explained that the input mask is not such a mask: it is supposed to be the output low resolution mask from a previous iteration (prediction):

Is there a law that says we are not allowed to "fake" these logits? 😃

So far in this conversation people came with different conclusions on how to replicate the behavior of these masks, where the threshold point is, etc. I don't think an answer that doesn't take into consideration the rest of the thread is helpful. Anyone can say "this can't be done", but that's not a real hacker mentality and rarely achieves anything.

@cip8
Copy link

cip8 commented Jun 17, 2023

Hi all, has anyone managed to solve this problem efficiently? I see marginal improvements using the approach @cip8 proposed.

I think the problem resides in the "weight" associated to this extra mask parameter.
My intuition is that the model doesn't put a lot of importance on it, being instead focused on points and border boxes. Because even when you supply them with "clean" logits from another run, the results seem to change only marginally.

Maybe the next version will put a bigger importance on this param, and maybe accept sizes greater than 256x256 - this will make the model easier to include in existing image processing pipelines.

As a trick to bypass this I extract a grid of points from the mask and pass it to SAM instead - the results are much better than the minor changes provided by using the mask_input.

I wish someone from Meta could clarify this for us 🙏

💙🖖

@danigarciaoca
Copy link

danigarciaoca commented Jun 17, 2023

@cip8 apologize if my response was not to your liking or not what you were looking for.

I you had taken the time to deeply read the paper and replicate SAM architecture (not just reading the docs...) you would understand the purpose of this mask_input better.

Of course it is possible to replicate it, is just coding and imitating. My point was about replicating it with the desired results. Once again, if you read this thread with all its comments you can check that nobody has gotten the "refinement" results that everyone (including me) were expecting. This is because mask_input is expected to be used in conjunction with a point prompt input (or box), not alone by itself.

PS: if it was so straightforward, Meta would have released it for mask prompting...

@cip8
Copy link

cip8 commented Jun 17, 2023

@cip8 apologize if my response was not to your liking or not what you were looking for.

It's not about that @dankresio - every contribution is of course helpful and I appreciate your reply, truly! It just seemed to me that your answer didn't take into consideration what was discussed before & I'm also quite easily-triggered by "can't be done" type of answers 😅 I also apologize for my harsh reply 💙🖖

@shelper
Copy link

shelper commented Oct 9, 2023

i wonder that in stead of "extracting a grid of points from the mask and pass it to SAM", if we shrink the mask prompt with a certain pixel number (to avoid the sampled points later being out of the ground truth mask), and sample a few points on the edge of the shrinked mask would provide better results. To me, it may constrains SAM in a way similar to a mask promt.

I might test this and report back if i do that. Others may update if anyone here has time to try it out

@jerome223
Copy link

It seems that self-made masks to logit was implemented in micro-sam for ellipse and polygonal prompts and seems to be working correctly.
Function : def _compute_logits_from_mask(mask, eps=1e-3):

Link :
https://github.com/computational-cell-analytics/micro-sam/blob/83997ff4a471cd2159fda4e26d1445f3be79eb08/micro_sam/prompt_based_segmentation.py#L375-L388

@m13uz
Copy link

m13uz commented Dec 28, 2023

You guys might want to check out these two repositories and try creating some sort of a pipeline stitching everything together

https://github.com/danielgatis/rembg/tree/main
https://github.com/hkchengrex/CascadePSP

@voplica
Copy link

voplica commented May 2, 2024

Does anyone know how to convert ultralytics (YOLO) masks to the input_mask for SAM?
I tried both above approaches (reusing _compute_logits_from_mask and reusing reference_to_sam_mask) but seems those are for different mask types.
Please, if anyone has any clue here, share it and let me buy you a coffee 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how-to Request for help on how to do something specific
Projects
None yet
Development

No branches or pull requests