Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add ResNet-RS models #554

Merged
merged 22 commits into from May 4, 2021
Merged

Conversation

amaarora
Copy link
Contributor

No description provided.

@amaarora amaarora changed the title Add ResNet-RS models [WIP] Add ResNet-RS models Apr 10, 2021
@amaarora
Copy link
Contributor Author

amaarora commented Apr 10, 2021

Hey @rwightman !

I am almost there with contributing Resnet-RS models to timm (my first model contribution)! Starting this [WIP] PR to get things wrapped up.

I am very confident that the only things remaining in this implementation currently are:

  1. Pretrained model weights which have been officially released here.
  2. Switching the stride sizes for the first two convolutions in the residual path of the downsampling block (mentioned in Section 4.1 of Resnet-RS paper.

Everything else according to me has been covered already. As per section 4.1 of the paper:

First, the 7×7 convolution in the stem is replaced by three smaller 3×3 convolutions, as first proposed in InceptionV3 (Szegedy et al., 2016).

This has been done by adding stem_type='deep' to model_args.

Second, the stride sizes are switched for the first two convolutions in the residual path of the downsampling blocks.

This hasn't been covered yet.

Third, the stride-2 1×1 convolution in the skip connection path of the downsampling blocks is replaced by stride-2 2×2 average pooling and then a non-strided 1×1 convolution.

Covered off by setting avg_down=True similar to existing resnet50d architecture.

Fourth, the stride2 3×3 max pool layer is removed and the downsampling occurs in the first 3×3 convolution in the next bottleneck block.

Covered off by adding a new parameter skip_stem_max_pool to model config and first_conv_stride to make_blocks function.

For all experiments we use a Squeeze-and-Excitation ratio of 0.25 based on preliminary experiments.

Covered by adding block_args=dict(attn_layer='se') similar to senets


Since I haven't ported weights from TF to PyTorch before, I was hoping for some guidance on how to do this.

Also FYI.
I have kicked off a run on ImageNette and will later kick off a run on ImageNet for resnetrs50 to begin with and compare performance with resnet-50.

Pretty happy with the implementation thus far, but keen to hear your feedback too.

@amaarora
Copy link
Contributor Author

Ok @rwightman, happy days! The implementation performs better on Imagenette. The graph below compares renetrs50 with resnet50. Resnet-RS 50 gets 85.9% top-1 compared to 84.5% for Resnet-50. I have kicked off a training run for ImageNet but given I have only single V100 - it should take at least 3-4 days before we see results.

image

@JulienMaille
Copy link

Do you have any plans for smaller resnet (18, 34)? Does it make sense?

@amaarora
Copy link
Contributor Author

Do you have any plans for smaller resnet (18, 34)? Does it make sense?

Hi @JulienMaille :)

The smalles ResNet-RS model is ResNet 50 FYI

@rwightman
Copy link
Collaborator

@amaarora I thought that was imagenet 85.9 for a sec then realized it was imagenette ;)

One thought re the stem config, do the RS models push the stride 2 into the blocks or replace the maxpool with a conv? I thought they replaced? Or do they do both depending on the model size?

@JulienMaille don't think it makes much sense to define a basic block RS model. I generally use a ResNet26 def as the smallest bottleneck (2,2,2,2) like 34 but with bottleneck. Although I've seen 14s (1,1,1,1).

An FYI, there are already resnet models better than these RS ones here. I trained an ecaresnet26t that's almost 80% top-1, ecaresnet50t that's over 82, resnet50d that's close to 80.5. The RS models are SE so the 50d score is impressive relative and the ECA are roughly similar in terms of throughput to an SE but lower param count. I have some comparable SE models as well but only a larger ones like the seresnet152d (84.35 vs 83 for the best 152 RS) was trained with recent hparams.

@rwightman
Copy link
Collaborator

Basically I've already been exploring the ideas in the RS paper here, heavier augs and regularization w/ different resolution scaling. I went a bit further with aug/regularization.

@rwightman
Copy link
Collaborator

@amaarora also, the stem changes will likely break the feature extraction without changing feature_info locations based on the stem config

@amaarora
Copy link
Contributor Author

One thought re the stem config, do the RS models push the stride 2 into the blocks or replace the maxpool with a conv? I thought they replaced?

Good catch! This has been a little confusing but going by the implementation, you're right. They replace the maxpool with stride-2 3x3 conv.

This can be inferred from the Resnet-RS config here where replace_stem_max_pool is set to True. Based on this config, in the code implementation here, the MaxPool is replaced by stride 2 3x3 conv.

FYI, this is in conflict to the paper where they write in section 4.1 under Resnet-D architecture subheading:

Fourth, the stride2 3×3 max pool layer is removed and the downsampling occurs in the first 3×3 convolution in the next bottleneck block.

Note the usage of the word "removed" and not "replaced".

I guess in such cases we follow the TF implementation and take it to be source of truth? @rwightman

@amaarora
Copy link
Contributor Author

FYI I am working on fixing the failed tests.

@amaarora
Copy link
Contributor Author

amaarora commented Apr 12, 2021

@rwightman Finally passing! :)

How do you add pretrained weights please? If I can now add pretrained weights to these models, then they are ready to be merged IMHO.

EDIT* Please ignore above.. It's not working still and not sure why

@amaarora
Copy link
Contributor Author

I keep getting error code 137. It appears as though that’s due to OOM error and not quite code change.

Source from here https://stackoverflow.com/questions/43268156/process-finished-with-exit-code-137-in-pycharm

Still trying to figure out why and can’t quite replicate it unless I run all tests for all models.

@amaarora amaarora reopened this May 2, 2021
@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@rwightman - the tensorflow convolution weights have a momentum term in them. Do you know what to do with these please? I have been able to match the tf kernel weights to conv but not sure what to do with ema and momentum weights. See example below:

...
conv2d_58/kernel
conv2d_58/kernel/ExponentialMovingAverage
conv2d_58/kernel/Momentum
conv2d_59/kernel
conv2d_59/kernel/ExponentialMovingAverage
conv2d_59/kernel/Momentum
...

@rwightman
Copy link
Collaborator

rwightman commented May 3, 2021 via email

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

Thanks @rwightman !! I am very close to getting the model weights in now.

Just last question - I am left with num_batches_tracked kind state dict values. I don;t know what's an equivalent in TF.

From https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py, I have mapped:

 mapping_bn[module_name+'.weight'] = tf_layer_name + '/gamma'
 mapping_bn[module_name+'.bias'] = tf_layer_name + '/beta'
 mapping_bn[module_name+'.running_mean'] = tf_layer_name + '/moving_mean'
 mapping_bn[module_name+'.running_var'] = tf_layer_name + '/moving_variance'

Not sure what to do with these:

....
'layer3.5.bn3.num_batches_tracked',
'layer4.0.bn1.num_batches_tracked',
'layer4.0.bn2.num_batches_tracked',
'layer4.0.bn3.num_batches_tracked',
....

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

Okay done! I have loaded tensorflow weights to PyTorch, now I am trying to find how to benchmark the model and test that the model weights actually have around 80% top-1 accuracy. :)

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@rwightman 79.114 TOP-1 Accuracy on ImageNet 1k for Resnetrs-50! Hell yeah! :D

Just to confirm - here's what I did:
I used the timm training script train.py, but didn't call train_one_epoch and only called validate. The summary.csv says 79.114% top-1 ! I think that's right?

Now let me get the weights for all resnet-rs models - it's pretty easy from here on. How do we upload the pretrained weights?

Would be really helpful if you could also benchmark or get another set of eyes for QA.

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

Thanks @rwightman !! I am very close to getting the model weights in now.

Just last question - I am left with num_batches_tracked kind state dict values. I don;t know what's an equivalent in TF.

From https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py, I have mapped:

 mapping_bn[module_name+'.weight'] = tf_layer_name + '/gamma'
 mapping_bn[module_name+'.bias'] = tf_layer_name + '/beta'
 mapping_bn[module_name+'.running_mean'] = tf_layer_name + '/moving_mean'
 mapping_bn[module_name+'.running_var'] = tf_layer_name + '/moving_variance'

Not sure what to do with these:

....
'layer3.5.bn3.num_batches_tracked',
'layer4.0.bn1.num_batches_tracked',
'layer4.0.bn2.num_batches_tracked',
'layer4.0.bn3.num_batches_tracked',
....

I just ignored these num_batches_tracked weights.

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

All model weights are here. https://www.kaggle.com/aroraaman/resnetrs

You should now be able to do:

import timm 
m = timm.create_model('resnetrs50') 
m.load_state_dict(torch.load('<path to weights>')

@rwightman
Copy link
Collaborator

rwightman commented May 3, 2021

@amaarora congrats! yeah, num_batches_tracked is a pytorch specific thing that can be ignored for this...

you know there is a validate.py right? :) you can just call, try different image size, crop pct, image interpolation to find what the best is .... one thing that's not clear from the paper and table in the official repo is whether each validation result is at the train size (ie 160x160) for the 50, or if they're doing the train - test res thing and testing at a higher res, ie 224 when trained at 160.

@rwightman
Copy link
Collaborator

@amaarora FYI, the kaggle weight link doesn't work

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@amaarora FYI, the kaggle weight link doesn't work

Sorry it's cause the dataset is private. Could you please try again? :)

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

you know there is a validate.py right?

Must be the excitement of getting resnetrs weights to work that I just completely didn't see it. :D

one thing that's not clear from the paper and table in the official repo is whether each validation result is at the train size (ie 160x160) for the 50, or if they're doing the train - test res thing and testing at a higher res, ie 224 when trained at 160.

I do agree. I ran benchmarks yesterday testing at default 224x224 for all models and found validation accuracy to be within 1% of the reported results here .

@rwightman
Copy link
Collaborator

@amaarora for the models that have checkpoints for more than one size, did you include only the largest one?

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@rwightman Sorry - I should have clarified. I went for the smallest one - but can totally update to have the largest ones? Or both? It should take only a couple minutes from here on to update the weights.

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

FYI here is the messy script that I wrote to port the weights from TF.

@rwightman
Copy link
Collaborator

rwightman commented May 3, 2021

@amaarora k, good, that makes sense (with regards to what I observed), I picked a few to do a quick eval and the 152 wasn't checking out but the others were. I think the largest for each model makes sense, don't see much value in having them all, and def prefer having the best possible one for each arch

@@ -318,7 +334,7 @@ class Bottleneck(nn.Module):

def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwightman FYI this is change that you might want to review specifically. I have added **kwargs to Bottleneck block that get passed to attention layers.

This is to pass in reduction_ratio=0.25 for se layers as mentioned in the paper.

@@ -341,7 +357,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, b
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)

self.se = create_attn(attn_layer, outplanes)
self.se = create_attn(attn_layer, outplanes, **kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These **kwargs get passed to Bottleneck using block_args in model_config as in L1112 model definition for resnetrs152

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@amaarora k, good, that makes sense (with regards to what I observed), I picked a few to do a quick eval and the 152 wasn't checking out but the others were. I think the largest for each model makes sense, don't see much value in having them all, and def prefer having the best possible one for each arch

Okay thanks! I'll go back and port the weights for largest ones where there are two set of weights available and run benchmark scripts to share the better performing model in my Kaggle dataset - does this sound good? Won't be long.

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@amaarora
Copy link
Contributor Author

amaarora commented May 3, 2021

@rwightman Here are my benchmarking results.

ResNet RS 101

For resnetrs101_i192.pt I get 81.25 top-1 as opposed to 80.524 for resnetrs101.pt (this is i160). Both are evaluated on validation image size 224x224.
Let's use resnetrs101_i192.pt for ResNet-RS 101.

ResNet RS 152

For newly added resnetrs152_i224.pt, I get 81.68 ImageNet top-1 when predicting on image size 224x224. (Paper reports 82%)
For same weights, I get 82.474 top-1 when predicting on image size 256x256.
For resnetrs152_i256.pt, I get 81.732 top-1 when predicting on image size 256x256.
Let's s use resnetrs152_i224.pt weights for ResNet-RS 152.

ResNet-RS 350

For resnetrs350_i320.pt I get 83.224 top-1 for validation image size 320x320.
Whereas for resnetrs350.pt (this model was trained on i256), I get 83.426 top-1 for validation image size 256x256.

I suggest we use resnetrs350.pt for ResNet-RS 350.

@rwightman rwightman merged commit 560eae3 into huggingface:master May 4, 2021
@rwightman
Copy link
Collaborator

@amaarora merged and made a few additional changes... curious if you compared the EMA vs non-EMA weights on any of the models?

The 420 weights are pretty weak, actually 420 and the i320 350 and the 270 weights are kind of meh, the i256 350 are great by comparison (validate up to 84.4ish with some res scaling). I can't get the 420 past 84.2/84.3. So I'd be curious on a check to see if the 420 weight are any better for the non-EMA (assuming the one I have is the EMA).

@amaarora
Copy link
Contributor Author

amaarora commented May 4, 2021

Hey @rwightman - thanks for your help!

curious if you compared the EMA vs non-EMA weights on any of the models?

Sorry I didn't.

can't get the 420 past 84.2/84.3.

What you have are the non-EMA weights for all models. Please allow ~1hour to go back and share the EMA weights as a separate Kaggle dataset with you.

@rwightman
Copy link
Collaborator

@amaarora oh, okay, that could be a fairly significant difference, curious to see how it stacks up

@amaarora
Copy link
Contributor Author

amaarora commented May 4, 2021

@rwightman EMA weights uploaded here https://www.kaggle.com/aroraaman/resnetrsema

@amaarora
Copy link
Contributor Author

amaarora commented May 4, 2021

@rwightman
So the ResNet-RS 270 EMA weights get me 8358% top-1 (paper reports 83.8%) for IMG SIZE 256x256.
ResNet-RS 420 EMA weights get me 84.238% top-1 (paper reports 84.4%) for IMG SIZE 320x320
And the ResNet-RS 350 i320 EMA weights get me 84.312 top-1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants