In [1]:
import os
import re
import csv
import glob
import random
from pathlib import Path
from tabulate import tabulate
from PIL import Image, ImageOps

import clip 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
device = "cuda:3" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("RN101", device=device, jit=False)

TAXA = [
    ("Salmo salar", "Atlantic salmon", "Actinopterygii", "Salmoniformes", "Salmonidae", "Salmo"),
    ("Oncorhynchus mykiss", "Rainbow trout", "Actinopterygii", "Salmoniformes", "Salmonidae", "Oncorhynchus"),
    ("Gadus morhua", "Atlantic cod", "Actinopterygii", "Gadiformes", "Gadidae", "Gadus"),
    ("Melanogrammus aeglefinus", "Haddock", "Actinopterygii", "Gadiformes", "Gadidae", "Melanogrammus"),
    ("Scomber scombrus", "Atlantic mackerel", "Actinopterygii", "Scombriformes", "Scombridae", "Scomber"),
    ("Thunnus thynnus", "Atlantic bluefin tuna", "Actinopterygii", "Scombriformes", "Scombridae", "Thunnus"),
    ("Coryphaena hippurus", "Mahi-mahi", "Actinopterygii", "Carangiformes", "Coryphaenidae", "Coryphaena"),
    ("Xiphias gladius", "Swordfish", "Actinopterygii", "Xiphiiformes", "Xiphiidae", "Xiphias"),
    ("Clupea harengus", "Atlantic herring", "Actinopterygii", "Clupeiformes", "Clupeidae", "Clupea"),
    ("Sardina pilchardus", "European pilchard", "Actinopterygii", "Clupeiformes", "Clupeidae", "Sardina"),
    ("Engraulis encrasicolus", "European anchovy", "Actinopterygii", "Clupeiformes", "Engraulidae", "Engraulis"),
    ("Amphiprion ocellaris", "Ocellaris clownfish", "Actinopterygii", "Blenniiformes", "Pomacentridae", "Amphiprion"),
    ("Pomacanthus imperator", "Emperor angelfish", "Actinopterygii", "Acanthuriformes", "Pomacanthidae", "Pomacanthus"),
    ("Pterois volitans", "Red lionfish", "Actinopterygii", "Scorpaeniformes", "Scorpaenidae", "Pterois"),
    ("Zebrasoma flavescens", "Yellow tang", "Actinopterygii", "Acanthuriformes", "Acanthuridae", "Zebrasoma"),
    ("Hippocampus kuda", "Common seahorse", "Actinopterygii", "Syngnathiformes", "Syngnathidae", "Hippocampus"),
    ("Betta splendens", "Siamese fighting fish", "Actinopterygii", "Anabantiformes", "Osphronemidae", "Betta"),
    ("Paracheirodon innesi", "Neon tetra", "Actinopterygii", "Characiformes", "Characidae", "Paracheirodon"),
    ("Carassius auratus", "Goldfish", "Actinopterygii", "Cypriniformes", "Cyprinidae", "Carassius"),
    ("Cyprinus carpio", "Common carp", "Actinopterygii", "Cypriniformes", "Cyprinidae", "Cyprinus"),
    ("Poecilia reticulata", "Guppy", "Actinopterygii", "Cyprinodontiformes", "Poeciliidae", "Poecilia"),
    ("Astatotilapia burtoni", "Burton’s mouthbrooder", "Actinopterygii", "Cichliformes", "Cichlidae", "Astatotilapia"),
    ("Oreochromis niloticus", "Nile tilapia", "Actinopterygii", "Cichliformes", "Cichlidae", "Oreochromis"),
    ("Pterophyllum scalare", "Freshwater angelfish", "Actinopterygii", "Cichliformes", "Cichlidae", "Pterophyllum"),
    ("Micropterus salmoides", "Florida bass", "Actinopterygii", "Centrarchiformes", "Centrarchidae", "Micropterus"),
    ("Lepomis macrochirus", "Bluegill sunfish", "Actinopterygii", "Centrarchiformes", "Centrarchidae", "Lepomis"),
    ("Esox lucius", "Northern pike", "Actinopterygii", "Esociformes", "Esocidae", "Esox"),
    ("Ictalurus punctatus", "Channel catfish", "Actinopterygii", "Siluriformes", "Ictaluridae", "Ictalurus"),
    ("Silurus glanis", "Wels catfish", "Actinopterygii", "Siluriformes", "Siluridae", "Silurus"),
    ("Electrophorus electricus", "Electric eel", "Actinopterygii", "Gymnotiformes", "Gymnotidae", "Electrophorus"),
    ("Arapaima gigas", "Arapaima", "Actinopterygii", "Osteoglossiformes", "Arapaimidae", "Arapaima"),
    ("Osteoglossum bicirrhosum", "Silver arowana", "Actinopterygii", "Osteoglossiformes", "Osteoglossidae", "Osteoglossum"),
    ("Anguilla anguilla", "European eel", "Actinopterygii", "Anguilliformes", "Anguillidae", "Anguilla"),
    ("Muraena helena", "Mediterranean moray", "Actinopterygii", "Anguilliformes", "Muraenidae", "Muraena"),
    ("Lophius piscatorius", "Monkfish", "Actinopterygii", "Lophiiformes", "Lophiidae", "Lophius"),
    ("Hippoglossus hippoglossus", "Atlantic halibut", "Actinopterygii", "Pleuronectiformes", "Pleuronectidae", "Hippoglossus"),
    ("Pleuronectes platessa", "European plaice", "Actinopterygii", "Pleuronectiformes", "Pleuronectidae", "Pleuronectes"),
    ("Sphyraena barracuda", "Great barracuda", "Actinopterygii", "Carangiformes", "Sphyraenidae", "Sphyraena"),
    ("Dicentrarchus labrax", "European seabass", "Actinopterygii", "Acanthuriformes", "Moronidae", "Dicentrarchus"),
    ("Lutjanus campechanus", "Northern red snapper", "Actinopterygii", "Acanthuriformes", "Lutjanidae", "Lutjanus"),
    ("Epinephelus itajara", "Goliath grouper", "Actinopterygii", "Perciformes", "Epinephelidae", "Epinephelus"),
    ("Cheilinus undulatus", "Humphead wrasse", "Actinopterygii", "Labriformes", "Labridae", "Cheilinus"),
    ("Gobius niger", "Black goby", "Actinopterygii", "Gobiiformes", "Gobiidae", "Gobius"),
    ("Carcharodon carcharias", "Great white shark", "Chondrichthyes", "Lamniformes", "Lamnidae", "Carcharodon"),
    ("Galeocerdo cuvier", "Tiger shark", "Chondrichthyes", "Carcharhiniformes", "Carcharhinidae", "Galeocerdo"),
    ("Sphyrna lewini", "Scalloped hammerhead", "Chondrichthyes", "Carcharhiniformes", "Sphyrnidae", "Sphyrna"),
    ("Raja clavata", "Thornback ray", "Chondrichthyes", "Rajiformes", "Rajidae", "Raja"),
    ("Mobula birostris", "Giant manta ray", "Chondrichthyes", "Myliobatiformes", "Mobulidae", "Mobula"),
    ("Takifugu rubripes", "Japanese pufferfish", "Actinopterygii", "Tetraodontiformes", "Tetraodontidae", "Takifugu"),
    ("Diodon hystrix", "Porcupinefish", "Actinopterygii", "Tetraodontiformes", "Diodontidae", "Diodon"),]

def make_captions(binom, common, cls, order, family, genus):
    return [
        f"{binom}",
        f"a fish from family {family}",
        f"a fish from order {order}",
        f"a fish from class {cls}",]

def build_clip_lists(root="dataCLIP"):
    root = Path(root)
    image_paths, text_options = [], []
    for species in TAXA:
        binom, common, cls, order, family, genus = species
        captions = make_captions(binom, common, cls, order, family, genus)
        folder_path = root / binom.replace(" ", "_")
        for img in sorted(folder_path.glob("*")):
            if img.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}:
                image_paths.append(str(img))
                text_options.append(captions)
    assert len(image_paths) == len(text_options)
    return image_paths, text_options

class CLIP_dataset(Dataset):
    def __init__(self, list_image_path, list_txt_options):
        self.image_path = list_image_path
        self.text_options = list_txt_options

    def __len__(self):
        return len(self.image_path)

    def __getitem__(self, idx):
        img_path = self.image_path[idx]

        caption_list = self.text_options[idx]
        caption = random.choice(caption_list)

        img = Image.open(img_path)
        img = ImageOps.exif_transpose(img)
        if img.mode == "P" and ("transparency" in img.info or img.info.get("transparency") is not None):
            img = img.convert("RGBA").convert("RGB")
        else:
            img = img.convert("RGB")
        img.load()

        image = preprocess(img)
        title = clip.tokenize(caption)[0]
        return image, title

image_data, text_data = build_clip_lists("dataCLIP")
dataset = CLIP_dataset(image_data, text_data)
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True) 

In [3]:
# Tokenize the text prompts
prompts = ["Salmonidae", "Sphyraenidae",
           "Pomacanthidae", "Epinephelidae",
           "Moronidae", "Gymnotidae"]

text_tokens = clip.tokenize(prompts).to(device)

@torch.no_grad()
def clip_predict(image_path, text_tokens, texts, model, preprocess, device, topk=5):
    # Model evaluation mode
    model.eval()

    # Load and process image
    img = Image.open(image_path)
    image = preprocess(img).unsqueeze(0).to(device)

    # Infer image
    logits_per_image, _ = model(image, text_tokens)        
    probs = logits_per_image.softmax(dim=-1).squeeze(0)            

    # Get top k results
    k = min(topk, len(texts))
    scores, idx = torch.topk(probs, k=k, largest=True, sorted=True)
    return [(float(scores[i]), texts[int(idx[i])]) for i in range(k)]

def run_evaluation(imgFolder):
    # Assert image folder
    if not os.path.exists(imgFolder):
        print(f"{imgFolder} does not exist.")
        return 
    if not len(glob.glob(f"{imgFolder}/*")):
        print(f"{imgFolder} is empty.")
        return 

    # Run evaluation and tabulate
    rows = []
    for name in glob.glob(f"{imgFolder}/*"):
        results = clip_predict(name, text_tokens, prompts, model, preprocess, device, topk=3)
        name_split = os.path.basename(name).split(".")[0].split("_")
        
        row = [f"{name_split[0]} {name_split[1]}"] + [name_split[2]] + [f"{label} ({score:.3f})" for score, label in results]
        rows.append(row)

    headers = ["Image", "True Family", "Top-1", "Top-2", "Top-3"]
    print(tabulate(rows, headers=headers, tablefmt="fancy_grid"), "\n")

run_evaluation("zeroCLIP")

╒═══════════════════════════╤═══════════════╤════════════════════╤═══════════════════════╤═══════════════════════╕
│ Image                     │ True Family   │ Top-1              │ Top-2                 │ Top-3                 │
╞═══════════════════════════╪═══════════════╪════════════════════╪═══════════════════════╪═══════════════════════╡
│ Salmo hucho               │ Salmonidae    │ Salmonidae (0.999) │ Moronidae (0.001)     │ Pomacanthidae (0.000) │
├───────────────────────────┼───────────────┼────────────────────┼───────────────────────┼───────────────────────┤
│ Sphyraena novaehollandiae │ Sphyraenidae  │ Salmonidae (0.986) │ Moronidae (0.005)     │ Pomacanthidae (0.004) │
├───────────────────────────┼───────────────┼────────────────────┼───────────────────────┼───────────────────────┤
│ Centropyge boylei         │ Pomacanthidae │ Salmonidae (0.442) │ Pomacanthidae (0.313) │ Moronidae (0.093)     │
├───────────────────────────┼───────────────┼────────────────────┼──────────────

In [4]:
model = model.float().to(device)

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-6,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)
EPOCH = 25

for epoch in range(EPOCH):
    model.train()
    epoch_loss_sum = 0.0
    epoch_count = 0
    step_idx = 0
    print (15*"-", f"Epoch {epoch+1}", 15*"-")
    for batch in train_dataloader:
        optimizer.zero_grad()

        images, texts = batch
        images = images.to(device)
        texts  = texts.to(device)

        logits_per_image, logits_per_text = model(images, texts)
        logits_per_image = logits_per_image.contiguous()
        logits_per_text  = logits_per_text.contiguous()
        
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2

        total_loss.backward()
        optimizer.step()

        bs = images.size(0)
        epoch_loss_sum += float(total_loss.detach()) * bs
        epoch_count += bs
        
        step_idx += 1
        if step_idx % 10 == 0:
            running_avg = epoch_loss_sum / max(1, epoch_count)
            print(f"Epoch {epoch+1} | Step {step_idx}/{len(train_dataloader)} | Running average step loss: {running_avg:.4f}")
    epoch_avg = epoch_loss_sum / max(1, epoch_count)
    print(f"Epoch {epoch+1} | Epoch average loss: {epoch_avg:.4f}\n")
    run_evaluation("zeroCLIP")


--------------- Epoch 1 ---------------
Epoch 1 | Step 10/79 | Running average step loss: 3.5526
Epoch 1 | Step 20/79 | Running average step loss: 3.2833
Epoch 1 | Step 30/79 | Running average step loss: 3.1651
Epoch 1 | Step 40/79 | Running average step loss: 3.0949
Epoch 1 | Step 50/79 | Running average step loss: 3.0095
Epoch 1 | Step 60/79 | Running average step loss: 2.9676
Epoch 1 | Step 70/79 | Running average step loss: 2.9173
Epoch 1 | Epoch average loss: 2.8797

╒═══════════════════════════╤═══════════════╤═══════════════════════╤═══════════════════════╤══════════════════════╕
│ Image                     │ True Family   │ Top-1                 │ Top-2                 │ Top-3                │
╞═══════════════════════════╪═══════════════╪═══════════════════════╪═══════════════════════╪══════════════════════╡
│ Salmo hucho               │ Salmonidae    │ Salmonidae (0.283)    │ Epinephelidae (0.177) │ Moronidae (0.160)    │
├───────────────────────────┼───────────────┼──────────

In [5]:
torch.save({
        'epoch': 25,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"Res101_25F.pt")

In [6]:
prompts = ["Salmonidae", "Sphyraenidae",
           "Pomacanthidae", "Epinephelidae",
           "Morinidae", "Gymnotidae"]
text_tokens = clip.tokenize(prompts).to(device)
run_evaluation("zeroCLIP")

╒═══════════════════════════╤═══════════════╤═══════════════════════╤═══════════════════════╤═══════════════════════╕
│ Image                     │ True Family   │ Top-1                 │ Top-2                 │ Top-3                 │
╞═══════════════════════════╪═══════════════╪═══════════════════════╪═══════════════════════╪═══════════════════════╡
│ Salmo hucho               │ Salmonidae    │ Salmonidae (0.988)    │ Morinidae (0.010)     │ Sphyraenidae (0.001)  │
├───────────────────────────┼───────────────┼───────────────────────┼───────────────────────┼───────────────────────┤
│ Sphyraena novaehollandiae │ Sphyraenidae  │ Sphyraenidae (0.995)  │ Gymnotidae (0.004)    │ Epinephelidae (0.000) │
├───────────────────────────┼───────────────┼───────────────────────┼───────────────────────┼───────────────────────┤
│ Centropyge boylei         │ Pomacanthidae │ Pomacanthidae (0.869) │ Morinidae (0.069)     │ Gymnotidae (0.033)    │
├───────────────────────────┼───────────────┼───────────

In [7]:
prompts = ["Salmonidae", "Sphyraenidae",
           "Pomacanthidae", "Epinephelidae",
           "Morinidae", "Gymnotidae", "Electrophorus"]
text_tokens = clip.tokenize(prompts).to(device)
run_evaluation("zeroCLIP")

╒═══════════════════════════╤═══════════════╤═══════════════════════╤═══════════════════════╤═══════════════════════╕
│ Image                     │ True Family   │ Top-1                 │ Top-2                 │ Top-3                 │
╞═══════════════════════════╪═══════════════╪═══════════════════════╪═══════════════════════╪═══════════════════════╡
│ Salmo hucho               │ Salmonidae    │ Salmonidae (0.986)    │ Morinidae (0.010)     │ Electrophorus (0.001) │
├───────────────────────────┼───────────────┼───────────────────────┼───────────────────────┼───────────────────────┤
│ Sphyraena novaehollandiae │ Sphyraenidae  │ Sphyraenidae (0.986)  │ Electrophorus (0.009) │ Gymnotidae (0.004)    │
├───────────────────────────┼───────────────┼───────────────────────┼───────────────────────┼───────────────────────┤
│ Centropyge boylei         │ Pomacanthidae │ Pomacanthidae (0.818) │ Morinidae (0.065)     │ Electrophorus (0.058) │
├───────────────────────────┼───────────────┼───────────