Skip to content

Commit

Permalink
Add expected output to the sample code for `ViTMSNForImageClassificat…
Browse files Browse the repository at this point in the history
…ion` (#19183)

* chore: add expected output to the sample code.

* add: imagenet-1k labels to the model config.

* chore: apply code formatting.

* chore: change the expected output.
  • Loading branch information
sayakpaul committed Sep 30, 2022
1 parent 368b649 commit 582d085
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/transformers/models/vit_msn/convert_msn_to_pytorch.py
Expand Up @@ -15,11 +15,13 @@
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""

import argparse
import json

import torch
from PIL import Image

import requests
from huggingface_hub import hf_hub_download
from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

Expand Down Expand Up @@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = ViTMSNConfig()
config.num_labels = 1000

repo_id = "datasets/huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

if "s16" in checkpoint_url:
config.hidden_size = 384
config.intermediate_size = 1536
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/vit_msn/modeling_vit_msn.py
Expand Up @@ -632,6 +632,8 @@ def forward(
>>> from PIL import Image
>>> import requests
>>> torch.manual_seed(2)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
Expand All @@ -644,6 +646,7 @@ def forward(
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
Kerry blue terrier
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

Expand Down

0 comments on commit 582d085

Please sign in to comment.