In [12]:
import torch
import yaml
import json

from yolov5 import run as yolo_run
from segment_anything import sam_register, sam_run
from segment_anything.utils import get_masked_images
from mobilenet.utils import preprocess
from mobilenet import run as mobilenet_run

In [2]:
# Load configuration
config_path = '/Users/lgl/code/machine_learning/wildlife-segment-anything/config/inference.yaml'
with open(config_path) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
# Run YOLOv5 inference
prompts = yolo_run(
    weights=config['yolo'],
    source=config['dataset'],
    data=config['config'],
    device=config['device'],
)

YOLOv5 🚀 2024-1-19 Python-3.8.18 torch-2.1.2 CPU

Fusing layers... 
Model summary: 157 layers, 7037095 parameters, 0 gradients, 15.8 GFLOPs
image 1/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/bear.jpg: 480x640 1 bear, 260.9ms
image 2/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/bird.jpg: 576x640 1 bird, 308.6ms
image 3/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/cat.jpg: 448x640 1 cat, 244.4ms
image 4/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/cow.jpg: 384x640 2 cows, 216.3ms
image 5/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/dog.jpg: 448x640 2 dogs, 237.6ms
image 6/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/elephant.jpg: 480x640 1 elephant, 256.5ms
image 7/10 /Users/lgl/code/machine_learning/wildlife-segment-anything/data/images/coco/giraffe.jpg: 448x640 1 giraffe, 232.3ms
ima

In [4]:
# Run SAM inference
sam = sam_register(
    checkpoint=config['sam_checkpoint'],
    model_type=config['sam_model_type'],
    device=config['device'],
)
masks = sam_run(
    dataset=config['dataset'],
    prompt=prompts,
    predictor=sam,
)

Processed 1/10 images
Processed 2/10 images
Processed 3/10 images
Processed 4/10 images
Processed 5/10 images
Processed 6/10 images
Processed 7/10 images
Processed 8/10 images
Processed 9/10 images
Processed 10/10 images


In [5]:
# Get masked images
masked_images = get_masked_images(
    dataset=config['dataset'],
    masks=masks,
)

In [6]:
# Run MobileNet inference
# Load MobileNet model
mobilenet = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
mobilenet.eval()

# Preprocess images
mobilenet_input = preprocess(masked_images)

Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0


In [9]:
# Run inference
results = mobilenet_run(
    input_batches=mobilenet_input,
    label_path=config['labels'],
)

Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/lgl/.cache/torch/hub/pytorch_vision_v0.10.0


In [10]:
results

{'dog.jpg': {'flat-coated retriever': 0.3171379566192627},
 'horse.jpg': {'African elephant': 0.18369567394256592},
 'elephant.jpg': {'Indian elephant': 0.6400818228721619},
 'zebra.jpg': {'zebra': 0.8779024481773376},
 'sheep.jpg': {'ox': 0.3848831355571747},
 'bird.jpg': {'indigo bunting': 0.19207172095775604},
 'cow.jpg': {'Chihuahua': 0.15208041667938232},
 'cat.jpg': {'Arctic fox': 0.6235840320587158},
 'giraffe.jpg': {'banded gecko': 0.3448532521724701},
 'bear.jpg': {'ice bear': 0.21838708221912384}}

In [13]:
with open('/Users/lgl/code/machine_learning/wildlife-segment-anything/results/results.json','w') as f:
    json.dump(results, f)