Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Custom dataset training #45

Closed
ghost opened this issue Jun 3, 2020 · 3 comments
Closed

Custom dataset training #45

ghost opened this issue Jun 3, 2020 · 3 comments
Labels
question Further information is requested

Comments

@ghost
Copy link

ghost commented Jun 3, 2020

❓ How to do something using DETR

I am trying to train the resnet50 model with one more class on top of the coco dataset. So I loaded the pretrained model like this -

model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)

and then i am unfreezing class_embed and bbox_embed

for param in model.parameters():
    param.requires_grad = False

classifier_class = nn.Sequential(nn.Linear(256,128), 
                                 nn.ReLU(), 
                                 nn.Dropout(p=0.2), 
                                 nn.Linear(128,93), 
                                 #nn.LogSoftmax(dim=1)
                                 )

model.class_embed = classifier_class

classifier_bbox = nn.Sequential(nn.Linear(256,256), 
                                nn.ReLU(), 
                                nn.Dropout(p=0.2), 
                                nn.Linear(256,256),
                                nn.ReLU(),
                                nn.Dropout(p=0.2),
                                nn.Linear(256,4),
                                nn.Sigmoid()
                                )

And I am using build_model to get my criterion and postprocesses

dummy, criterion, postprocessors = build_model(data_args)

Optimizer:

optimizer = torch.optim.Adam([{'params': model.class_embed.parameters()}, 
                             {'params': model.bbox_embed.parameters()}], 
                             lr=data_args.lr, weight_decay=data_args.weight_decay)

Now I am loading only 'skyscraper' class using data_loader.

Unfortunately I am getting this error:

RuntimeError: weight tensor should be defined either for all or no classes at /pytorch/aten/src/THCUNN/generic/SpatialClassNLLCriterion.cu:27

Here is the entire code:
https://colab.research.google.com/drive/1L3PLEiOVICgmjyK6JIDjEBFmraVEQYhz?usp=sharing

@fmassa
Copy link
Contributor

fmassa commented Jun 3, 2020

Hi,

The error you are facing is because the number of classes is required for creating the criterion, see

self.num_classes = num_classes

and

detr/models/detr.py

Lines 95 to 97 in b7b62c0

empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer('empty_weight', empty_weight)

you can probably just overwrite those values with the number of classes you want, and this error should disappear.

@fmassa fmassa added the question Further information is requested label Jun 3, 2020
@ghost
Copy link
Author

ghost commented Jun 4, 2020

Hi,

Thank you! It worked. Now I am training the class_embed and bbox_embed layers for an additional class 'skyscraper'. The goal is to add additional classes to the 91 classes in the pre-trained model. Hopefully, it will work.

Thanks!

@fmassa
Copy link
Contributor

fmassa commented Jun 4, 2020

Great, keep us informed on how it works.

I believe I have answered your question, and as such I'm closing the issue, but please let us know if you have further questions / issues

@fmassa fmassa closed this as completed Jun 4, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant