Skip to content
Merged

Resnet #15770

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
f2593f3
first commit
FrancescoSaverioZuppichini Feb 22, 2022
cc1d1c2
ResNet model correctly implemented.
FrancescoSaverioZuppichini Feb 22, 2022
1a6d045
minor changes
FrancescoSaverioZuppichini Feb 23, 2022
b0089cf
Apply suggestions from code review
FrancescoSaverioZuppichini Feb 24, 2022
c55e5fc
new test format
FrancescoSaverioZuppichini Feb 24, 2022
9a9e3c7
minor changes from conversations
FrancescoSaverioZuppichini Feb 24, 2022
b54dc28
minor changes from conversations
FrancescoSaverioZuppichini Feb 24, 2022
d95c607
make style + quality
FrancescoSaverioZuppichini Feb 24, 2022
c30f3e9
readded the tests
FrancescoSaverioZuppichini Feb 24, 2022
c5489df
test + README
FrancescoSaverioZuppichini Feb 24, 2022
ca2fbbb
minor changes from conversations
FrancescoSaverioZuppichini Feb 24, 2022
9a17143
error in README
FrancescoSaverioZuppichini Feb 24, 2022
bc26ea3
make fix-copies
FrancescoSaverioZuppichini Feb 24, 2022
a60983a
removed regression for classification head
FrancescoSaverioZuppichini Feb 25, 2022
5c05428
make quality
FrancescoSaverioZuppichini Feb 25, 2022
4accbd7
fixed loss control flow
FrancescoSaverioZuppichini Feb 25, 2022
7eedf7f
fixed loss control flow
FrancescoSaverioZuppichini Feb 25, 2022
db8c38c
resolved conversations
FrancescoSaverioZuppichini Feb 25, 2022
3e2b79c
Apply suggestions from code review
FrancescoSaverioZuppichini Feb 25, 2022
1c6885d
READMEs
FrancescoSaverioZuppichini Feb 25, 2022
79b2ba0
index.mdx
FrancescoSaverioZuppichini Feb 25, 2022
eae2cef
minor changes
FrancescoSaverioZuppichini Feb 25, 2022
717cfd5
updated tests and models
FrancescoSaverioZuppichini Mar 1, 2022
69240c4
unused import
FrancescoSaverioZuppichini Mar 1, 2022
45db40c
outputs
FrancescoSaverioZuppichini Mar 1, 2022
a0597a5
Update docs/source/model_doc/resnet.mdx
FrancescoSaverioZuppichini Mar 1, 2022
d37ea95
added embeddings_size
FrancescoSaverioZuppichini Mar 3, 2022
2c4034f
Apply suggestions from code review
FrancescoSaverioZuppichini Mar 3, 2022
d793bb7
conversation
FrancescoSaverioZuppichini Mar 3, 2022
e18d9da
added push to hub
FrancescoSaverioZuppichini Mar 3, 2022
398642f
test
FrancescoSaverioZuppichini Mar 3, 2022
b4b8c83
embedding_size
FrancescoSaverioZuppichini Mar 3, 2022
5c89040
make fix-copies
FrancescoSaverioZuppichini Mar 3, 2022
ee8783b
resolved conversations
FrancescoSaverioZuppichini Mar 3, 2022
7a6febf
CI
FrancescoSaverioZuppichini Mar 7, 2022
974337c
changed organization
FrancescoSaverioZuppichini Mar 8, 2022
8a8ca31
minor changes
FrancescoSaverioZuppichini Mar 8, 2022
8517727
CI
FrancescoSaverioZuppichini Mar 8, 2022
367e886
minor changes
FrancescoSaverioZuppichini Mar 8, 2022
73d3a66
conversations
FrancescoSaverioZuppichini Mar 10, 2022
a5d9f29
conversation
FrancescoSaverioZuppichini Mar 10, 2022
7e08499
doc
FrancescoSaverioZuppichini Mar 10, 2022
d8553c3
tests
FrancescoSaverioZuppichini Mar 10, 2022
1a4521c
removed unused docstring
FrancescoSaverioZuppichini Mar 10, 2022
3a3ae6d
conversation
FrancescoSaverioZuppichini Mar 11, 2022
8708d65
removed unused outputs
FrancescoSaverioZuppichini Mar 11, 2022
2b4c8bc
CI
FrancescoSaverioZuppichini Mar 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
Expand Down
1 change: 1 addition & 0 deletions README_ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
Expand Down
1 change: 1 addition & 0 deletions README_zh-hans.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ conda install -c huggingface transformers
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (来自 Google Research) 伴随论文 [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) 由 Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang 发布。
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (来自 Google Research) 伴随论文 [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) 由 Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya 发布。
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。
Expand Down
1 change: 1 addition & 0 deletions README_zh-hant.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ conda install -c huggingface transformers
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](https://huggingface.co/docs/transformers/model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](https://huggingface.co/docs/transformers/model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](https://huggingface.co/docs/transformers/master/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@
title: Reformer
- local: model_doc/rembert
title: RemBERT
- local: model_doc/resnet
title: ResNet
- local: model_doc/retribert
title: RetriBERT
- local: model_doc/roberta
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ conversion utilities for the following models.
1. **[REALM](https://huggingface.co/transformers/model_doc/realm.html)** (from Google Research) released with the paper [REALM: Retrieval-Augmented Language Model Pre-Training](https://arxiv.org/abs/2002.08909) by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.
1. **[Reformer](model_doc/reformer)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RemBERT](model_doc/rembert)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/abs/2010.12821) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[ResNet](model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
Expand Down Expand Up @@ -230,6 +231,7 @@ Flax), PyTorch, and/or TensorFlow.
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
50 changes: 50 additions & 0 deletions docs/source/model_doc/resnet.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# ResNet

## Overview

The ResNet model was proposed in [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. Our implementation follows the small changes made by [Nvidia](https://catalog.ngc.nvidia.com/orgs/nvidia/resources/resnet_50_v1_5_for_pytorch), we apply the `stride=2` for downsampling in bottleneck's `3x3` conv and not in the first `1x1`. This is generally known as "ResNet v1.5".

ResNet introduced residual connections, they allow to train networks with an unseen number of layers (up to 1000). ResNet won the 2015 ILSVRC & COCO competition, one important milestone in deep computer vision.

The abstract from the paper is the following:

*Deeper neural networks are more difficult to train. We present a residual learning framework to ease the training of networks that are substantially deeper than those used previously. We explicitly reformulate the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. We provide comprehensive empirical evidence showing that these residual networks are easier to optimize, and can gain accuracy from considerably increased depth. On the ImageNet dataset we evaluate residual nets with a depth of up to 152 layers---8x deeper than VGG nets but still having lower complexity. An ensemble of these residual nets achieves 3.57% error on the ImageNet test set. This result won the 1st place on the ILSVRC 2015 classification task. We also present analysis on CIFAR-10 with 100 and 1000 layers.
The depth of representations is of central importance for many visual recognition tasks. Solely due to our extremely deep representations, we obtain a 28% relative improvement on the COCO object detection dataset. Deep residual nets are foundations of our submissions to ILSVRC & COCO 2015 competitions, where we also won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation.*

Tips:

- One can use [`AutoFeatureExtractor`] to prepare images for the model.

The figure below illustrates the architecture of ResNet. Taken from the [original paper](https://arxiv.org/abs/1512.03385).

<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/>

This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).

## ResNetConfig

[[autodoc]] ResNetConfig


## ResNetModel

[[autodoc]] ResNetModel
- forward


## ResNetForImageClassification

[[autodoc]] ResNetForImageClassification
- forward
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"models.realm": ["REALM_PRETRAINED_CONFIG_ARCHIVE_MAP", "RealmConfig", "RealmTokenizer"],
"models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"],
"models.rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"],
"models.resnet": ["RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ResNetConfig"],
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"],
Expand Down Expand Up @@ -1323,6 +1324,14 @@
"load_tf_weights_in_rembert",
]
)
_import_structure["models.resnet"].extend(
[
"RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"ResNetForImageClassification",
"ResNetModel",
"ResNetPreTrainedModel",
]
)
_import_structure["models.retribert"].extend(
["RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "RetriBertModel", "RetriBertPreTrainedModel"]
)
Expand Down Expand Up @@ -2573,6 +2582,7 @@
from .models.realm import REALM_PRETRAINED_CONFIG_ARCHIVE_MAP, RealmConfig, RealmTokenizer
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .models.resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
Expand Down Expand Up @@ -3451,6 +3461,12 @@
RemBertPreTrainedModel,
load_tf_weights_in_rembert,
)
from .models.resnet import (
RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ResNetForImageClassification,
ResNetModel,
ResNetPreTrainedModel,
)
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,31 @@ class SemanticSegmentationModelOutput(ModelOutput):
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class ImageClassifierOutput(ModelOutput):
"""
Base class for outputs of image classification models.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
realm,
reformer,
rembert,
resnet,
retribert,
roberta,
roformer,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
("maskformer", "MaskFormerConfig"),
("poolformer", "PoolFormerConfig"),
("convnext", "ConvNextConfig"),
("resnet", "ResNetConfig"),
("yoso", "YosoConfig"),
("swin", "SwinConfig"),
("vilt", "ViltConfig"),
Expand Down Expand Up @@ -133,6 +134,7 @@
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("resnet", "RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("swin", "SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("vilt", "VILT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -220,6 +222,7 @@
("maskformer", "MaskFormer"),
("poolformer", "PoolFormer"),
("convnext", "ConvNext"),
("resnet", "ResNet"),
("yoso", "YOSO"),
("swin", "Swin"),
("vilt", "ViLT"),
Expand Down
Loading