In [5]:
import gradio as gr
import torch 
import torchvision
import torchvision.transforms as transforms
import requests
from torchvision.models import resnet18, ResNet18_Weights


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
resnet.eval()

# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

def classify(img):
    # By default, gradio image is numpy
    img = torch.from_numpy(img)
    # Numpy image is channel last. PyTorch is channel 1st.
    img = img.permute(2, 0, 1)
    
    # The transforms before prediction
    img = torchvision.transforms.Resize(256, antialias=True)(img)
    img = torchvision.transforms.CenterCrop(224)(img).float()/255.
    img = normalize(img)
    
    # We insert batch size of 1
    img = img.unsqueeze(0)
    
    # The actual prediction
    with torch.no_grad():
        pred = resnet(img)
    
    # Convert the prediction to probabilities
    pred = torch.nn.functional.softmax(pred, dim=1)
    # Remove the batch dim. torch.squeeze() can also be used.
    pred = pred.squeeze()
    
    # torch to numpy space
    pred = pred.cpu().numpy()
    
    return {labels[i]: float(pred[i]) for i in range(1000)}
    

gr.Interface(fn=classify, 
             inputs=gr.Image(shape=(224, 224)),
             outputs=gr.Label(num_top_classes=5),
             title="1k Object Recognition",
             examples=['assets/wonder_cat.jpg', 'assets/aki_dog.jpg','assets/birdie1.jpg'],
             description="Demonstrates a pre-trained model from torchvision for image classification.",
             allow_flagging="never").launch(inbrowser=False)

AttributeError: partially initialized module 'gradio' has no attribute 'inputs' (most likely due to a circular import)