Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning #5

Open
kdcd opened this issue Apr 5, 2023 · 81 comments
Open

Finetuning #5

kdcd opened this issue Apr 5, 2023 · 81 comments
Labels
enhancement New feature or request

Comments

@kdcd
Copy link

kdcd commented Apr 5, 2023

Is there any plans to release scripts for finetuning the model?

Also you did such a great work! Thank you very much!

@codybum
Copy link

codybum commented Apr 5, 2023

Information on fine tuning would be great.

@austinmw
Copy link

austinmw commented Apr 6, 2023

+1, I'd love to be able to fine tune to improve performance on extremely difficult tiny-object tasks, for example segmenting vehicles in geospatial images:

3CYVS3OSCFVCC4VHCCGRGRUT2Y

@penguingiraffe2
Copy link

this thread is referenced as the answer for similar questions, but I don't think there is an answer here for transfer learning?

@jindameias
Copy link

Look forward to finetuning

@BenSpex
Copy link

BenSpex commented Apr 7, 2023

I would love to be able to fine tune the model for specific datasets as well.

@hu-po
Copy link

hu-po commented Apr 7, 2023

Do we wait for Meta to provide a training/fine-tuning script? Or should the open source hivemind write it?

@TimWGY
Copy link

TimWGY commented Apr 8, 2023

Has anyone tried the idea of what may be called "point prompt engineering"? i.e. training a separate model that learns how to put positive prompt points and negative prompt points, such that these points prompt SAM to select target objects from a custom dataset.

Or we can just summarize strategies and best practices in terms of placing positive and negative prompt points/prompt boxes, similar to how GPT/DALLE users summarize the best ways to write prompts.

Wonder if this could be one way to fine-tune the SAM model when only a limited amount of annotations are available. Happy to discuss more if anyone wants to work together and try it out.

@openvino-book
Copy link

+1, Looking forward to fine-tuning the SAM model on the custom dataset.:)

@hu-po
Copy link

hu-po commented Apr 8, 2023

I am attempting some fine tuning in this repo. Perhaps people can find use in it. The biggest thing I figured out is that you have to break up the Sam model into its components in order for there to be a gradient path for fine-tuning.

@dlod-openvino
Copy link

After some messing around I have gotten preliminary fine-tuning to work on my fork. The code is still super messy and early, but perhaps people can find use in it. The biggest thing I figured out is that you have to break up the Sam model into its components in order for there to be a gradient path for fine-tuning.

Could you please recommend the minimum hardware configuration for fine-tuning the SAM? eg. 4090 x 4?

@hu-po
Copy link

hu-po commented Apr 8, 2023

Could you please recommend the minimum hardware configuration for fine-tuning the SAM? eg. 4090 x 4?

I can get the smallest pre-trained model (vit_b) with a batch size of 1 in <5GB of GPU memory, but I think fine tuning with those settings would take forever.

@codybum
Copy link

codybum commented Apr 8, 2023

I have access to a 4 x A100 /w 80G if you want me to test something.

@JunMa11
Copy link

JunMa11 commented Apr 9, 2023

hi @hu-po ,

Thanks for sharing the fine-tuning code very much. Would it be possible for you to give guidance on how to prepare the customized dataset (e.g., data format and folder structures)?

@hu-po
Copy link

hu-po commented Apr 9, 2023

hi @hu-po ,

Thanks for sharing the fine-tuning code very much. Would it be possible for you to give guidance on how to prepare the customized dataset (e.g., data format and folder structures)?

Thank me when I get it to work 😭 this is more complicated than anticipated.

@shakesBeardZ
Copy link

+1, interested in fine-tuning it for coral reef images.

@AMInnovationTeam
Copy link

+1 interested in fine-tuning it for cracking on roads.

@maskani-moh
Copy link

+1 🙌

@ariannaravera
Copy link

+1 interested in fine-tuning!

@harry-s-grewal
Copy link

+1, I'd like to do some vehicle detection on low quality images!

@nikhilaravi nikhilaravi added the enhancement New feature or request label Apr 12, 2023
@javiermcebrian
Copy link

+1 interested in fine tunning prompt encoder or mask decoder!

@francescodisalvo05
Copy link

+1! I would be interested in fine-tuning the model for medical image analysis

@travishsu
Copy link

I'm curious that is it possible to point out an unknown object have not been learned (like anomaly detection) by text prompt if I fine-tune with custom data.

@imandrealombardo
Copy link

+1!

@satpalsr
Copy link

CC: @ericmintun @nikhilaravi

@Kenneth-X
Copy link

@hu-po
hi, nice work for sharing finetune script , is "FragmentDataset" the datasets that released by official datasets https://segment-anything.com/dataset/index.html

@DuongTSon
Copy link

To reproduce the training (and eventunally fintetuning) you could also use:

https://github.com/UX-Decoder/Semantic-SAM

Thanks for the sharing.

From the code in the link, what you mean is in order to further improve the fine-tuned SAM we need to retrain the model on SAM dataset with the architecture close to the original SAM right? Currently, is there anyway to fine tune the encoder in the original SAM?

@nahidalam
Copy link

The blog that @alex-encord shared works fine if you have 1 bounding box per image. But in reality you will have multiple bounding box in an image. Below code fails if you have multiple bounding box

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

Am I missing something?

@nhhung1810
Copy link

Can you explain how fail @nahidalam? What is the error or stacktrace?

@nahidalam
Copy link

@nhhung1810 @alex-encord this is the error I get

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 19 for tensor number 1 in the list.

I get this error because I have 19 bounding boxes in my first image

Full stack trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[38], line 27
     24   box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
     25   box_torch = box_torch[None, :]
---> 27   sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
     28       points=None,
     29       boxes=box_torch,
     30       masks=None,
     31   )
     32 low_res_masks, iou_predictions = sam_model.mask_decoder(
     33   image_embeddings=image_embedding,
     34   image_pe=sam_model.prompt_encoder.get_dense_pe(),
   (...)
     37   multimask_output=False,
     38 )
     40 upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

File ~/SAM-experiments/env/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/SAM-experiments/env/lib/python3.8/site-packages/segment_anything/modeling/prompt_encoder.py:159, in PromptEncoder.forward(self, points, boxes, masks)
    157 if boxes is not None:
    158     box_embeddings = self._embed_boxes(boxes)
--> 159     sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
    161 if masks is not None:
    162     dense_embeddings = self._embed_masks(masks)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 19 for tensor number 1 in the list.

@nhhung1810
Copy link

nhhung1810 commented Sep 7, 2023

Hey @nahidalam
Maybe the problem is box_torch = box_torch[None, :]. I assume that your input box_torch have the batch-size as 19. But then you run box_torch = box_torch[None, :] => which make the batch-size become 1 (as box_torch.shape = [1, ...]).

Check this line and this line, you will see that your bs become 1.

I suggest that you remove that box_torch = box_torch[None, :] line and run again. If there is some problem, try putting a breakpoint and use debug to see the shape change, it may help a lot.

@wm-Githuber
Copy link

The blog that @alex-encord shared works fine if you have 1 bounding box per image. But in reality you will have multiple bounding box in an image. Below code fails if you have multiple bounding box

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

Am I missing something?

Hi @nahidalam
I have same question as you, i.e., reality image will have multiple object bounding box. Should we put the multiple box together? Thanks in advance.

@yogendra-yatnalkar
Copy link

Question: Can we fine-tune the decoder to infer without prompt ?

Possible Solution (Just thinking out loud): Fine-Tuning without Prompt or fixed static prompt >> Hence infer without prompt using fine-tuned model?


  • Every tutorial I saw has one thing in common: The fine-tuning SAM model for specific object also requires the prompt as input. During fine-tuning it is easier because we can take bounding-box of ground-truth mask and provide that as prompt-input.

  • But during inference, we don't have any such bounding box already available

Hence, are there any experiments around where prompt was set to some pre-fixed static value or skipped completely and SAM decoder was fine-tuned to segment only the object of interest.

The end-result I want, is a solution where during inference, I should not be needed to pass any prompt as input.

@bertinma
Copy link

Question: Can we fine-tune the decoder to infer without prompt ?

Possible Solution (Just thinking out loud): Fine-Tuning without Prompt or fixed static prompt >> Hence infer without prompt using fine-tuned model?

  • Every tutorial I saw has one thing in common: The fine-tuning SAM model for specific object also requires the prompt as input. During fine-tuning it is easier because we can take bounding-box of ground-truth mask and provide that as prompt-input.
  • But during inference, we don't have any such bounding box already available

Hence, are there any experiments around where prompt was set to some pre-fixed static value or skipped completely and SAM decoder was fine-tuned to segment only the object of interest.

The end-result I want, is a solution where during inference, I should not be needed to pass any prompt as input.

It is what I did. Finetune without any prompt because I do not have during inference time.
Therefore, I have to predict specific classes so I recreate a decoder with n_classes masks as output. So, I train the decoder from scratch which is very long to converge. It seems to not be the best model for this task.
Imo, it's powerful for automatic segmentation task in an annotation tool or trimming object in PS etc..

@yogendra-yatnalkar
Copy link

@bertinma Thanks for your inputs. I was actually planning to try this out over the weekend.

If possible, can you please share some insights on how you fine-tuned the decoder ? As in:

  1. Did you give a static point as input ? (example: center of image or something like that )
  2. Or did you change the decoder architecture and removed all parts where prompt encoding was getting used ?
  3. Or did you completely wrote a custom decoder which upsampled 256x64x64 encoder embedding size to original image size using convTranspose2d or something else ?

Thanks in advance and thanks again for sharing your initial inputs .

@bertinma
Copy link

@yogendra-yatnalkar I do not give any prompt. All my prompt values are None as points/masks/bboxes.
I took a deep look to give you the best answer and discovered I'm wrong. I have to pass a list of points to embeddings computation function.

@yogendra-yatnalkar
Copy link

Yes, as you said, the SAM models take N point prompts and converts it to shape: Nx256. Later, the encoder embedding and prompt embedding have 2 cross-attention blocks in the decoder.

So since your model did not converge and during fine-tuning did not throw an error as well, guessing from my current shallow knowledge:

  • when we not pass any prompt, it might have selected a default prompt (usually a grid of points) and trained on it ?
  • Not sure how this will go, but I will try to write a custom decoder and train it from scratch (directly from encoder embedding)

Thanks again for your insights, this will surely help me.

@TKPhuong
Copy link

@yogendra-yatnalkar I plan to do something similar. May you share with me your result ? Did you succeed ? If you did, may I ask how have you implemented that ? Thank you very much.

@yogendra-yatnalkar
Copy link

Hi @TKPhuong, yes, I tried it on a very small sample and the results are very promising. I currently tried it on some data which I cannot share. I will surely replicate it on some kaggle or open-source dataset over the weekend and share that over here.

@AdrienneBerghAMAT
Copy link

Hi @yogendra-yatnalkar, interested to hear if you ended up replicating it with an open-source dataset? Thanks.

@yogendra-yatnalkar
Copy link

yogendra-yatnalkar commented Oct 31, 2023

Hi @TKPhuong @bertinma @AdrienneBerghAMAT Sorry for the late response. I actually had to opensource this NB with a small blog but I did not find time for that hence sharing the kaggle NB right away:


How does SAM work (high-level):

  • Sam Encoder --> ViT + Neck-Module (Consisting of 2 Conv2D layers used for downsampling the channels of the ViT output)
  • Input: 1024x1024x3
  • Output of Encoder: 256x64x64
  • This output goes into the decoder with Prompt Input and generates the output

What I tried with model:

  • Removed the decoder
  • Freeze the ViT part of encoder and un-freeze the Conv2d Neck
  • Add a custom decoder having multiple blocks of: Conv2d-Transpose + LayerNorm2D + Relu + Dropout --> Added 4 such blocks
  • The input to the decoder will be of shape: 256x64x64 and the output will be of shape: 1024x1024x1

Training:

  • I trained this SAM+Custom-Decoder model on a open kaggle dataset consisting of binary segmentation
  • Dataset has 1620 images.
  • To prove SAM's capability, I trained this model only on 135 images, ie around 8.3% of the total data just for 11 epochs

Results:

  • The results are very promising as on a completely random image of test-set, have a look at the output:
  • (Left-most image is the ground-truth, middle image is the model prediction, right-most image is the input)

image
I really don't think that if we train any other model just for 11 epochs and 135 images, it would be this good in generalization on a test-set image.


Future:

  • If this Encoder is combined with a model like: Mask2Former, according to my current understanding, this would show very good results.
  • The Decoder is very basic model currently but just proves that this can act as a good starting point for further experimentations

Will just convert this long text to a blog soon. Please let me know your thoughts on this experiments as well.

@crapthings
Copy link

crapthings commented Dec 6, 2023

1k images
--batch_size_train 16
2 x h100 sxm 80
sam_hq hq_huge

On-Demand $4.69/hr

this one is without 44k they mention in README

image

image

@sudo-install-MW
Copy link

Yes, as you said, the SAM models take N point prompts and converts it to shape: Nx256. Later, the encoder embedding and prompt embedding have 2 cross-attention blocks in the decoder.

So since your model did not converge and during fine-tuning did not throw an error as well, guessing from my current shallow knowledge:

  • when we not pass any prompt, it might have selected a default prompt (usually a grid of points) and trained on it ?
  • Not sure how this will go, but I will try to write a custom decoder and train it from scratch (directly from encoder embedding)

Thanks again for your insights, this will surely help me.

This is very useful, thanks @yogendra-yatnalkar

@jez-moxmo
Copy link

Has anyone tried the idea of what may be called "point prompt engineering"? i.e. training a separate model that learns how to put positive prompt points and negative prompt points, such that these points prompt SAM to select target objects from a custom dataset.

Or we can just summarize strategies and best practices in terms of placing positive and negative prompt points/prompt boxes, similar to how GPT/DALLE users summarize the best ways to write prompts.

Wonder if this could be one way to fine-tune the SAM model when only a limited amount of annotations are available. Happy to discuss more if anyone wants to work together and try it out.

https://github.com/vignywang/SAMFeat

@Jap8nted
Copy link

Jap8nted commented Jan 2, 2024

The blog that @alex-encord shared works fine if you have 1 bounding box per image. But in reality you will have multiple bounding box in an image. Below code fails if you have multiple bounding box

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

Am I missing something?

Hi @nahidalam I have same question as you, i.e., reality image will have multiple object bounding box. Should we put the multiple box together? Thanks in advance.

Hi @wm-Githuber , @nahidalam ,

Did you find an answer for this? Can one fine tune with a mask containing multiple objects and multiple boxes or how should one do this?

I created a training entry for each mask with one box and seems to work fine, but I wonder if using multiple boxes / objects is possible cause it would make training faster. Until now, nobody has done this.

@yogendra-yatnalkar
Copy link

Hi @TKPhuong @bertinma @AdrienneBerghAMAT Sorry for the late response. I actually had to opensource this NB with a small blog but I did not find time for that hence sharing the kaggle NB right away:

How does SAM work (high-level):

  • Sam Encoder --> ViT + Neck-Module (Consisting of 2 Conv2D layers used for downsampling the channels of the ViT output)
  • Input: 1024x1024x3
  • Output of Encoder: 256x64x64
  • This output goes into the decoder with Prompt Input and generates the output

What I tried with model:

  • Removed the decoder
  • Freeze the ViT part of encoder and un-freeze the Conv2d Neck
  • Add a custom decoder having multiple blocks of: Conv2d-Transpose + LayerNorm2D + Relu + Dropout --> Added 4 such blocks
  • The input to the decoder will be of shape: 256x64x64 and the output will be of shape: 1024x1024x1

Training:

  • I trained this SAM+Custom-Decoder model on a open kaggle dataset consisting of binary segmentation
  • Dataset has 1620 images.
  • To prove SAM's capability, I trained this model only on 135 images, ie around 8.3% of the total data just for 11 epochs

Results:

  • The results are very promising as on a completely random image of test-set, have a look at the output:
  • (Left-most image is the ground-truth, middle image is the model prediction, right-most image is the input)

image I really don't think that if we train any other model just for 11 epochs and 135 images, it would be this good in generalization on a test-set image.

Future:

  • If this Encoder is combined with a model like: Mask2Former, according to my current understanding, this would show very good results.
  • The Decoder is very basic model currently but just proves that this can act as a good starting point for further experimentations

Will just convert this long text to a blog soon. Please let me know your thoughts on this experiments as well.

Hi readers, just 2 small updates from my previous message.

  1. Update-1:
    - My code had a small issue. My logic of un-freezing the 2 conv2d blocks was not getting implemented. Some good person on kaggle corrected it and the performance has increased slightly
    - Earlier the IoU score on my test-set was around 89%, now its 91% ++ The edges of segmentation maps have improved.

  2. Update-2:
    - I have added visual explanation on what I did (PS: By no means my way of doing this is very unique)
    SAM-promptless-task-specific-finetuning

The kaggle NB: https://www.kaggle.com/code/yogendrayatnalkar/promptless-taskspecific-finetuning-of-metaai-sam

@Raspberry-beans
Copy link

@yogendra-yatnalkar Thanks for the nice work.

Can you specify the GPU memory usage of fine tuning either SAM original mask decoder or your custom decoder without using prompts.

I will be fine tuning SAM mask decoder only on my few shot medical images (maximum 50). But I have only 8GB GPU memory allocated from my university. Hence I am confused that 8GB would be enough or not to fine tune only mask decoder with few medical images.

Thanks!

@williamhoole
Copy link

Hi, I have finetuned a Sam model on a custom dataset, however i want to use the AutomaticMaskGenerator. When loading my finetuned model i get the following error: AttributeError: 'SamModel' object has no attribute 'image_encoder'

Does anyone know why this is and are there any solutions ?

@RMNT
Copy link

RMNT commented Mar 20, 2024

@williamhoole I do not know the answer, but encounter the same error

@RMNT
Copy link

RMNT commented Mar 20, 2024

I'm trying to do a fine-tuning on the semantic segmentation task without bbox, but only labels. I resized the image and label to be 1024x1024, and use the following codes:

    # Assign data
    img, label, o_img_size, n_img_size = batch['image'], batch['label'], batch['original_image_size'], batch['image_size']
    
    # Map to variables
    img = Variable(img)
    label = Variable(label)

    # Get embeddings
    with torch.no_grad():
        image_embedding = self.model.image_encoder(img)
        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
            points=None,boxes=None,masks=label.float())
    
    # Get predictions
    low_res_masks, iou_predictions = self.model.mask_decoder(
        image_embeddings=image_embedding,
        image_pe=self.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
        )
    upscaled_masks = self.model.postprocess_masks(low_res_masks, n_img_size, o_img_size)

`

I modify the decoder head like this:

    net = sam_model_registry[args.sam_model_type](checkpoint=args.ckpt)
    d = net.mask_decoder
    net.mask_decoder = MaskDecoder(transformer_dim=d.transformer_dim, transformer=d.transformer, num_multimask_outputs=args.num_classes)

    return net

However, I got errors from the decoder embeddings fusions:

    from segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
    src = src + dense_prompt_embeddings
    RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

In this case, my sparse_embeddings has shape [B,0,256], and dense_embeddings has shape [B, 256, 256, 256]. Anyone has ideas?

Where do you import the MaskDecoder from? Because I get this error when using from segment_anything.modeling import MaskDecoder:

Traceback (most recent call last):
  File "/home/raminta/PROJECTS/Raminta/farmvibes/main.py", line 94, in <module>
    transformer_dim=d.transformer_dim,
  File "/home/raminta/anaconda3/envs/farmvibes/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'SamMaskDecoder' object has no attribute 'transformer_dim'

@RMNT
Copy link

RMNT commented Mar 27, 2024

I'm trying to do a fine-tuning on the semantic segmentation task without bbox, but only labels. I resized the image and label to be 1024x1024, and use the following codes:

    # Assign data
    img, label, o_img_size, n_img_size = batch['image'], batch['label'], batch['original_image_size'], batch['image_size']
    
    # Map to variables
    img = Variable(img)
    label = Variable(label)

    # Get embeddings
    with torch.no_grad():
        image_embedding = self.model.image_encoder(img)
        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
            points=None,boxes=None,masks=label.float())
    
    # Get predictions
    low_res_masks, iou_predictions = self.model.mask_decoder(
        image_embeddings=image_embedding,
        image_pe=self.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
        )
    upscaled_masks = self.model.postprocess_masks(low_res_masks, n_img_size, o_img_size)

I modify the decoder head like this:

    net = sam_model_registry[args.sam_model_type](checkpoint=args.ckpt)
    d = net.mask_decoder
    net.mask_decoder = MaskDecoder(transformer_dim=d.transformer_dim, transformer=d.transformer, num_multimask_outputs=args.num_classes)

    return net

However, I got errors from the decoder embeddings fusions:

    from segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
    src = src + dense_prompt_embeddings
    RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

In this case, my sparse_embeddings has shape [B,0,256], and dense_embeddings has shape [B, 256, 256, 256]. Anyone has ideas?

Where do you import the MaskDecoder from? Because I get this error when using from segment_anything.modeling import MaskDecoder:

Traceback (most recent call last):
  File "/home/raminta/PROJECTS/Raminta/farmvibes/main.py", line 94, in <module>
    transformer_dim=d.transformer_dim,
  File "/home/raminta/anaconda3/envs/farmvibes/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'SamMaskDecoder' object has no attribute 'transformer_dim'

@williamhoole
Copy link

I am currently out of office, but I figured out what caused my error.

I fine tuned my model using the segment anything model from hugging face. The issue is that hugging face models mask decoder is named differently. So when fine tuning the model using hugging faces SAM model and weights you cannot use Metas SAM automatic mask generator.

I resolved my issue by using hugging faces automatic mask generation pipeline

beginning of next week I am back in office and can take a closer look at your issue. Hopefully this helps you further for now.

@malotaibi44
Copy link

I'm trying to do a fine-tuning on the semantic segmentation task without bbox, but only labels. I resized the image and label to be 1024x1024, and use the following codes:

    # Assign data
    img, label, o_img_size, n_img_size = batch['image'], batch['label'], batch['original_image_size'], batch['image_size']
    
    # Map to variables
    img = Variable(img)
    label = Variable(label)

    # Get embeddings
    with torch.no_grad():
        image_embedding = self.model.image_encoder(img)
        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
            points=None,boxes=None,masks=label.float())
    
    # Get predictions
    low_res_masks, iou_predictions = self.model.mask_decoder(
        image_embeddings=image_embedding,
        image_pe=self.model.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
        )
    upscaled_masks = self.model.postprocess_masks(low_res_masks, n_img_size, o_img_size)

I modify the decoder head like this:

    net = sam_model_registry[args.sam_model_type](checkpoint=args.ckpt)
    d = net.mask_decoder
    net.mask_decoder = MaskDecoder(transformer_dim=d.transformer_dim, transformer=d.transformer, num_multimask_outputs=args.num_classes)

    return net

However, I got errors from the decoder embeddings fusions:

    from segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
    src = src + dense_prompt_embeddings
    RuntimeError: The size of tensor a (64) must match the size of tensor b (256) at non-singleton dimension 3

In this case, my sparse_embeddings has shape [B,0,256], and dense_embeddings has shape [B, 256, 256, 256]. Anyone has ideas?

Where do you import the MaskDecoder from? Because I get this error when using from segment_anything.modeling import MaskDecoder:

Traceback (most recent call last):
  File "/home/raminta/PROJECTS/Raminta/farmvibes/main.py", line 94, in <module>
    transformer_dim=d.transformer_dim,
  File "/home/raminta/anaconda3/envs/farmvibes/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1688, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'SamMaskDecoder' object has no attribute 'transformer_dim'

I squeezed the mask variable to have size of (# of batches, 1,size of mask,size of mask) rather than (# of batches, 1,1 ,size of mask,size of mask)
and it works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests