Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Train DETR using small dataset (3k examples) #419

Open
micheleantonazzi opened this issue Jul 22, 2021 · 1 comment
Open

Train DETR using small dataset (3k examples) #419

micheleantonazzi opened this issue Jul 22, 2021 · 1 comment

Comments

@micheleantonazzi
Copy link

micheleantonazzi commented Jul 22, 2021

❓ How to fine-tune DETR using a small dataset (3k examples)

Hi everyone,
I'm using DETR in my master's thesis: it concerns the development of a door recognizer.
This is my first experience with transformers, so I would like to share the results I have obtained with my experiments and ask for advice to improve them.
Before using my doors' dataset, I'm testing DETR on a public dataset used in other research works. It is called DeepDoors2.
It has 3k examples, divided into 3 different types of doors: closed, opened, and semi-opened.

Retrain DETR

I tried to retrain DETR using the default parameters specified in your main.py.
I loaded it using the following code:

self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=False)
self.model.query_embed = nn.Embedding(10, 256)
self.model.class_embed = nn.Linear(self.model.transformer.d_model, 4)

As you can see, I replaced the class_embed layer and I reduced the object queries to 10.
As reported in other issues (#125, #9), retrain DETR using less than 10-15k samples is not recommended. In fact, my results are not good, as shown by the following figure.
losses

Fine-tune class_embed and box_embed

In the next experiment, I have fine-tuned only the layers that return the predicted labels and the bounding boxes.
I created the model using the following code:

self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
for p in self.mode.parameters():
   p.requires_grad = False
self.model.class_embed = nn.Linear(256, 4)
self.model.bbox_embed = MLP(256, 256, 4, 3)

The obtained results are better than the previous ones, as shown by the figure.
losses_2

However, the model does not seem to learn more and the bounding boxes found are not good.

Fine-tune the entire trained model (work in progress)

The last experiment concerns the fine-tuning of the entire trained model. The experiment is in progress, I will publish the results soon.
I loaded the model using the following code:

self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
self.model.class_embed = nn.Linear(256, 4)

I have changed the following parameters:

  • learning rate: 2e-5
  • backbone learning rate: 1e-5
  • bbox_loss_coef: 1
  • giou_loss_coef: 1
  • eos_coef: 0.5
  • set_cost_class: 1
  • set_cost_bbox: 1
  • set_cost_giou: 1

Questions

  1. In the first two experiments, could there be some error? Are their results plausible?
  2. The last experiment is the promising one, as reported in #125. Are the parameters I have set correct? Would you have any other suggestions to obtain better results?

NB: I performed all the experiments applying the same data augmentation of DETR (crop, resize, flip ecc) and with a batch size of 1.

Thank you so much in advance for any suggestions.

@micheleantonazzi micheleantonazzi changed the title Train DETR using small dataset (3k esamples) Train DETR using small dataset (3k examples) Jul 22, 2021
@NielsRogge
Copy link

Hi,

looking at this, I would definitely go for option 3 (this is also what's done with models like BERT). Typically, the classification head is trained with the model jointly to get the best performance.

Btw, I have a notebook that illustrates how to easily fine-tune DETR on a custom dataset (balloons): https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/Fine_tuning_DetrForObjectDetection_on_custom_dataset_(balloon).ipynb

It's part of my Transformers-tutorials, where I show how to easily fine-tune Transformer-based models on custom data using HuggingFace Transformers. I've added DETR myself to that library, I've implemented both DetrForObjectDetection and DetrForPanopticSegmentation.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants