In [None]:
from pathlib import Path
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.optim.lr_scheduler import OneCycleLR
import timm
from tqdm.auto import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

BASE_DIR    = Path().resolve()           # repo root
DATASET_DIR = BASE_DIR / "dataset"
train_dir   = DATASET_DIR / "train"
val_dir     = DATASET_DIR / "valid"
test_dir    = DATASET_DIR / "test"

for d in (train_dir, val_dir, test_dir):
    print(f"{d}: {'OK' if d.exists() else 'MISSING'}")

C:\Users\hjmso\Programming_Projects\fishingapp\machinelearning\dataset\train: OK
C:\Users\hjmso\Programming_Projects\fishingapp\machinelearning\dataset\valid: OK
C:\Users\hjmso\Programming_Projects\fishingapp\machinelearning\dataset\test: OK


In [3]:
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])
val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

In [4]:
train_ds = ImageFolder(str(train_dir), transform=train_tfms)
val_ds   = ImageFolder(str(val_dir),   transform=val_tfms)
test_ds  = ImageFolder(str(test_dir),  transform=val_tfms)

# 0‑based class names
class_names = train_ds.classes
num_classes = len(class_names)
print(f"Found {num_classes} classes:", class_names[:5], "…", class_names[-5:])

# (Optional) regenerate species_mapping.json for inference
with open("species_mapping.json","w") as f:
    json.dump(class_names, f, indent=2)

Found 549 classes: ['Abalistes_stellatus', 'Abudefduf_saxatilis', 'Acanthemblemaria_spinosa', 'Acanthochromis_polyacanthus', 'Acanthurus_achilles'] … ['Zebrasoma_desjardinii', 'Zebrasoma_flavescens', 'Zebrasoma_scopas', 'Zebrasoma_veliferum', 'Zebrasoma_xanthurum']


In [5]:
counts = torch.bincount(torch.tensor(train_ds.targets))
weights = 1.0 / counts.float()
sample_weights = [weights[t] for t in train_ds.targets]
train_sampler = WeightedRandomSampler(
    sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

In [6]:
batch_size  = 32
num_workers = 0

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    sampler=train_sampler,
    num_workers=num_workers,
    pin_memory=True
)
val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)
test_loader = DataLoader(
    test_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

# quick sanity‐check
imgs, labs = next(iter(train_loader))
print("Batch OK:", imgs.shape, labs.shape)

Batch OK: torch.Size([32, 3, 224, 224]) torch.Size([32])


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = timm.create_model(
    "efficientnet_b0",
    pretrained=True,
    num_classes=num_classes
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

num_epochs = 20
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,
    anneal_strategy="cos"
)

In [8]:
def train_one_epoch():
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for imgs, labs in tqdm(train_loader, desc="Train"):
        imgs, labs = imgs.to(device), labs.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss   = criterion(logits, labs)
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=1)
        correct   += (preds == labs).sum().item()
        total     += labs.size(0)
    return running_loss/total, correct/total

def validate():
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, labs in tqdm(val_loader, desc="Val  "):
            imgs, labs = imgs.to(device), labs.to(device)
            logits = model(imgs)
            loss   = criterion(logits, labs)
            val_loss += loss.item() * imgs.size(0)
            preds    = logits.argmax(dim=1)
            correct  += (preds == labs).sum().item()
            total    += labs.size(0)
    return val_loss/total, correct/total

In [9]:
best_val_acc = 0.0
for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch()
    val_loss,   val_acc   = validate()
    print(f"Epoch {epoch:02d} | "
          f"Train: loss={train_loss:.4f}, acc={train_acc:.4f} | "
          f" Val: loss={val_loss:.4f}, acc={val_acc:.4f}")

    # save best checkpoint
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "best_val_acc": best_val_acc
        }, "checkpoint_best_2.pt")

Train:   0%|          | 0/940 [00:00<?, ?it/s]


UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='C:\\Users\\hjmso\\Programming_Projects\\fishingapp\\machinelearning\\dataset\\train\\Neoglyphidodon_crossi\\85.jpg'>

### Some Notes on Training
1. Previously attempted static StepLR:
    - no warm up, jumped straight to fixed lr=1e-4
    - coarse halving for decay every handful of epochs
    - this causes slow initial learning, risk of getting stuck in sub-optimal minima, then too-abrupt drops later
2. What I changed:
    - warm up phase for first 10% slowly raise LR from almost 0 to 1e-3 (prevents network from jittering on random initial weights)
    - high LR exploration around midpoint to escape poor local minima (encourage discovery of flatter more generalized solutions)
    - smooth cosine decay in latter 90% (gradually anneal LR back to 0, avoids sudden large drops that can throw optimizer off track)