<a href="https://colab.research.google.com/github/cspringbett/octomatic/blob/main/octomatic_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Mount Colab

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


Load Model and Classes

In [2]:
import json
import pandas as pd
import os
import numpy as np
import tensorflow as tf
import geopandas as gpd
import folium
from shapely.geometry import Point
from geopy.geocoders import Nominatim
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.cm as cm
import matplotlib.colors as mcolors


# Constants
IMAGE_SIZE = (224, 224)
CONFIDENCE_THRESHOLD = 0.75
species_names = ['macrotritopus_defilippi', 'octopus_americanus', 'octopus_briareus', 'octopus_furvus', 'octopus_insularis']
OUTPUT_DIR = "/content/drive/MyDrive/OctoMatic_SpeciesID/output"
GEOJSON_DIR = "/content/drive/MyDrive/OctoMatic_SpeciesID/species_ranges"
METADATA_PATH = "/content/drive/MyDrive/OctoMatic_SpeciesID/species_info.json"

# Load model
model = tf.keras.models.load_model(f"{OUTPUT_DIR}/MobileNetV2_fold1.h5")

# Load GeoJSON ranges
geojson_ranges = {}
for species in species_names:
    path = f"{GEOJSON_DIR}/{species}.geojson"
    try:
        geojson_ranges[species] = gpd.read_file(path)
    except Exception as e:
        print(f"⚠️ Could not load {species} range: {e}")

# Load metadata
with open(METADATA_PATH) as f:
    species_info = json.load(f)




⚠️ Could not load octopus_furvus range: /content/drive/MyDrive/OctoMatic_SpeciesID/species_ranges/octopus_furvus.geojson: No such file or directory


Geocode City/State Input to Coordinates

In [3]:
def geocode_location(location_str):
    geolocator = Nominatim(user_agent="octomatic_app")
    loc = geolocator.geocode(location_str, timeout=10)
    if loc:
        return loc.latitude, loc.longitude
    return None, None


Location Filtering Logic

In [4]:
from shapely.geometry import Point

def filter_predictions_by_geojson(preds, species_names, lat, lon, geojson_ranges, buffer_km=25):
    point = Point(lon, lat).buffer(buffer_km / 111)  # rough buffer conversion to degrees
    in_range_species = []
    messages = []

    for species, gdf in geojson_ranges.items():
        try:
            if gdf.geometry.intersects(point).any():
                in_range_species.append(species)
        except Exception as e:
            messages.append(f"⚠️ Error checking {species}: {e}")

    for species in species_names:
        if species not in geojson_ranges:
            messages.append(f"⚠️ {species}: No range file — not filtered")
        elif species in in_range_species:
            messages.append(f"✅ {species}: Within known range")
        else:
            messages.append(f"❗ {species}: Outside known range")

    return preds, messages  # Keep predictions unchanged





Create distribution map

In [5]:
def create_range_map(lat, lon, predicted_species, gdf):
    m = folium.Map(location=[lat, lon], zoom_start=6)

    folium.Marker(
        location=[lat, lon],
        popup="Input Location",
        icon=folium.Icon(color="blue", icon="info-sign")
    ).add_to(m)

    try:
        folium.GeoJson(gdf, name=predicted_species, tooltip=predicted_species).add_to(m)
    except Exception as e:
        print(f"Error adding GeoJSON to map: {e}")

    return m._repr_html_()


Classifier and Gradio UI

In [6]:
# === Classify Function ===
def classify_image(image, location_str="", lat_input=None, lon_input=None):
    # --- helpers ---
    def format_species_name(slug: str) -> str:
        parts = slug.split("_", 1)
        return f"{parts[0].capitalize()} {parts[1]}" if len(parts) == 2 else slug

    def empty_conf_df():
        # Neutral placeholder bars so BarPlot doesn't error
        names = [format_species_name(s) for s in species_names]
        return pd.DataFrame({
            "class": names,
            "confidence": [0.0] * len(names),
            "color": ["#d3d3d3"] * len(names)
        })

    # --- if no image, return placeholders ---
    if image is None:
        return (
            "Please upload an image.",
            empty_conf_df(),
            "No location provided.",  # geo-filtering details
            "",                       # description
            None,                     # example image
            "<p>No map available</p>" # map HTML
        )

    # --- preprocess image ---
    img = image.resize(IMAGE_SIZE)
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)

    # --- predict ---
    preds = model.predict(img_array)[0]  # shape: (num_classes,)

    # --- optional geo-filtering & map ---
    filtering_msg = ["No location provided, no filtering applied."]
    map_html = "<p>No map available</p>"

    lat = lon = None
    used_source = None

   # A) numeric first
    lat = lon = None
    used_source = None
    if lat_input is not None and lon_input is not None:
        try:
            lat = float(lat_input); lon = float(lon_input)
            if (lat, lon) != (0.0, 0.0) or (location_str and location_str.strip() == "0,0"):
                used_source = "numeric"
            else:
                lat = lon = None
        except Exception:
            lat = lon = None
            filtering_msg = ["Invalid numeric coordinates."]

# B) fall back to geocoding if needed
    if (lat is None or lon is None) and location_str and location_str.strip():
        try:
            geolocator = Nominatim(user_agent="octo_classifier", timeout=10)
            location = geolocator.geocode(location_str.strip())
        except Exception as e:
            location = None
            filtering_msg = [f"Geocoding error: {e}"]
        if location:
            lat, lon = location.latitude, location.longitude
            used_source = "geocoded"
        elif not filtering_msg or "Geocoding error" not in filtering_msg[0]:
            filtering_msg = [f"Could not geocode location: {location_str}"]

# C) apply filtering & map if we have coords; else default behavior
    if lat is not None and lon is not None:
        preds, filtering_msg = filter_predictions_by_geojson(
            preds, species_names, lat, lon, geojson_ranges
        )
        top_slug_for_map = species_names[int(np.argmax(preds))]
        if top_slug_for_map in geojson_ranges:
            map_html = create_range_map(lat, lon, top_slug_for_map, geojson_ranges[top_slug_for_map])
        if used_source == "numeric":
            filtering_msg = [f"Using numeric coordinates: ({lat:.5f}, {lon:.5f})."] + filtering_msg
        elif used_source == "geocoded":
            filtering_msg = [f"Geocoded '{location_str.strip()}' → ({lat:.5f}, {lon:.5f})."] + filtering_msg
    else:
        top_slug_for_map = species_names[int(np.argmax(preds))]
        if top_slug_for_map in geojson_ranges:
            map_html = create_range_map(20.0, -70.0, top_slug_for_map, geojson_ranges[top_slug_for_map])


    # --- top prediction label, pretty formatting ---
    top_index = int(np.argmax(preds))
    top_confidence = float(preds[top_index])
    top_slug = species_names[top_index]
    top_pretty = format_species_name(top_slug)
    top_pct = int(round(top_confidence * 100))

    if top_confidence >= CONFIDENCE_THRESHOLD:
        label = f"Top Prediction: {top_pretty} ({top_pct}%)"
    else:
        label = "Top Prediction: Unsure or Unknown Species"

    # --- build DataFrame for BarPlot ---
    formatted_names = [format_species_name(s) for s in species_names]
    confidences = [round(float(p) * 100, 2) for p in preds]

    confidence_data = pd.DataFrame({
        "class": formatted_names,
        "confidence": confidences,
    })

    # Sort descending
    confidence_data = confidence_data.sort_values(by="confidence", ascending=False).reset_index(drop=True)

    # Force categorical y-axis ordering by confidence
    from pandas.api.types import CategoricalDtype
    class_order = confidence_data["class"].tolist()
    confidence_data["class"] = confidence_data["class"].astype(
        CategoricalDtype(categories=class_order, ordered=True)
    )

    # colors: viridis for all, black for top
    viridis = cm.get_cmap("viridis", len(confidences) or 1)
    colors = [mcolors.to_hex(viridis(i / max(1, len(confidences) - 1))) for i in range(len(confidences))]
    if len(colors) > 0:
        colors[top_index] = "#000000"

    confidence_data["color"] = colors


    # --- metadata (only if confident) ---
    if top_confidence >= CONFIDENCE_THRESHOLD:
        info = species_info.get(top_slug, {})
        description = info.get("description", "No description available.")
        example_img_url = info.get("example_image", None)
    else:
        description = ""
        example_img_url = None

    return label, confidence_data, "\n".join(filtering_msg), description, example_img_url, map_html





# === Gradio Interface ===

import gradio as gr

with gr.Blocks(
    title="Octopus Species Classifier",
    theme=gr.themes.Ocean(
        text_size="lg",
        font=[gr.themes.GoogleFont('Sniglet'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
    ).set(
        body_background_fill='*primary_100',
        body_background_fill_dark='*secondary_950',
        body_text_color_subdued='*neutral_700',
        background_fill_primary='*neutral_50',
        background_fill_secondary='*neutral_100',
        link_text_color_active='*neutral_700',
        button_secondary_background_fill='linear-gradient(120deg, *secondary_500 0%, *primary_300 60%, *primary_400 100%)',
        button_secondary_background_fill_dark='linear-gradient(120deg, *secondary_500 0%, *primary_300 60%, *primary_400 100%)'
    )
) as demo:
    gr.Markdown("# 🐙 OctoMatic: An AI Solution for Octopus Species Identification")
    gr.Markdown("Upload an image of an octopus, and optionally enter a location for distribution data.")

    with gr.Row():
        # ---- Inputs (left) ----
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Octopus Image")

            with gr.Accordion("Optional Location as City OR Coordinates", open=True):
                location_input = gr.Textbox(
                    label="City/State (optional)",
                    placeholder="e.g., Miami, FL (optional)"
                )
                with gr.Row():
                    lat_input = gr.Number(
                        label="Latitude (optional)", precision=6,
                        value=None, minimum=-90, maximum=90,
                        placeholder= "e.g., 25.77"
                    )
                    lon_input = gr.Number(
                        label="Longitude (optional)", precision=6,
                        value=None, minimum=-180, maximum=180,
                        placeholder="e.g., -80.19"
                    )

            submit_btn = gr.Button("Classify")

        # ---- Outputs (right) ----
        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.Tab("Results"):
                    prediction_label = gr.Label(label="Top Prediction")
                    confidence_plot = gr.BarPlot(
                        label="Prediction Confidence",
                        x="confidence",
                        y="class",
                        orientation="horizontal",
                        title="Prediction Confidence per Species",
                        x_title="Confidence (%)",
                        y_title="Species",
                        x_lim=(0, 100),
                        tooltip=["confidence"],
                        color="color"
                    )
                    with gr.Row():
                        metadata_image = gr.Image(label="Example Image", interactive=False)
                        species_info_output = gr.HTML(label="Species Info / Notes")

                with gr.Tab("Geography"):
                  filtering_output = gr.Textbox(
                      label="Geo-Filtering Details",
                      lines=8,           # a touch taller so it’s readable
                      interactive=False,
                      show_copy_button=True
                  )
                  map_output = gr.HTML(label="Species Range Map")


                with gr.Tab("Model Info"):
                    gr.Markdown("""
                    ### ℹ️ Model & Pipeline
                    - **Backbone**: MobileNetV2 (ImageNet pre-trained), fine-tuned
                    - **Training**: Stratified K-Fold cross-validation
                    - **Class Imbalance**: Class weights + augmentation
                    - **Evaluation**: Per-class precision/recall/F1, confusion matrices
                    - **Prediction UX**: Confidence thresholding, optional geo-filtering by range polygons
                    - **Notes**: Top prediction is highlighted in the confidence plot; map shows the predicted species distribution
                    """)

    # Wire up click → classify_image
    submit_btn.click(
        fn=classify_image,
        inputs=[image_input, location_input, lat_input, lon_input],
        outputs=[
            prediction_label,
            confidence_plot,
            filtering_output,
            species_info_output,
            metadata_image,
            map_output
        ]
    )

demo.launch(debug=True)


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://3e5d3e58ad6c6ce3ae.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://3e5d3e58ad6c6ce3ae.gradio.live




Gradio