In [14]:
# Import dependencies
import torch
import torchvision
import torch.nn as nn
import torchvision.models
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights #pretrained model

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Define Animal Classes
animal_classes = ['cat', 'cow', 'coyote', 'deer', 'dog', 'donkey', 'fox', 'horse', 'owl', 'pig', 'possum', 'raccoon', 'sheep', 'wolf']

# Define predator mapping (0: nonpredator, 1: predator, 2: both)
predator_mapping = {'cat': 2, 'cow': 0, 'coyote': 1, 'deer': 2, 'dog': 2, 'donkey': 0, 'fox': 1, 'horse': 0, 'owl': 1, 'pig': 0, 'possum': 1,
                    'raccoon': 1, 'sheep': 0, 'wolf': 1}

predator_classes = ['nonpredator', 'predator', 'both']

#load the model
model = torch.load('animal_classifier_model.pt')


# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Function to classify an image and return the animal and predator class
def classify_image(image):
    model.eval()
    with torch.no_grad():
        image = val_transforms(image).unsqueeze(0)
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        animal_class = animal_classes[predicted.item()]
        predator_class = predator_classes[predator_mapping[animal_class]]
        return animal_class, predator_class

In [18]:
import gradio as gr
from PIL import Image

# Set up the Gradio interface
def predict(image):
    animal_class, predator_class = classify_image(image)
    return f"Animal: {animal_class}, Category: {predator_class}"

interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Text(), title="Animal Classifier")
interface.launch(share=True)

Running on local URL:  http://127.0.0.1:7865
Running on public URL: https://4a06fed12aed765422.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


