## 1. Check GPU & Install Dependencies

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

In [None]:
# Install additional dependencies
!pip install -q albumentations timm
print("‚úÖ Dependencies installed!")

## 2. Mount Google Drive & Load CSV

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

In [None]:
import pandas as pd
import numpy as np
import os

# ‚ö†Ô∏è CONFIGURE YOUR CSV PATH HERE
CSV_PATH = "/content/drive/MyDrive/otolith_species.csv"  # <-- CHANGE THIS IF NEEDED

# Load the dataset
df = pd.read_csv(CSV_PATH)
print(f"‚úÖ Loaded {len(df)} records")
print(f"\nColumns: {list(df.columns)}")
print(f"\nUnique species: {df['scientificName'].nunique()}")

In [None]:
# Clean data - remove rows with missing URLs
df = df[df['associatedMedia'].notna()].copy()
df = df[df['associatedMedia'].astype(str).str.startswith('http')].copy()
print(f"Records with valid URLs: {len(df)}")

In [None]:
# Analyze species distribution
species_counts = df['scientificName'].value_counts()
print("Top 20 species by image count:")
print(species_counts.head(20))

print(f"\n\nSpecies with >= 50 images: {(species_counts >= 50).sum()}")
print(f"Species with >= 20 images: {(species_counts >= 20).sum()}")
print(f"Species with >= 10 images: {(species_counts >= 10).sum()}")

In [None]:
import matplotlib.pyplot as plt

# Plot species distribution
plt.figure(figsize=(14, 6))
species_counts.head(30).plot(kind='bar')
plt.title('Top 30 Species by Image Count')
plt.xlabel('Species')
plt.ylabel('Number of Images')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 3. Filter & Prepare Dataset

We'll filter to species with enough samples for training.

In [None]:
# ============================================
# ‚ö†Ô∏è CONFIGURATION
# ============================================

MIN_SAMPLES_PER_SPECIES = 10  # Minimum images per species to include
MAX_SPECIES = 50  # Maximum number of species to classify (top N by count)

# Filter species with enough samples
valid_species = species_counts[species_counts >= MIN_SAMPLES_PER_SPECIES].head(MAX_SPECIES).index.tolist()
df_filtered = df[df['scientificName'].isin(valid_species)].copy()

print(f"Selected {len(valid_species)} species with >= {MIN_SAMPLES_PER_SPECIES} samples")
print(f"Total images for training: {len(df_filtered)}")

# Create label encoding
species_to_idx = {species: idx for idx, species in enumerate(sorted(valid_species))}
idx_to_species = {idx: species for species, idx in species_to_idx.items()}

df_filtered['label'] = df_filtered['scientificName'].map(species_to_idx)

print(f"\nNumber of classes: {len(species_to_idx)}")

In [None]:
# Save label mapping for later use
import json

label_mapping = {
    'species_to_idx': species_to_idx,
    'idx_to_species': idx_to_species,
    'num_classes': len(species_to_idx)
}

print("Label mapping (first 10):")
for species, idx in list(species_to_idx.items())[:10]:
    print(f"  {idx}: {species}")

## 4. Download Otolith Images

Download images from the AFORO server. This may take a while.

In [None]:
import requests
from pathlib import Path
from tqdm.notebook import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# Create directories
IMAGE_DIR = Path("/content/otolith_images")
IMAGE_DIR.mkdir(exist_ok=True)

def download_image(row):
    """Download a single image with error handling"""
    url = row.get('associatedMedia')
    
    # Skip if URL is missing or invalid
    if not url or not isinstance(url, str) or not url.startswith('http'):
        return False, "Invalid URL"
    
    catalog_num = str(row['catalogNumber'])
    label = row['label']
    
    # Create species folder
    species_folder = IMAGE_DIR / f"class_{label:03d}"
    species_folder.mkdir(exist_ok=True)
    
    # Filename
    try:
        ext = url.split('.')[-1].lower()
        if ext not in ['jpg', 'jpeg', 'png', 'tif', 'tiff']:
            ext = 'tif'
    except:
        ext = 'tif'
    
    filename = species_folder / f"{catalog_num}.{ext}"
    
    # Skip if exists
    if filename.exists():
        return True, str(filename)
    
    try:
        response = requests.get(url, timeout=30)
        if response.status_code == 200:
            with open(filename, 'wb') as f:
                f.write(response.content)
            return True, str(filename)
        else:
            return False, f"HTTP {response.status_code}"
    except Exception as e:
        return False, str(e)

print(f"Will download {len(df_filtered)} images...")
print("This may take 10-20 minutes.\n")

In [None]:
# Download images with progress bar
successful = 0
failed = 0
failed_urls = []

# Convert to list of dicts for easier processing
records = df_filtered.to_dict('records')

with ThreadPoolExecutor(max_workers=10) as executor:
    futures = {executor.submit(download_image, row): row for row in records}
    
    for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading"):
        success, result = future.result()
        if success:
            successful += 1
        else:
            failed += 1
            if failed <= 10:
                failed_urls.append(result)

print(f"\n‚úÖ Downloaded: {successful}")
print(f"‚ùå Failed: {failed}")

if failed_urls:
    print(f"\nSample failures: {failed_urls[:5]}")

In [None]:
# Verify downloaded images
from PIL import Image

# Count images per class
class_counts = {}
all_images = []

for class_folder in sorted(IMAGE_DIR.iterdir()):
    if class_folder.is_dir():
        images = list(class_folder.glob("*"))
        class_idx = int(class_folder.name.split('_')[1])
        class_counts[class_idx] = len(images)
        all_images.extend([(str(img), class_idx) for img in images])

print(f"Total images downloaded: {len(all_images)}")
print(f"Classes with images: {len(class_counts)}")
print(f"\nImages per class (first 10):")
for idx in sorted(class_counts.keys())[:10]:
    species_name = idx_to_species.get(idx, 'Unknown')[:40]
    print(f"  Class {idx} ({species_name}...): {class_counts[idx]} images")

In [None]:
# Visualize sample images
import random

fig, axes = plt.subplots(3, 4, figsize=(16, 12))

sample_images = random.sample(all_images, min(12, len(all_images)))

for idx, (img_path, label) in enumerate(sample_images):
    ax = axes[idx // 4, idx % 4]
    try:
        img = Image.open(img_path).convert('RGB')
        ax.imshow(img)
        species_name = idx_to_species.get(label, 'Unknown')
        short_name = ' '.join(species_name.split()[:2])
        ax.set_title(f"{short_name}\n(Class {label})", fontsize=9)
    except Exception as e:
        ax.set_title(f"Error loading")
    ax.axis('off')

plt.suptitle('Sample Otolith Images by Species', fontsize=14)
plt.tight_layout()
plt.show()

## 5. Create Dataset & DataLoaders

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from collections import Counter

class OtolithDataset(Dataset):
    """Dataset for otolith species classification"""
    
    def __init__(self, image_paths, labels, transform=None, image_size=224):
        self.image_paths = image_paths
        self.labels = labels
        self.image_size = image_size
        self.transform = transform or self._default_transform()
    
    def _default_transform(self):
        return A.Compose([
            A.Resize(self.image_size, self.image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            image = np.array(Image.open(img_path).convert('RGB'))
            transformed = self.transform(image=image)
            image = transformed['image']
        except Exception as e:
            image = torch.zeros(3, self.image_size, self.image_size)
        
        return image, label

print("‚úÖ Dataset class defined")

In [None]:
# Prepare image paths and labels
image_paths = [x[0] for x in all_images]
labels = [x[1] for x in all_images]

# Filter out classes with too few samples
class_counts_check = Counter(labels)
print(f"Total classes before filtering: {len(class_counts_check)}")

# Keep only classes with at least 3 images
MIN_IMAGES = 3
valid_classes = {cls for cls, count in class_counts_check.items() if count >= MIN_IMAGES}
print(f"Classes with >= {MIN_IMAGES} images: {len(valid_classes)}")

# Filter
filtered_data = [(path, label) for path, label in zip(image_paths, labels) if label in valid_classes]
image_paths = [x[0] for x in filtered_data]
labels = [x[1] for x in filtered_data]

print(f"Total images after filtering: {len(image_paths)}")

In [None]:
# Split data into train/val/test (without stratify to avoid errors)
# First split: 80% train, 20% temp
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42
)

# Second split: 50% val, 50% test
val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, test_size=0.5, random_state=42
)

print(f"Train: {len(train_paths)} images")
print(f"Val: {len(val_paths)} images")
print(f"Test: {len(test_paths)} images")

In [None]:
# Define transforms
IMAGE_SIZE = 224
BATCH_SIZE = 32

train_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
    A.OneOf([
        A.GaussNoise(var_limit=(10, 50)),
        A.GaussianBlur(blur_limit=3),
    ], p=0.3),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Create datasets
train_dataset = OtolithDataset(train_paths, train_labels, train_transform, IMAGE_SIZE)
val_dataset = OtolithDataset(val_paths, val_labels, val_transform, IMAGE_SIZE)
test_dataset = OtolithDataset(test_paths, test_labels, val_transform, IMAGE_SIZE)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 6. Create Model

In [None]:
import timm
import torch.nn as nn

class OtolithSpeciesClassifier(nn.Module):
    """CNN for otolith species classification using transfer learning"""
    
    def __init__(self, num_classes, model_name='efficientnet_b0', pretrained=True):
        super().__init__()
        
        # Load pretrained backbone
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        
        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            self.feature_dim = features.shape[1]
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )
        
        print(f"Model: {model_name}")
        print(f"Feature dimension: {self.feature_dim}")
        print(f"Number of classes: {num_classes}")
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# Get number of unique classes from actual data
NUM_CLASSES = len(set(labels))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = OtolithSpeciesClassifier(NUM_CLASSES, model_name='efficientnet_b0')
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Device: {device}")

## 7. Training

In [None]:
import torch.optim as optim
from tqdm.notebook import tqdm

# Training configuration
EPOCHS = 30
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.01

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LEARNING_RATE/100)

print(f"Training for {EPOCHS} epochs")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Device: {device}")

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(loader), 100. * correct / total

print("‚úÖ Training functions defined")

In [None]:
# Create output directory
from datetime import datetime

OUTPUT_DIR = Path("/content/drive/MyDrive/otolith_models/species_classifier")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = OUTPUT_DIR / f"run_{timestamp}"
run_dir.mkdir(exist_ok=True)

print(f"Saving to: {run_dir}")

In [None]:
# Training loop
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}
best_val_acc = 0
patience_counter = 0
early_stop_patience = 10

print("\n" + "="*60)
print("üöÄ STARTING TRAINING")
print("="*60 + "\n")

for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    print("-" * 40)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Update scheduler
    scheduler.step()
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print metrics
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    print(f"  LR: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'num_classes': NUM_CLASSES,
            'species_to_idx': species_to_idx,
            'idx_to_species': idx_to_species,
        }, run_dir / "checkpoint_best.pt")
        print(f"  ‚úÖ New best model saved! (Acc: {val_acc:.2f}%)")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\n‚ö†Ô∏è Early stopping after {epoch + 1} epochs")
            break
    
    print()

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE!")
print(f"   Best Validation Accuracy: {best_val_acc:.2f}%")
print(f"   Model saved to: {run_dir}")
print("="*60)

## 8. Evaluate Results

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(run_dir / "training_curves.png", dpi=150)
plt.show()

In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load(run_dir / "checkpoint_best.pt")
model.load_state_dict(checkpoint['model_state_dict'])

test_loss, test_acc = validate(model, test_loader, criterion, device)

print("\nüìä Test Set Results:")
print(f"   Loss: {test_loss:.4f}")
print(f"   Accuracy: {test_acc:.2f}%")

In [None]:
# Detailed evaluation
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels_batch in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels_batch.numpy())

# Get unique classes in test set
unique_classes = sorted(set(all_labels))
print(f"Classes in test set: {len(unique_classes)}")

# Classification report for top classes
top_classes = unique_classes[:10]
print("\nClassification Report (Top 10 classes):")
print(classification_report(
    all_labels, all_preds, 
    labels=top_classes,
    target_names=[idx_to_species.get(i, f'Class {i}')[:30] for i in top_classes],
    zero_division=0
))

In [None]:
# Plot confusion matrix for top classes
TOP_N = min(15, len(unique_classes))
top_classes = unique_classes[:TOP_N]

# Filter to top classes
mask = [l in top_classes for l in all_labels]
filtered_labels = [l for l, m in zip(all_labels, mask) if m]
filtered_preds = [p for p, m in zip(all_preds, mask) if m]

if len(filtered_labels) > 0:
    cm = confusion_matrix(filtered_labels, filtered_preds, labels=top_classes)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[idx_to_species.get(i, '?').split()[0][:10] for i in top_classes],
                yticklabels=[idx_to_species.get(i, '?').split()[0][:10] for i in top_classes])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(f'Confusion Matrix (Top {TOP_N} Species)')
    plt.tight_layout()
    plt.savefig(run_dir / "confusion_matrix.png", dpi=150)
    plt.show()
else:
    print("Not enough data for confusion matrix")

## 9. Export Model for Production

In [None]:
# Export to ONNX format
model.eval()
dummy_input = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE).to(device)

onnx_path = run_dir / "species_classifier.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    input_names=['image'],
    output_names=['logits'],
    dynamic_axes={'image': {0: 'batch_size'}, 'logits': {0: 'batch_size'}}
)

print(f"‚úÖ ONNX model exported: {onnx_path}")
print(f"   File size: {os.path.getsize(onnx_path) / 1e6:.1f} MB")

In [None]:
# Save complete model info
model_info = {
    'model_name': 'efficientnet_b0',
    'image_size': IMAGE_SIZE,
    'num_classes': NUM_CLASSES,
    'species_to_idx': species_to_idx,
    'idx_to_species': {str(k): v for k, v in idx_to_species.items()},
    'best_val_acc': best_val_acc,
    'test_acc': test_acc,
    'train_samples': len(train_paths),
    'val_samples': len(val_paths),
    'test_samples': len(test_paths),
    'timestamp': timestamp,
}

with open(run_dir / "model_info.json", "w") as f:
    json.dump(model_info, f, indent=2)

print("\nüìÅ Saved files:")
for f in sorted(run_dir.iterdir()):
    size = os.path.getsize(f) / 1e6
    print(f"   {f.name} ({size:.1f} MB)")

## 10. Test Single Image Prediction

In [None]:
def predict_species(model, image_path, device, idx_to_species, top_k=5):
    """Predict species from an otolith image"""
    model.eval()
    
    # Load and preprocess
    image = np.array(Image.open(image_path).convert('RGB'))
    
    transform = A.Compose([
        A.Resize(224, 224),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.softmax(outputs, dim=1)[0]
        top_probs, top_indices = probs.topk(top_k)
    
    results = []
    for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
        results.append({
            'species': idx_to_species.get(idx, f'Unknown class {idx}'),
            'confidence': float(prob)
        })
    
    return results

# Test on a random image
if len(test_paths) > 0:
    test_image = random.choice(test_paths)
    predictions = predict_species(model, test_image, device, idx_to_species)
    
    print("\nüîç Sample Prediction:")
    print(f"Image: {Path(test_image).name}")
    print(f"\nTop 5 Predictions:")
    for i, pred in enumerate(predictions, 1):
        print(f"  {i}. {pred['species'][:50]}")
        print(f"     Confidence: {pred['confidence']:.1%}")

In [None]:
# Visualize prediction
if len(test_paths) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Show image
    img = Image.open(test_image).convert('RGB')
    axes[0].imshow(img)
    axes[0].set_title(f"Input Image\n{Path(test_image).name}")
    axes[0].axis('off')
    
    # Show predictions
    species_names = [p['species'].split('(')[0].strip()[:25] for p in predictions]
    confidences = [p['confidence'] for p in predictions]
    colors = ['green' if i == 0 else 'steelblue' for i in range(len(predictions))]
    
    axes[1].barh(species_names[::-1], confidences[::-1], color=colors[::-1])
    axes[1].set_xlabel('Confidence')
    axes[1].set_title('Species Predictions')
    axes[1].set_xlim(0, 1)
    
    for i, (name, conf) in enumerate(zip(species_names[::-1], confidences[::-1])):
        axes[1].text(conf + 0.02, i, f'{conf:.1%}', va='center')
    
    plt.tight_layout()
    plt.show()

## üéâ Done!

Your trained species classifier is saved in Google Drive:
- `checkpoint_best.pt` - PyTorch model with label mappings
- `species_classifier.onnx` - ONNX format for production
- `model_info.json` - Model metadata and species list

### Next Steps:
1. Download the model files from Google Drive
2. Copy to your project: `ai-services/models/`
3. Integrate with your Ocean platform

### Integration Example:
```python
# In your otolith_analyzer.py
import torch
import json

# Load model
checkpoint = torch.load('models/species_classifier/checkpoint_best.pt')
model.load_state_dict(checkpoint['model_state_dict'])
idx_to_species = checkpoint['idx_to_species']

# Predict
species = predict_species(model, image_path, device, idx_to_species)
```