In [1]:
import os
import glob
import random
from pathlib import Path
from tabulate import tabulate
from PIL import Image, ImageOps
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

device = "cuda:3" if torch.cuda.is_available() else "cpu"

In [2]:
TAXA = [
    ("Salmo salar", "Atlantic salmon", "Actinopterygii", "Salmoniformes", "Salmonidae", "Salmo"),
    ("Oncorhynchus mykiss", "Rainbow trout", "Actinopterygii", "Salmoniformes", "Salmonidae", "Oncorhynchus"),
    ("Gadus morhua", "Atlantic cod", "Actinopterygii", "Gadiformes", "Gadidae", "Gadus"),
    ("Melanogrammus aeglefinus", "Haddock", "Actinopterygii", "Gadiformes", "Gadidae", "Melanogrammus"),
    ("Scomber scombrus", "Atlantic mackerel", "Actinopterygii", "Scombriformes", "Scombridae", "Scomber"),
    ("Thunnus thynnus", "Atlantic bluefin tuna", "Actinopterygii", "Scombriformes", "Scombridae", "Thunnus"),
    ("Coryphaena hippurus", "Mahi-mahi", "Actinopterygii", "Carangiformes", "Coryphaenidae", "Coryphaena"),
    ("Xiphias gladius", "Swordfish", "Actinopterygii", "Xiphiiformes", "Xiphiidae", "Xiphias"),
    ("Clupea harengus", "Atlantic herring", "Actinopterygii", "Clupeiformes", "Clupeidae", "Clupea"),
    ("Sardina pilchardus", "European pilchard", "Actinopterygii", "Clupeiformes", "Clupeidae", "Sardina"),
    ("Engraulis encrasicolus", "European anchovy", "Actinopterygii", "Clupeiformes", "Engraulidae", "Engraulis"),
    ("Amphiprion ocellaris", "Ocellaris clownfish", "Actinopterygii", "Blenniiformes", "Pomacentridae", "Amphiprion"),
    ("Pomacanthus imperator", "Emperor angelfish", "Actinopterygii", "Acanthuriformes", "Pomacanthidae", "Pomacanthus"),
    ("Pterois volitans", "Red lionfish", "Actinopterygii", "Scorpaeniformes", "Scorpaenidae", "Pterois"),
    ("Zebrasoma flavescens", "Yellow tang", "Actinopterygii", "Acanthuriformes", "Acanthuridae", "Zebrasoma"),
    ("Hippocampus kuda", "Common seahorse", "Actinopterygii", "Syngnathiformes", "Syngnathidae", "Hippocampus"),
    ("Betta splendens", "Siamese fighting fish", "Actinopterygii", "Anabantiformes", "Osphronemidae", "Betta"),
    ("Paracheirodon innesi", "Neon tetra", "Actinopterygii", "Characiformes", "Characidae", "Paracheirodon"),
    ("Carassius auratus", "Goldfish", "Actinopterygii", "Cypriniformes", "Cyprinidae", "Carassius"),
    ("Cyprinus carpio", "Common carp", "Actinopterygii", "Cypriniformes", "Cyprinidae", "Cyprinus"),
    ("Poecilia reticulata", "Guppy", "Actinopterygii", "Cyprinodontiformes", "Poeciliidae", "Poecilia"),
    ("Astatotilapia burtoni", "Burton’s mouthbrooder", "Actinopterygii", "Cichliformes", "Cichlidae", "Astatotilapia"),
    ("Oreochromis niloticus", "Nile tilapia", "Actinopterygii", "Cichliformes", "Cichlidae", "Oreochromis"),
    ("Pterophyllum scalare", "Freshwater angelfish", "Actinopterygii", "Cichliformes", "Cichlidae", "Pterophyllum"),
    ("Micropterus salmoides", "Florida bass", "Actinopterygii", "Centrarchiformes", "Centrarchidae", "Micropterus"),
    ("Lepomis macrochirus", "Bluegill sunfish", "Actinopterygii", "Centrarchiformes", "Centrarchidae", "Lepomis"),
    ("Esox lucius", "Northern pike", "Actinopterygii", "Esociformes", "Esocidae", "Esox"),
    ("Ictalurus punctatus", "Channel catfish", "Actinopterygii", "Siluriformes", "Ictaluridae", "Ictalurus"),
    ("Silurus glanis", "Wels catfish", "Actinopterygii", "Siluriformes", "Siluridae", "Silurus"),
    ("Electrophorus electricus", "Electric eel", "Actinopterygii", "Gymnotiformes", "Gymnotidae", "Electrophorus"),
    ("Arapaima gigas", "Arapaima", "Actinopterygii", "Osteoglossiformes", "Arapaimidae", "Arapaima"),
    ("Osteoglossum bicirrhosum", "Silver arowana", "Actinopterygii", "Osteoglossiformes", "Osteoglossidae", "Osteoglossum"),
    ("Anguilla anguilla", "European eel", "Actinopterygii", "Anguilliformes", "Anguillidae", "Anguilla"),
    ("Muraena helena", "Mediterranean moray", "Actinopterygii", "Anguilliformes", "Muraenidae", "Muraena"),
    ("Lophius piscatorius", "Monkfish", "Actinopterygii", "Lophiiformes", "Lophiidae", "Lophius"),
    ("Hippoglossus hippoglossus", "Atlantic halibut", "Actinopterygii", "Pleuronectiformes", "Pleuronectidae", "Hippoglossus"),
    ("Pleuronectes platessa", "European plaice", "Actinopterygii", "Pleuronectiformes", "Pleuronectidae", "Pleuronectes"),
    ("Sphyraena barracuda", "Great barracuda", "Actinopterygii", "Carangiformes", "Sphyraenidae", "Sphyraena"),
    ("Dicentrarchus labrax", "European seabass", "Actinopterygii", "Acanthuriformes", "Moronidae", "Dicentrarchus"),
    ("Lutjanus campechanus", "Northern red snapper", "Actinopterygii", "Acanthuriformes", "Lutjanidae", "Lutjanus"),
    ("Epinephelus itajara", "Goliath grouper", "Actinopterygii", "Perciformes", "Epinephelidae", "Epinephelus"),
    ("Cheilinus undulatus", "Humphead wrasse", "Actinopterygii", "Labriformes", "Labridae", "Cheilinus"),
    ("Gobius niger", "Black goby", "Actinopterygii", "Gobiiformes", "Gobiidae", "Gobius"),
    ("Carcharodon carcharias", "Great white shark", "Chondrichthyes", "Lamniformes", "Lamnidae", "Carcharodon"),
    ("Galeocerdo cuvier", "Tiger shark", "Chondrichthyes", "Carcharhiniformes", "Carcharhinidae", "Galeocerdo"),
    ("Sphyrna lewini", "Scalloped hammerhead", "Chondrichthyes", "Carcharhiniformes", "Sphyrnidae", "Sphyrna"),
    ("Raja clavata", "Thornback ray", "Chondrichthyes", "Rajiformes", "Rajidae", "Raja"),
    ("Mobula birostris", "Giant manta ray", "Chondrichthyes", "Myliobatiformes", "Mobulidae", "Mobula"),
    ("Takifugu rubripes", "Japanese pufferfish", "Actinopterygii", "Tetraodontiformes", "Tetraodontidae", "Takifugu"),
    ("Diodon hystrix", "Porcupinefish", "Actinopterygii", "Tetraodontiformes", "Diodontidae", "Diodon"),]


# Extract unique families and create label mapping
families = sorted(list(set([species[4] for species in TAXA])))
family_to_idx = {family: idx for idx, family in enumerate(families)}
idx_to_family = {idx: family for family, idx in family_to_idx.items()}
num_classes = len(families)

print(f"Number of unique families: {num_classes}")
print(f"Families: {families}\n")

Number of unique families: 41
Families: ['Acanthuridae', 'Anguillidae', 'Arapaimidae', 'Carcharhinidae', 'Centrarchidae', 'Characidae', 'Cichlidae', 'Clupeidae', 'Coryphaenidae', 'Cyprinidae', 'Diodontidae', 'Engraulidae', 'Epinephelidae', 'Esocidae', 'Gadidae', 'Gobiidae', 'Gymnotidae', 'Ictaluridae', 'Labridae', 'Lamnidae', 'Lophiidae', 'Lutjanidae', 'Mobulidae', 'Moronidae', 'Muraenidae', 'Osphronemidae', 'Osteoglossidae', 'Pleuronectidae', 'Poeciliidae', 'Pomacanthidae', 'Pomacentridae', 'Rajidae', 'Salmonidae', 'Scombridae', 'Scorpaenidae', 'Siluridae', 'Sphyraenidae', 'Sphyrnidae', 'Syngnathidae', 'Tetraodontidae', 'Xiphiidae']



In [3]:
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def build_dataset_lists(root="dataCLIP"):
    root = Path(root)
    image_paths, labels = [], []
    
    for species in TAXA:
        binom, common, cls, order, family, genus = species
        family_idx = family_to_idx[family]
        folder_path = root / binom.replace(" ", "_")
        
        if not folder_path.exists():
            continue
            
        for img in sorted(folder_path.glob("*")):
            if img.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}:
                image_paths.append(str(img))
                labels.append(family_idx)
    
    assert len(image_paths) == len(labels)
    return image_paths, labels

class FishDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        img = Image.open(img_path)
        img = ImageOps.exif_transpose(img)
        
        if img.mode == "P" and ("transparency" in img.info or img.info.get("transparency") is not None):
            img = img.convert("RGBA").convert("RGB")
        else:
            img = img.convert("RGB")
        img.load()

        if self.transform:
            img = self.transform(img)

        return img, label

# Build dataset
image_data, label_data = build_dataset_lists("dataCLIP")
dataset = FishDataset(image_data, label_data, transform=train_transform)
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [5]:
# Model setup
model = models.resnet101(weights='IMAGENET1K_V1')
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 1024),
    nn.ReLU(inplace=True),
    nn.Dropout(0.3),
    nn.Linear(1024, num_classes)
)
model = model.to(device)

@torch.no_grad()
def predict_image(image_path, model, transform, device, topk=5):
    model.eval()
    
    img = Image.open(image_path)
    img = ImageOps.exif_transpose(img)
    
    if img.mode == "P" and ("transparency" in img.info or img.info.get("transparency") is not None):
        img = img.convert("RGBA").convert("RGB")
    else:
        img = img.convert("RGB")
    img.load()
    
    image = transform(img).unsqueeze(0).to(device)
    outputs = model(image)
    
    probs = torch.softmax(outputs, dim=1).squeeze(0)
    
    k = min(topk, num_classes)
    scores, indices = torch.topk(probs, k=k, largest=True, sorted=True)
    
    return [(float(scores[i]), idx_to_family[int(indices[i])]) for i in range(k)]

def run_evaluation(imgFolder):
    if not os.path.exists(imgFolder):
        print(f"{imgFolder} does not exist.")
        return 
    
    img_files = sorted(glob.glob(f"{imgFolder}/*"))  # Sort for consistent ordering
    if not len(img_files):
        print(f"{imgFolder} is empty.")
        return 

    rows = []
    eval_results = []  # Store results for accurate counting
    
    for name in img_files:
        results = predict_image(name, model, eval_transform, device, topk=3)
        name_split = os.path.basename(name).split(".")[0].split("_")
        
        # Extract true family from filename (format: Genus_species_Family)
        if len(name_split) >= 3:
            true_family = name_split[2]
        else:
            true_family = "Unknown"
        
        # Get predicted families
        pred_families = [label for _, label in results]
        
        # Build table row
        row = [f"{name_split[0]} {name_split[1] if len(name_split) > 1 else ''}"] + [true_family]
        row += [f"{label} ({score:.3f})" for score, label in results]
        rows.append(row)
        
        # Store for accuracy calculation (only if true family is in our training set)
        if true_family in family_to_idx:
            is_top1_correct = (pred_families[0] == true_family)
            is_top3_correct = (true_family in pred_families)
            eval_results.append({
                'name': os.path.basename(name),
                'true_family': true_family,
                'pred_top1': pred_families[0],
                'pred_top3': pred_families,
                'top1_correct': is_top1_correct,
                'top3_correct': is_top3_correct
            })

    # Print table
    headers = ["Image", "True Family", "Top-1", "Top-2", "Top-3"]
    print(tabulate(rows, headers=headers, tablefmt="fancy_grid"))
    
    # Calculate and print accuracy
    if len(eval_results) > 0:
        correct_top1 = sum(r['top1_correct'] for r in eval_results)
        correct_top3 = sum(r['top3_correct'] for r in eval_results)
        total = len(eval_results)
        
        print(f"\nTop-1 Accuracy: {correct_top1}/{total} = {100*correct_top1/total:.2f}%")
        print(f"Top-3 Accuracy: {correct_top3}/{total} = {100*correct_top3/total:.2f}%")
    else:
        print("\nNo valid images found with families in the training set.\n")

# Initial evaluation (before training)
print("=" * 30)
print("INITIAL EVALUATION (Before Training)")
print("=" * 30)
run_evaluation("zeroCLIP")

INITIAL EVALUATION (Before Training)
╒═══════════════════════════╤═══════════════╤═════════════════════╤════════════════════╤═══════════════════════╕
│ Image                     │ True Family   │ Top-1               │ Top-2              │ Top-3                 │
╞═══════════════════════════╪═══════════════╪═════════════════════╪════════════════════╪═══════════════════════╡
│ Centropyge boylei         │ Pomacanthidae │ Poeciliidae (0.030) │ Sphyrnidae (0.030) │ Coryphaenidae (0.030) │
├───────────────────────────┼───────────────┼─────────────────────┼────────────────────┼───────────────────────┤
│ Electrophorus electricus  │ Gymnotidae    │ Xiphiidae (0.031)   │ Siluridae (0.030)  │ Lophiidae (0.030)     │
├───────────────────────────┼───────────────┼─────────────────────┼────────────────────┼───────────────────────┤
│ Gymnotus carapo           │ Gymnotidae    │ Gymnotidae (0.031)  │ Muraenidae (0.030) │ Scombridae (0.030)    │
├───────────────────────────┼───────────────┼──────────────

In [6]:
# Training setup
EPOCH = 25
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCH)

# Training loop
for epoch in range(EPOCH):
    model.train()
    epoch_loss_sum = 0.0
    epoch_correct = 0
    epoch_total = 0
    step_idx = 0
    
    print("=" * 50)
    print(f"Epoch {epoch+1}/{EPOCH}")
    print("=" * 50)
    
    for batch in train_dataloader:
        optimizer.zero_grad()

        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        # Calculate batch accuracy
        _, predicted = torch.max(outputs, 1)
        batch_correct = (predicted == labels).sum().item()
        batch_total = labels.size(0)
        
        epoch_loss_sum += loss.item() * batch_total
        epoch_correct += batch_correct
        epoch_total += batch_total
        
        step_idx += 1
        if step_idx % 10 == 0:
            running_loss = epoch_loss_sum / epoch_total
            running_acc = 100 * epoch_correct / epoch_total
            print(f"Step {step_idx}/{len(train_dataloader)} | Loss: {running_loss:.4f} | Acc: {running_acc:.2f}%")
    
    scheduler.step()
    epoch_avg_loss = epoch_loss_sum / epoch_total
    epoch_acc = 100 * epoch_correct / epoch_total
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Average Loss: {epoch_avg_loss:.4f}")
    print(f"  Training Accuracy: {epoch_acc:.2f}%\n")
    
    # Run evaluation after each epoch
    print("-" * 30)
    print(f"EVALUATION AFTER EPOCH {epoch+1}")
    print("-" * 30)
    run_evaluation("zeroCLIP")

print("=" * 30)
print("TRAINING COMPLETE")
print("=" * 30)

# Save the trained model
#torch.save(model.state_dict(), 'resnet50_fish_family_classifier.pth')
#print("Model saved to: resnet50_fish_family_classifier.pth")

Epoch 1/25
Step 10/79 | Loss: 3.6704 | Acc: 4.38%
Step 20/79 | Loss: 3.6011 | Acc: 8.75%
Step 30/79 | Loss: 3.5213 | Acc: 10.83%
Step 40/79 | Loss: 3.4668 | Acc: 11.72%
Step 50/79 | Loss: 3.3521 | Acc: 17.50%
Step 60/79 | Loss: 3.2477 | Acc: 21.04%
Step 70/79 | Loss: 3.1535 | Acc: 24.38%

Epoch 1 Summary:
  Average Loss: 3.0772
  Training Accuracy: 26.32%

------------------------------
EVALUATION AFTER EPOCH 1
------------------------------
╒═══════════════════════════╤═══════════════╤═══════════════════════╤═══════════════════════╤════════════════════════╕
│ Image                     │ True Family   │ Top-1                 │ Top-2                 │ Top-3                  │
╞═══════════════════════════╪═══════════════╪═══════════════════════╪═══════════════════════╪════════════════════════╡
│ Centropyge boylei         │ Pomacanthidae │ Pomacentridae (0.137) │ Cichlidae (0.134)     │ Acanthuridae (0.124)   │
├───────────────────────────┼───────────────┼───────────────────────┼─────────