In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from PIL import Image
import gradio as gr

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load ImageNet-pretrained backbone and turn it into a feature extractor
resnet50_model = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
resnet50_model.fc = nn.Identity()
resnet50_model = resnet50_model.to(device)

# Load our trained classifier head
fc_model = nn.Sequential(
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)
fc_state_dict = torch.load("model.pth", weights_only=True, map_location=device)
fc_model.load_state_dict(fc_state_dict)
fc_model = fc_model.to(device)

# Compose final model and set to eval
model = nn.Sequential(
    resnet50_model,
    fc_model
)
model = model.to(device)
model.eval()

# TODO: write a predict_image(image_pixels) function that:
#   1) receives a NumPy array from Gradio,
#   2) converts it to a PIL image (Image.fromarray),
#   3) applies `preprocess`, unsqueezes batch dim, moves to device,
#   4) runs model in no_grad, applies sigmoid, and
#   5) returns a string like: "By {percentage}%, this is a good tire"

# TODO: create and launch a Gradio Interface:
# demo = gr.Interface(
#     fn=predict_image,
#     inputs=[gr.Image()],
#     outputs=gr.Text()
# )
# demo.launch()