In [1]:
import torch
import transformers

from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img = Image.open("D:/my_gestures/val/01_one/0.jpg")

In [3]:
model_dir = "D:/saved_models/vit-base-patch16-224-in21k-finetuned-gestures/"

feature_extractor = transformers.ViTImageProcessor.from_pretrained(model_dir)
model = transformers.ViTForImageClassification.from_pretrained(model_dir)

In [4]:
encoding = feature_extractor(img.convert('RGB'), return_tensors='pt')
print(encoding.pixel_values.shape)

torch.Size([1, 3, 224, 224])


PyTorch API

In [12]:
import time

device = torch.device('cuda')

model.eval()
with torch.no_grad():
    start_time = time.time()
    model.to(device)
    encoding = encoding.to(device)
    outputs = model(encoding.pixel_values)
    elapsed = time.time() - start_time
    logits = outputs.logits
    sm = torch.nn.Softmax(dim=-1)(logits)

print(f"Elapsed time: {elapsed * 1000:.4f} ms")

predicted_class_idx = sm.argmax(-1).item()
confidence = sm[-1][predicted_class_idx]

print(f"Logits: {logits}")
print(f"SoftMax: {sm}")
print(f"Confidence: {confidence:.4f}")
print(f"Predicted class ID: {predicted_class_idx}")
print(f"Predicted class: {model.config.id2label[predicted_class_idx]}")

Elapsed time: 12.9848 ms
Logits: tensor([[-1.1792,  6.7489, -0.3102, -1.8317, -1.7011, -1.2327]],
       device='cuda:0')
SoftMax: tensor([[3.5975e-04, 9.9804e-01, 8.5785e-04, 1.8734e-04, 2.1348e-04, 3.4101e-04]],
       device='cuda:0')
Confidence: 0.9980
Predicted class ID: 1
Predicted class: 01_one


Pipeline API

In [13]:
pipe = transformers.pipeline('image-classification', model_dir, device=0)

start_time = time.time()
outputs = pipe(img)
elapsed = time.time() - start_time
print(f"Elapsed time: {elapsed * 1000:.4f} ms")

for o in outputs:
    print(o)



Elapsed time: 49.9973 ms
{'score': 0.9980406165122986, 'label': '01_one'}
{'score': 0.0008578469860367477, 'label': '02_two'}
{'score': 0.0003597516333684325, 'label': '00_fist'}
{'score': 0.000341010803822428, 'label': '05_five'}
{'score': 0.0002134784881491214, 'label': '04_four'}
