In [None]:

from pathlib import Path
import torch
from typing import Dict, Tuple
from timeit import default_timer as timer
import gradio as gr

In [None]:
import sys
sys.path.append('..')
from modular import data_setup, engine, models, utils, helper

In [None]:
model, transforms = models.create_vit_model(5)

In [None]:
state_dict = torch.load('../models/vietnamese_landmark_vit.pth', map_location='cpu')
model.load_state_dict(state_dict=state_dict)

In [None]:
class_names = ['Ha Long Bay', 'Ho Chi Minh Mausoleum', 'Hoi An Town', 'Hue Imperial city', 'Sapa Rice Terrace']

In [None]:
# Function for gradio demo

def predict(img) -> Tuple[Dict, float]:
    """Transforms and performs a prediction on img and returns prediction and time taken.
    """
    # Start the timer
    start_time = timer()
    
    # Transform the target image and add a batch dimension
    img = transforms(img).unsqueeze(0)
    
    # Put model into evaluation mode and turn on inference mode
    model.eval()
    with torch.inference_mode():
        # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
        pred_probs = torch.softmax(model(img), dim=1)
    
    # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
    pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
    
    # Calculate the prediction time
    pred_time = round(timer() - start_time, 5)
    
    # Return the prediction dictionary and prediction time 
    return pred_labels_and_probs, pred_time

In [None]:
# Create a list of examples for demo

data_path = Path('../data/val')
example_list = list(data_path.glob("*.jpg"))

In [None]:
# Build a gradio interface

title = 'Vietnamese Landmark'
description = "A viT model to classify Vietnamese landmarks"

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type='pil'),
    outputs=[gr.Label(num_top_classes=5, label="Predictions"),
             gr.Number(label='Prediction time (s)')],
    examples=example_list,
    title=title,
    description=description
)

demo.launch(debug=False,
            share=True )