In [1]:
import torch

from efficient_pipeline import EfficientNetPipeline
from inaturalist import FISH_CLASSES

In [2]:
pipeline = EfficientNetPipeline(
                image_size=128,
                batch_size=64,
                learning_rate=0.001,
                num_epochs=40,
                top_k=5,
                model_type="efficient",
                classes=FISH_CLASSES
            )



In [3]:
pipeline.job_id = "latest"

In [4]:
pipeline.data_setup(False, FISH_CLASSES)
pipeline.load_model()

[INFO]: Loading pre-trained weights
[INFO]: Fine-tuning all layers...
Loading model from models/model_efficient_fish_latest.pt




In [5]:
pipeline.evaluate()

(52.11055276381909, 74.02010050251256)

In [6]:
y_true = []
y_pred = []

for images, labels in pipeline.test_dataloader:
    images = images.to(pipeline.device)
    labels = labels.to(pipeline.device)
    outputs = pipeline.model(images)
    _, predictions = torch.topk(outputs.data, k=pipeline.k, dim=1, largest=True)
    _, predicted = torch.max(outputs.data, 1)
    
    y_true += labels.tolist()
    y_pred += predicted.tolist()

In [7]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

In [8]:
from inaturalist import iNaturalistDataset

test = iNaturalistDataset(
    root="./data",
    train=False,
    download=False,
    classes=FISH_CLASSES,
    transform=pipeline.all_transforms,
)

In [9]:
import json

with open("data/train.json") as f:
    data = json.load(f)

In [10]:
data["categories"]

[{'id': 0,
  'name': 'Lumbricus terrestris',
  'common_name': 'Common Earthworm',
  'supercategory': 'Animalia',
  'kingdom': 'Animalia',
  'phylum': 'Annelida',
  'class': 'Clitellata',
  'order': 'Haplotaxida',
  'family': 'Lumbricidae',
  'genus': 'Lumbricus',
  'specific_epithet': 'terrestris',
  'image_dir_name': '00000_Animalia_Annelida_Clitellata_Haplotaxida_Lumbricidae_Lumbricus_terrestris'},
 {'id': 1,
  'name': 'Sabella spallanzanii',
  'common_name': 'Mediterranean Fanworm',
  'supercategory': 'Animalia',
  'kingdom': 'Animalia',
  'phylum': 'Annelida',
  'class': 'Polychaeta',
  'order': 'Sabellida',
  'family': 'Sabellidae',
  'genus': 'Sabella',
  'specific_epithet': 'spallanzanii',
  'image_dir_name': '00001_Animalia_Annelida_Polychaeta_Sabellida_Sabellidae_Sabella_spallanzanii'},
 {'id': 2,
  'name': 'Serpula columbiana',
  'common_name': 'Serpula columbiana',
  'supercategory': 'Animalia',
  'kingdom': 'Animalia',
  'phylum': 'Annelida',
  'class': 'Polychaeta',
  'ord

In [11]:
order_dict = {v: k for k, v in test.id_to_order.items()}

In [12]:
y_pred_labeled = [data["categories"][order_dict[y]]["name"] for y in y_pred]
y_true_labeled = [data["categories"][order_dict[y]]["name"] for y in y_true]

In [13]:
labels = [data["categories"][k]["name"] for k in test.id_to_order.keys()]
labels

['Amia calva',
 'Anguilla rostrata',
 'Gymnothorax prasinus',
 'Latropiscis purpurissatus',
 'Dorosoma cepedianum',
 'Catostomus commersonii',
 'Campostoma anomalum',
 'Carassius auratus',
 'Cyprinus carpio',
 'Cyprinus rubrofuscus',
 'Luxilus chrysocephalus',
 'Luxilus cornutus',
 'Notemigonus crysoleucas',
 'Pimephales notatus',
 'Rhinichthys atratulus',
 'Scardinius erythrophthalmus',
 'Semotilus atromaculatus',
 'Squalius cephalus',
 'Cyprinodon variegatus',
 'Fundulus diaphanus',
 'Fundulus heteroclitus',
 'Fundulus notatus',
 'Gambusia affinis',
 'Gambusia holbrooki',
 'Megalops atlanticus',
 'Esox lucius',
 'Esox niger',
 'Culaea inconstans',
 'Gasterosteus aculeatus',
 'Gobiesox maeandricus',
 'Lepisosteus oculatus',
 'Lepisosteus osseus',
 'Lepisosteus platyrhincus',
 'Mugil cephalus',
 'Acanthurus coeruleus',
 'Acanthurus nigrofuscus',
 'Acanthurus olivaceus',
 'Acanthurus triostegus',
 'Naso lituratus',
 'Zebrasoma flavescens',
 'Ostorhinchus limenus',
 'Pseudocaranx georgia

In [14]:
# Create the confusion matrix
cm = confusion_matrix(y_true_labeled, y_pred_labeled)

In [15]:
# plt.figure(figsize = (20,20))
# sns.heatmap(cm, annot=True)
# plt.show()

In [16]:
print(classification_report(y_true_labeled, y_pred_labeled))

                             precision    recall  f1-score   support

        Abudefduf saxatilis       0.21      0.30      0.25        10
     Abudefduf sexfasciatus       0.67      0.80      0.73        10
       Abudefduf troschelii       0.38      0.30      0.33        10
       Abudefduf vaigiensis       0.44      0.40      0.42        10
    Acanthopagrus australis       0.78      0.70      0.74        10
       Acanthurus coeruleus       0.73      0.80      0.76        10
     Acanthurus nigrofuscus       0.58      0.70      0.64        10
       Acanthurus olivaceus       0.86      0.60      0.71        10
      Acanthurus triostegus       0.50      0.50      0.50        10
         Achoerodus viridis       0.70      0.70      0.70        10
         Aetobatus narinari       0.67      0.80      0.73        10
      Ambloplites rupestris       0.33      0.30      0.32        10
             Ameiurus melas       0.58      0.70      0.64        10
           Ameiurus natalis      

In [17]:
import pandas as pd

report = classification_report(y_true_labeled, y_pred_labeled, output_dict=True)
report_df = pd.DataFrame(report).transpose()

In [59]:
report_df.sort_values(by="f1-score", ascending=False)

Unnamed: 0,precision,recall,f1-score,support
Upeneichthys lineatus,0.909091,1.0,0.952381,10.0
Dicotylichthys punctulatus,0.833333,1.0,0.909091,10.0
Taeniura lymma,0.900000,0.9,0.900000,10.0
Cheilodactylus nigripes,0.900000,0.9,0.900000,10.0
Pomacanthus paru,1.000000,0.8,0.888889,10.0
...,...,...,...,...
Herichthys cyanoguttatus,0.166667,0.1,0.125000,10.0
Fundulus diaphanus,0.142857,0.1,0.117647,10.0
Oligocottus maculosus,0.125000,0.1,0.111111,10.0
Myliobatis tenuicaudatus,0.111111,0.1,0.105263,10.0


In [55]:
output = pipeline.predict("/home/nates/repos/inaturalist-species-detection/data/val/02789_Animalia_Chordata_Actinopterygii_Mugiliformes_Mugilidae_Mugil_cephalus/9f847924-7d56-4720-b5fb-f2ef5385078d.jpg")

In [56]:
_, predicted = torch.max(output.data, 1)

In [57]:
y = predicted.tolist()[0]

In [58]:
data["categories"][order_dict[y]]["name"]

'Gasterosteus aculeatus'