In [4]:
import torch
import torchvision.transforms as transforms
from PIL import Image
from vit_pytorch import ViT
import numpy as np
import matplotlib.pyplot as plt

# Load the ViT model
vit_model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=16,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)
checkpoint = torch.load("pretrained_vit_model_skin_or_eye_detection.pth")
vit_model.load_state_dict(checkpoint)
vit_model.eval()

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Function to predict class and confidence level
def predict(image_path, model):
    img = Image.open(image_path)
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        output = model(img)
    probs = torch.nn.functional.softmax(output, dim=-1)[0] * 100
    confidence, predicted_class = torch.max(probs, dim=-1)
    return predicted_class.item(), confidence.item()

# Function to display image, class, and confidence
def display_prediction(image_path, model):
    predicted_class, confidence = predict(image_path, model)
    classes = np.loadtxt("imagenet_classes.txt", str, delimiter='\t')
    img = Image.open(image_path)
    plt.imshow(img)
    plt.title(f'Prediction: {classes[predicted_class]} - Confidence: {confidence:.2f}%')
    plt.axis('off')
    plt.show()

# Call the function with the image path
image_path = "skin.jpg"
display_prediction(image_path, vit_model)


TypeError: Expected state_dict to be dict-like, got <class 'torchvision.models.vision_transformer.VisionTransformer'>.