**Install needed dependencies**

In [None]:
!pip install speciesnet fiftyone

**Mount Google Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Unzip model weights**

First, download model weights from here: https://www.kaggle.com/api/v1/models/google/speciesnet/pyTorch/v4.0.1b/1/download

Then, move the download here to the root of your Google Drive

In [None]:
!mkdir -p /content/drive/MyDrive/SpeciesNet
!tar -xvzf /content/drive/MyDrive/speciesnet-pytorch-v4.0.1b-v1.tar.gz -C /content/drive/MyDrive/SpeciesNet

**Imports**

In [None]:
import copy
import fiftyone as fo
import json
from PIL import Image
import numpy as np
import os
import random
from speciesnet import DEFAULT_MODEL
from speciesnet import SUPPORTED_MODELS
from speciesnet import only_one_true
from speciesnet import SpeciesNet
from speciesnet.ensemble_prediction_combiner import PredictionType
from speciesnet.utils import load_partial_predictions
from speciesnet.utils import prepare_instances_dict

**Download some data (or upload your images to the same path)**

In [None]:
# %%bash
import random
import subprocess
# Create the local directory if it doesn't exist
!mkdir -p /images

# Download files using gsutil
for i in range(23):
    image_num = random.randint(0, 10000)
    img_path = f"gs://public-datasets-lila/idaho-camera-traps/public/loc_{i:04}/loc_{i:04}_im_{image_num:06}.jpg"
    print(f"Downloading: {img_path}")
    subprocess.run(["gsutil", "-m", "cp", "-r", img_path, "/images"])
# gsutil -m cp -r "gs://public-datasets-lila/missouricameratraps/images/Set1/1.60-Red_Fox/SEQ75195" "/images"


**Initialize SpeciesNet model object.**

In [None]:
model = SpeciesNet(
    "/content/drive/MyDrive/SpeciesNet",
    multiprocessing=True,
)

**Make instances dictionary**

In [None]:

country = "USA",
admin1_region = "ID",

instances_dict = prepare_instances_dict(
    folders=["/images"],
    country=country,
    admin1_region=admin1_region,
)
os.makedirs("/content/outputs", exist_ok=True)
with open("/content/outputs/instances.json", "w") as outfile:
  json.dump(instances_dict, outfile, indent=4)


**Get detections**

In [None]:
dets_dict = model.detect(
    instances_dict=instances_dict,
    run_mode="multi_process",
    # batch_size=self.batch_size,
    progress_bars=True,
    predictions_json=None,
)
with open("/content/outputs/dets.json", "w") as outfile:
  json.dump(dets_dict, outfile, indent=4)

**Save crops of detections (to force classification of all detections)**

In [None]:
temp_path = os.path.join(
    "/content",
    "temp",
    "speciesnet_crops",
)
print(temp_path)
os.makedirs(temp_path, exist_ok=True)
temp_2_original = {}
cropped_dets = []
i = 0
for sample in dets_dict["predictions"]:
    img = Image.open(sample["filepath"])
    img_w, img_h = img.size
    _, ext = os.path.splitext(sample['filepath'])
    for det in sample["detections"]:
        x, y, w, h = det["bbox"]
        xmin, xmax = int(np.floor(img_w * x)), int(np.ceil(img_w * (x + w)))
        ymin, ymax = int(np.floor(img_h * y)), int(np.ceil(img_h * (y + h)))
        det_crop = img.crop((xmin, ymin, xmax, ymax))
        det_filepath = os.path.join(
            temp_path,
            f"{i}{ext}"
        )
        det_crop.save(det_filepath)
        cropped_det = copy.deepcopy(det)
        cropped_det["filepath"] = det_filepath
        cropped_det["bbox"] = [0.0, 0.0, 1.0, 1.0]
        cropped_dets.append(cropped_det)
        temp_2_original[det_filepath] = {
            "filepath": sample["filepath"],
            "bbox": det["bbox"],
        }
        i += 1
cropped_dets_dict = {"predictions": cropped_dets}
with open("/content/outputs/cropped_dets.json", "w") as outfile:
  json.dump(cropped_dets_dict, outfile, indent=4)
cropped_instances_dict = prepare_instances_dict(
    folders=[temp_path],
    country=country,
    admin1_region=admin1_region,
)
with open("/content/outputs/cropped_instances.json", "w") as outfile:
  json.dump(cropped_instances_dict, outfile, indent=4)

**Classify all detections**

In [None]:
# Get classification for all bounding boxes
cropped_labels_dict = model.classify(
    instances_dict=cropped_instances_dict,
    detections_dict=cropped_dets_dict,
    run_mode="multi_process",
    # batch_size=self.batch_size,
    progress_bars=True,
    predictions_json=None,
)
with open("/content/outputs/cropped_labels.json", "w") as outfile:
  json.dump(cropped_labels_dict, outfile, indent=4)

**Store results in Fiftyone**

species in model listed here: https://github.com/google/cameratrapai/blob/main/model_cards/v4.0.1b.md#label-distribution

In [None]:
fo_name = "SpeciesNet_Demo" #warning: will be deleted if it exists!
conf_threshold = 0.4
target_labels = {
    "human",
    "elk",
    "white-tailed deer",
    "mule deer",
    "coyote",
    "red fox",
    "domestic cattle",
    "pronghorn",
    "puma",
    "bobcat",
    "black-billed magpie",
}

print(temp_2_original)
try:
    fo.load_dataset(fo_name).delete()
except:
    pass
dataset = fo.Dataset(fo_name)
dataset.persistent = True

fo_dets = {}
for cropped_label in cropped_labels_dict["predictions"]:
    label = "other"
    conf = 0.0
    for i in range(len(cropped_label["classifications"]["classes"])):
        full_label = cropped_label["classifications"]["classes"][i]
        species_label = full_label.split(";")[-1]
        species_conf = cropped_label["classifications"]["scores"][i]
        if species_label in target_labels and species_conf >= conf_threshold:
            label = species_label
            conf = species_conf
            break
        elif species_conf > conf:
            conf = species_conf
    cropped_filepath = cropped_label["filepath"]
    original_filepath = temp_2_original[cropped_filepath]["filepath"]
    bbox = temp_2_original[cropped_filepath]["bbox"]
    try:
        fo_dets[original_filepath].append(fo.Detection(
            label = label,
            bounding_box = bbox,
            confidence = conf
        ))
    except KeyError:
        fo_dets[original_filepath] = [fo.Detection(
            label = label,
            bounding_box = bbox,
            confidence = conf
        )]

for filepath, dets in fo_dets.items():
    sample = fo.Sample(filepath=filepath)
    sample["SpeciesNet"] = fo.Detections(detections=dets)
    dataset.add_sample(sample)
dataset.save()

In [None]:
session = fo.launch_app(dataset)

In [None]:
session.close()