In [1]:
import requests
from PIL import Image
from torchvision import transforms
import torch
from torchvision.models import resnet50
import json


In [2]:

# Load a pre-trained ResNet-50 model
model = resnet50(pretrained=True)
model.eval()

# Define the image transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/jayanthkomarraju/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:03<00:00, 31.6MB/s]


In [3]:

# Load the image
url = "https://firebasestorage.googleapis.com/v0/b/tree-hops.appspot.com/o/plants%2FV3ZBN1N68mSLhg2mw9fQVUiiQum2_1703375486488_indoor-plants-1643136651.jpeg?alt=media&token=940e10eb-5f73-4b3a-a878-977f17c41c1c"
image = Image.open(requests.get(url, stream=True).raw)

# Transform the image
input_tensor = transform(image)
input_batch = input_tensor.unsqueeze(0)


In [4]:

# Check if a GPU is available and if not, use a CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_batch = input_batch.to(device)


In [5]:

# Forward pass
with torch.no_grad():
    output = model(input_batch)


In [6]:

# Load ImageNet class labels
labels_path = 'https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json'
labels = json.loads(requests.get(labels_path).text)


In [7]:

# Decode the results
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)


In [8]:

# Check if any of the top categories are related to plants
is_plant = any("plant" in labels[top5_catid[i]].lower() for i in range(top5_prob.size(0)))

print("Is the subject a plant:", is_plant)


Is the subject a plant: False


In [9]:
top5_prob

tensor([0.4870, 0.2654, 0.0493, 0.0348, 0.0162])

In [10]:
top5_catid

tensor([738, 943, 923, 924, 952])

In [11]:
labels

['tench',
 'goldfish',
 'great white shark',
 'tiger shark',
 'hammerhead shark',
 'electric ray',
 'stingray',
 'cock',
 'hen',
 'ostrich',
 'brambling',
 'goldfinch',
 'house finch',
 'junco',
 'indigo bunting',
 'American robin',
 'bulbul',
 'jay',
 'magpie',
 'chickadee',
 'American dipper',
 'kite',
 'bald eagle',
 'vulture',
 'great grey owl',
 'fire salamander',
 'smooth newt',
 'newt',
 'spotted salamander',
 'axolotl',
 'American bullfrog',
 'tree frog',
 'tailed frog',
 'loggerhead sea turtle',
 'leatherback sea turtle',
 'mud turtle',
 'terrapin',
 'box turtle',
 'banded gecko',
 'green iguana',
 'Carolina anole',
 'desert grassland whiptail lizard',
 'agama',
 'frilled-necked lizard',
 'alligator lizard',
 'Gila monster',
 'European green lizard',
 'chameleon',
 'Komodo dragon',
 'Nile crocodile',
 'American alligator',
 'triceratops',
 'worm snake',
 'ring-necked snake',
 'eastern hog-nosed snake',
 'smooth green snake',
 'kingsnake',
 'garter snake',
 'water snake',
 'vin