From 582d085bb2c54e20907bfdfae24d0e9e37070ca6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 30 Sep 2022 18:55:41 +0530 Subject: [PATCH] Add expected output to the sample code for `ViTMSNForImageClassification` (#19183) * chore: add expected output to the sample code. * add: imagenet-1k labels to the model config. * chore: apply code formatting. * chore: change the expected output. --- .../models/vit_msn/convert_msn_to_pytorch.py | 9 +++++++++ src/transformers/models/vit_msn/modeling_vit_msn.py | 3 +++ 2 files changed, 12 insertions(+) diff --git a/src/transformers/models/vit_msn/convert_msn_to_pytorch.py b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py index 535f5f742d631..f04d26d5eb886 100644 --- a/src/transformers/models/vit_msn/convert_msn_to_pytorch.py +++ b/src/transformers/models/vit_msn/convert_msn_to_pytorch.py @@ -15,11 +15,13 @@ """Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn""" import argparse +import json import torch from PIL import Image import requests +from huggingface_hub import hf_hub_download from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path): config = ViTMSNConfig() config.num_labels = 1000 + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + if "s16" in checkpoint_url: config.hidden_size = 384 config.intermediate_size = 1536 diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index a190c42caa707..f40d5278c06be 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -632,6 +632,8 @@ def forward( >>> from PIL import Image >>> import requests + >>> torch.manual_seed(2) + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -644,6 +646,7 @@ def forward( >>> # model predicts one of the 1000 ImageNet classes >>> predicted_label = logits.argmax(-1).item() >>> print(model.config.id2label[predicted_label]) + Kerry blue terrier ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict