## 1. Check GPU & Install Dependencies

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

In [None]:
# Install additional dependencies
!pip install -q albumentations tensorboard timm
print("‚úÖ Dependencies installed!")

## 2. Mount Google Drive

Your otolith images should be organized in Google Drive like this:
```
MyDrive/
‚îî‚îÄ‚îÄ otolith_data/
    ‚îú‚îÄ‚îÄ train/
    ‚îÇ   ‚îú‚îÄ‚îÄ age_01/  (images of 1-year-old fish)
    ‚îÇ   ‚îú‚îÄ‚îÄ age_02/  (images of 2-year-old fish)
    ‚îÇ   ‚îú‚îÄ‚îÄ age_03/
    ‚îÇ   ‚îî‚îÄ‚îÄ ...
    ‚îî‚îÄ‚îÄ val/
        ‚îú‚îÄ‚îÄ age_01/
        ‚îú‚îÄ‚îÄ age_02/
        ‚îî‚îÄ‚îÄ ...
```

In [None]:
from google.colab import drive
drive.mount('/content/drive')
print("‚úÖ Google Drive mounted!")

In [None]:
# ‚ö†Ô∏è CONFIGURE YOUR DATA PATH HERE
# Change this to match your Google Drive folder structure

DATA_DIR = "/content/drive/MyDrive/otolith_data"  # <-- CHANGE THIS

import os
if os.path.exists(DATA_DIR):
    print(f"‚úÖ Found data directory: {DATA_DIR}")
    print("\nContents:")
    for item in os.listdir(DATA_DIR):
        item_path = os.path.join(DATA_DIR, item)
        if os.path.isdir(item_path):
            count = sum(len(files) for _, _, files in os.walk(item_path))
            print(f"  üìÅ {item}/ ({count} files)")
else:
    print(f"‚ùå Directory not found: {DATA_DIR}")
    print("Please update DATA_DIR to point to your otolith images folder")

## 3. Define Model & Dataset

In [None]:
import os
import re
import json
import numpy as np
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from typing import Optional, Tuple, List, Dict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import timm

from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

print("‚úÖ All imports successful!")

In [None]:
@dataclass
class Config:
    """Training configuration"""
    # Data
    data_dir: str = "/content/drive/MyDrive/otolith_data"
    image_size: int = 224
    
    # Model
    model_name: str = "efficientnet_b0"  # Options: efficientnet_b0, efficientnet_b2, resnet50, vit_b_16
    pretrained: bool = True
    
    # Training
    batch_size: int = 32
    epochs: int = 50
    lr: float = 0.0001
    weight_decay: float = 0.01
    
    # Task
    task: str = "regression"  # or "classification"
    num_classes: int = 20  # Max age for classification
    
    # Output
    output_dir: str = "/content/drive/MyDrive/otolith_models"
    
    @property
    def device(self):
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class OtolithDataset(Dataset):
    """Dataset for otolith images with age labels"""
    
    def __init__(self, data_dir: str, split: str = "train", 
                 image_size: int = 224, task: str = "regression"):
        self.data_dir = Path(data_dir) / split
        self.split = split
        self.task = task
        self.image_size = image_size
        
        # Collect samples
        self.samples = []
        self._load_from_folders()
        
        # Setup transforms
        self.transform = self._get_transforms()
        
        print(f"  {split}: {len(self.samples)} images")
    
    def _load_from_folders(self):
        """Load from folder structure: age_01/, age_02/, etc."""
        if not self.data_dir.exists():
            print(f"‚ö†Ô∏è Directory not found: {self.data_dir}")
            return
            
        for age_folder in sorted(self.data_dir.iterdir()):
            if not age_folder.is_dir():
                continue
            
            # Extract age from folder name (e.g., "age_05" -> 5)
            match = re.search(r'(\d+)', age_folder.name)
            if not match:
                continue
            age = int(match.group(1))
            
            # Collect all images in this folder
            for img_path in age_folder.glob("*"):
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp", ".tiff"]:
                    self.samples.append((str(img_path), age))
    
    def _get_transforms(self):
        if self.split == "train":
            return A.Compose([
                A.Resize(self.image_size, self.image_size),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(p=0.5),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=45, p=0.5),
                A.OneOf([
                    A.GaussNoise(var_limit=(10, 50)),
                    A.GaussianBlur(blur_limit=3),
                    A.MotionBlur(blur_limit=3),
                ], p=0.3),
                A.OneOf([
                    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2),
                    A.CLAHE(clip_limit=2),
                    A.Equalize(),
                ], p=0.5),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2(),
            ])
        else:
            return A.Compose([
                A.Resize(self.image_size, self.image_size),
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2(),
            ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, age = self.samples[idx]
        
        # Load image
        image = np.array(Image.open(img_path).convert("RGB"))
        
        # Apply transforms
        transformed = self.transform(image=image)
        image = transformed["image"]
        
        # Prepare label
        if self.task == "regression":
            label = torch.tensor(float(age), dtype=torch.float32)
        else:
            label = torch.tensor(age, dtype=torch.long)
        
        return image, label

In [None]:
class OtolithAgeNet(nn.Module):
    """CNN for otolith age estimation using transfer learning"""
    
    def __init__(self, model_name: str = "efficientnet_b0", 
                 pretrained: bool = True, task: str = "regression",
                 num_classes: int = 20):
        super().__init__()
        
        self.task = task
        self.num_classes = num_classes
        
        # Load backbone
        self.backbone = timm.create_model(
            model_name, 
            pretrained=pretrained, 
            num_classes=0  # Remove classifier
        )
        
        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            self.feature_dim = features.shape[1]
        
        # Build head
        if task == "regression":
            self.head = nn.Sequential(
                nn.Linear(self.feature_dim, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 64),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(64, 1)
            )
        else:
            self.head = nn.Sequential(
                nn.Linear(self.feature_dim, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, num_classes)
            )
        
        print(f"Model: {model_name}, Features: {self.feature_dim}, Task: {task}")
    
    def forward(self, x):
        features = self.backbone(x)
        output = self.head(features)
        if self.task == "regression":
            output = output.squeeze(-1)
        return output

## 4. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []
    
    pbar = tqdm(loader, desc="Training", leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = outputs.detach().cpu().numpy()
        labels_np = labels.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels_np)
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_loss = total_loss / len(loader)
    mae = mean_absolute_error(all_labels, all_preds)
    
    return avg_loss, mae


def validate(model, loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            all_preds.extend(outputs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    mae = mean_absolute_error(all_labels, all_preds)
    rmse = np.sqrt(mean_squared_error(all_labels, all_preds))
    r2 = r2_score(all_labels, all_preds)
    
    return avg_loss, mae, rmse, r2, all_preds, all_labels

In [None]:
def plot_training_history(history):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Validation')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True)
    
    # MAE
    axes[1].plot(history['train_mae'], label='Train')
    axes[1].plot(history['val_mae'], label='Validation')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE (years)')
    axes[1].set_title('Mean Absolute Error')
    axes[1].legend()
    axes[1].grid(True)
    
    # R2
    axes[2].plot(history['val_r2'])
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('R¬≤ Score')
    axes[2].set_title('Validation R¬≤ Score')
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()


def plot_predictions(preds, labels):
    """Plot predicted vs actual ages"""
    plt.figure(figsize=(8, 8))
    plt.scatter(labels, preds, alpha=0.5)
    
    # Perfect prediction line
    max_val = max(max(labels), max(preds))
    plt.plot([0, max_val], [0, max_val], 'r--', label='Perfect prediction')
    
    plt.xlabel('Actual Age (years)')
    plt.ylabel('Predicted Age (years)')
    plt.title('Predicted vs Actual Age')
    plt.legend()
    plt.grid(True)
    plt.axis('equal')
    plt.show()

## 5. Configure & Start Training

‚ö†Ô∏è **Update the configuration below before running!**

In [None]:
# ============================================
# ‚ö†Ô∏è CONFIGURE YOUR TRAINING HERE
# ============================================

config = Config(
    # Path to your data in Google Drive
    data_dir="/content/drive/MyDrive/otolith_data",  # <-- CHANGE THIS
    
    # Model settings
    model_name="efficientnet_b0",  # Options: efficientnet_b0, efficientnet_b2, resnet50
    
    # Training settings
    batch_size=32,    # Reduce to 16 or 8 if you run out of memory
    epochs=50,        # More epochs = better results (if not overfitting)
    lr=0.0001,        # Learning rate
    
    # Image size
    image_size=224,   # 224 for efficientnet_b0, 288 for b2
    
    # Where to save the model
    output_dir="/content/drive/MyDrive/otolith_models",
)

print("Configuration:")
print(f"  Data: {config.data_dir}")
print(f"  Model: {config.model_name}")
print(f"  Epochs: {config.epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Device: {config.device}")

In [None]:
# Create datasets
print("Loading datasets...")
train_dataset = OtolithDataset(config.data_dir, "train", config.image_size, config.task)
val_dataset = OtolithDataset(config.data_dir, "val", config.image_size, config.task)

if len(train_dataset) == 0:
    raise ValueError("‚ùå No training images found! Check your data_dir path.")

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True
)

print(f"\n‚úÖ Data loaded successfully!")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

In [None]:
# Visualize some training samples
print("Sample training images:")
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i, (img, label) in enumerate(train_dataset):
    if i >= 8:
        break
    ax = axes[i // 4, i % 4]
    
    # Denormalize for display
    img_display = img.numpy().transpose(1, 2, 0)
    img_display = img_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_display = np.clip(img_display, 0, 1)
    
    ax.imshow(img_display)
    ax.set_title(f"Age: {label:.0f} years")
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Create model
print("Creating model...")
model = OtolithAgeNet(
    model_name=config.model_name,
    pretrained=config.pretrained,
    task=config.task,
    num_classes=config.num_classes
)
model = model.to(config.device)

# Loss function
if config.task == "regression":
    criterion = nn.MSELoss()
else:
    criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.AdamW(
    model.parameters(), 
    lr=config.lr, 
    weight_decay=config.weight_decay
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=config.epochs,
    eta_min=config.lr / 100
)

print(f"\n‚úÖ Model ready! Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Create output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = Path(config.output_dir) / f"run_{timestamp}"
run_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {run_dir}")

# Save config
with open(run_dir / "config.json", "w") as f:
    json.dump(asdict(config), f, indent=2, default=str)

print("\n" + "="*60)
print("üöÄ STARTING TRAINING")
print("="*60 + "\n")

In [None]:
# Training loop
history = {
    'train_loss': [], 'train_mae': [],
    'val_loss': [], 'val_mae': [], 'val_rmse': [], 'val_r2': []
}
best_val_mae = float('inf')
patience_counter = 0
early_stop_patience = 10

for epoch in range(config.epochs):
    print(f"\nEpoch {epoch + 1}/{config.epochs}")
    print("-" * 40)
    
    # Train
    train_loss, train_mae = train_epoch(
        model, train_loader, criterion, optimizer, config.device
    )
    
    # Validate
    val_loss, val_mae, val_rmse, val_r2, preds, labels = validate(
        model, val_loader, criterion, config.device
    )
    
    # Update scheduler
    scheduler.step()
    
    # Record history
    history['train_loss'].append(train_loss)
    history['train_mae'].append(train_mae)
    history['val_loss'].append(val_loss)
    history['val_mae'].append(val_mae)
    history['val_rmse'].append(val_rmse)
    history['val_r2'].append(val_r2)
    
    # Print metrics
    print(f"  Train Loss: {train_loss:.4f}, Train MAE: {train_mae:.2f} years")
    print(f"  Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.2f} years")
    print(f"  Val RMSE: {val_rmse:.2f} years, Val R¬≤: {val_r2:.4f}")
    print(f"  LR: {scheduler.get_last_lr()[0]:.6f}")
    
    # Save best model
    if val_mae < best_val_mae:
        best_val_mae = val_mae
        patience_counter = 0
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_mae': val_mae,
            'val_r2': val_r2,
            'config': asdict(config),
        }, run_dir / "checkpoint_best.pt")
        print(f"  ‚úÖ New best model saved! (MAE: {val_mae:.2f})")
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\n‚ö†Ô∏è Early stopping triggered after {epoch + 1} epochs")
            break
    
    # Save latest checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
    }, run_dir / "checkpoint_latest.pt")

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE!")
print(f"   Best Validation MAE: {best_val_mae:.2f} years")
print(f"   Model saved to: {run_dir}")
print("="*60)

## 6. Evaluate Results

In [None]:
# Plot training history
plot_training_history(history)

In [None]:
# Load best model and evaluate
checkpoint = torch.load(run_dir / "checkpoint_best.pt")
model.load_state_dict(checkpoint['model_state_dict'])

# Final validation
val_loss, val_mae, val_rmse, val_r2, preds, labels = validate(
    model, val_loader, criterion, config.device
)

print("\nüìä Final Model Performance:")
print(f"   MAE: {val_mae:.2f} years")
print(f"   RMSE: {val_rmse:.2f} years")
print(f"   R¬≤ Score: {val_r2:.4f}")

# Plot predictions
plot_predictions(preds, labels)

## 7. Export Model for Production

In [None]:
# Export to ONNX format (for production deployment)
model.eval()
dummy_input = torch.randn(1, 3, config.image_size, config.image_size).to(config.device)

onnx_path = run_dir / "model.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,
    opset_version=11,
    input_names=['image'],
    output_names=['age'],
    dynamic_axes={'image': {0: 'batch_size'}, 'age': {0: 'batch_size'}}
)

print(f"\n‚úÖ Model exported to ONNX: {onnx_path}")
print(f"   File size: {os.path.getsize(onnx_path) / 1e6:.1f} MB")

In [None]:
# Save a summary file
summary = {
    "model_name": config.model_name,
    "image_size": config.image_size,
    "task": config.task,
    "training_epochs": len(history['train_loss']),
    "best_val_mae": float(best_val_mae),
    "best_val_r2": float(max(history['val_r2'])),
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "timestamp": timestamp,
}

with open(run_dir / "summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print("\nüìÅ Saved files:")
for f in run_dir.iterdir():
    print(f"   {f.name} ({os.path.getsize(f) / 1e6:.1f} MB)")

## 8. Test on Single Image

In [None]:
def predict_age(model, image_path, config):
    """Predict age from a single otolith image"""
    model.eval()
    
    # Load and preprocess image
    image = np.array(Image.open(image_path).convert("RGB"))
    
    transform = A.Compose([
        A.Resize(config.image_size, config.image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    transformed = transform(image=image)
    image_tensor = transformed["image"].unsqueeze(0).to(config.device)
    
    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        predicted_age = output.item()
    
    return predicted_age


# Test on an image (update path as needed)
# test_image = "/content/drive/MyDrive/otolith_data/val/age_05/sample.jpg"
# age = predict_age(model, test_image, config)
# print(f"Predicted age: {age:.1f} years")

## üéâ Done!

Your trained model is saved in Google Drive at:
- `otolith_models/run_XXXXXX/checkpoint_best.pt` - Best PyTorch model
- `otolith_models/run_XXXXXX/model.onnx` - ONNX format for production

### Next Steps:
1. Download the model files
2. Copy them to your Ocean project's `ai-services/models/` folder
3. Update the otolith analyzer to use the trained model

### To use in your app:
```python
# Load and use the trained model
import torch

checkpoint = torch.load("models/checkpoint_best.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Predict
age = predict_age(model, "path/to/otolith.jpg", config)
```