<a href="https://colab.research.google.com/github/joaosMart/fish-species-class-siglip/blob/update-readme-comprehensive/Code/fish-classification/ResNet_50_finetuned.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet-50 Fish Species Classification Benchmark


Fine-tuned ResNet-50 Benchmark for Fish Species Classification

This notebook implements a ResNet-50 model fine-tuned for classifying three salmonid species
(Atlantic Salmon, Brown/Sea Trout, and Arctic Char) from underwater monitoring footage.

Paper: "Temporal Aggregation of Vision-Language Features for High-Accuracy
       Fish Classification in Automated Monitoring"

Key Results:
- Fine-tuned ResNet-50: 95.3% macro F1-score
- Outperformed by SigLIP + temporal pooling: 96.8% macro F1-score
- Learning curve analysis shows temporal aggregation methods require ~76% less training data


Usage:
1. Prepare your data by extracting frames from videos
2. Organize frames in folders by species class
3. Run this notebook to train and evaluate ResNet-50
4. Compare results with SigLIP temporal aggregation methods

In [None]:
# =============================================================================
# 1. IMPORTS AND SETUP
# =============================================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import shutil
import os
import gc
import json
import logging
from datetime import datetime
from typing import Dict, Tuple, List
import random

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR
import torchvision.transforms as transforms
import torchvision.models as models

# Computer vision and video processing
from PIL import Image
import cv2
from decord import VideoReader
from tqdm.notebook import tqdm

# Machine learning utilities
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (accuracy_score, f1_score, precision_recall_fscore_support,
                           confusion_matrix, classification_report)
from sklearn.preprocessing import LabelEncoder

# Configure logging
logging.basicConfig(level=logging.INFO)

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


## Data Loading and Preprocessing


In [None]:


class NumpyEncoder(json.JSONEncoder):
    """Custom encoder for numpy data types"""
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, np.bool_):
            return bool(obj)
        return super(NumpyEncoder, self).default(obj)

class DataLoader:
    """Handle loading and processing of NPZ files containing SigLIP features"""

    def __init__(self, data_dir: str):
        self.data_dir = data_dir

    def load_npz_files(self) -> Tuple[List[str], np.ndarray, List[int]]:
        """
        Load all NPZ files from directory and extract video paths, labels, and central frames

        Returns:
            video_paths: List of video file paths
            labels: Array of fish species labels
            central_frames: List of central frame indices
        """
        video_path_list = []
        labels_list = []
        central_frames_list = []

        # Get all NPZ files in directory
        npz_files = glob.glob(os.path.join(self.data_dir, "*.npz"))
        logging.info(f"Found {len(npz_files)} NPZ files")

        for npz_file in npz_files:
            try:
                filename = npz_file.split('/')[-1][:-4].split('_')
                riverID = filename[0]
                videoID = filename[1]

                # Construct video path (adjust path as needed for your setup)
                video_path = f"data/{riverID}_vid/{videoID}.mp4"

                # Load NPZ file
                data = np.load(npz_file, allow_pickle=True)

                # Extract central frame and species label
                middle_frame = data['middle_frame'].item()
                fish_species = str(data['fish_species'].item())

                video_path_list.append(video_path)
                labels_list.append(fish_species)
                central_frames_list.append(middle_frame)

            except Exception as e:
                logging.error(f"Error processing file {npz_file}: {str(e)}")
                continue

        # Convert to numpy array
        labels_array = np.array(labels_list)

        # Log data distribution
        unique_labels, counts = np.unique(labels_array, return_counts=True)
        logging.info("Dataset distribution:")
        for label, count in zip(unique_labels, counts):
            percentage = (count / len(labels_array)) * 100
            logging.info(f"  {label}: {count} samples ({percentage:.2f}%)")

        return video_path_list, labels_array, central_frames_list


## Frame Extraction


In [None]:

def extract_and_save_frames(video_paths_dict: dict, output_dir: str):
    """
    Extract specific frames from videos and organize by species class

    Args:
        video_paths_dict: Dictionary with {video_path: [frame_number, label]}
        output_dir: Directory to save extracted frames
    """
    output_dir = Path(output_dir)
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True)

    # Create class directories
    labels = set(info[1] for info in video_paths_dict.values())
    for label in labels:
        (output_dir / label).mkdir(exist_ok=True)

    # Process each video
    for video_path, (frame_num, label) in tqdm(video_paths_dict.items(),
                                             desc="Extracting frames"):
        try:
            # Load video with Decord
            vr = VideoReader(video_path)

            # Extract river name from path
            river_name = video_path.split('/')[-2].split('_')[0]

            # Extract frame
            frame = vr[frame_num].asnumpy()

            # Get video name from path
            video_name = Path(video_path).stem

            # Save frame
            frame_path = output_dir / label / f"{river_name}_{video_name}_frame{frame_num}.jpg"
            cv2.imwrite(str(frame_path), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        except Exception as e:
            print(f"Error processing {video_path}: {str(e)}")
            continue


## Dataset Classes


In [None]:


class FishFrameDataset(Dataset):
    """Dataset class for fish frame images"""

    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform

        print("Scanning directories...")
        # Get all valid class directories
        self.classes = sorted([
            d.name for d in self.root_dir.iterdir()
            if d.is_dir() and not d.name.startswith('.')
        ])
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

        # Index dataset
        self.samples = []
        print("Indexing dataset...")
        for class_name in tqdm(self.classes, desc="Loading classes"):
            class_dir = self.root_dir / class_name
            image_files = [
                f for f in class_dir.glob('*')
                if f.suffix.lower() in ('.jpg', '.jpeg', '.png')
                and not f.name.startswith('.')
            ]
            self.samples.extend([(str(img_path), class_name) for img_path in image_files])

        print(f"Found {len(self.samples)} images across {len(self.classes)} classes")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_name = self.samples[idx]

        try:
            with Image.open(img_path) as image:
                image = image.convert('RGB')

            if self.transform:
                image = self.transform(image)

            label_idx = self.class_to_idx[class_name]
            return image, torch.tensor(label_idx)
        except Exception as e:
            print(f"Error reading image {img_path}: {str(e)}")
            return self.__getitem__((idx + 1) % len(self))

class TransformSubset(Dataset):
    """Subset of dataset with specific transforms"""

    def __init__(self, dataset, indices, transform):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        if self.transform:
            img = self.transform(img)
        return img, label



## ResNet-50 Architecture

In [None]:


class Bottleneck(nn.Module):
    """Bottleneck block for ResNet-50"""
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=False):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1,
                               stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels,
                               kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)

        self.relu = nn.ReLU(inplace=True)

        if downsample:
            conv = nn.Conv2d(in_channels, self.expansion * out_channels,
                             kernel_size=1, stride=stride, bias=False)
            bn = nn.BatchNorm2d(self.expansion * out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None

        self.downsample = downsample

    def forward(self, x):
        i = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.downsample is not None:
            i = self.downsample(i)

        x += i
        x = self.relu(x)

        return x

class ResNet(nn.Module):
    """Custom ResNet implementation"""

    def __init__(self, config, output_dim):
        super().__init__()

        block, n_blocks, channels = config
        self.in_channels = channels[0]

        assert len(n_blocks) == len(channels) == 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.in_channels, output_dim)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []

        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample))

        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)

        return x, h

# ResNet-50 configuration
from collections import namedtuple
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

resnet50_config = ResNetConfig(block=Bottleneck,
                               n_blocks=[3, 4, 6, 3],
                               channels=[64, 128, 256, 512])




## Training Utilities


In [None]:


class EarlyStopping:
    """Early stopping to prevent overfitting"""

    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model, path='checkpoint.pt'):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model, path)
        elif val_loss > self.best_loss + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), path)
        self.val_loss_min = val_loss

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train model for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training')
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs, _ = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        pbar.set_postfix({'loss': running_loss/total, 'acc': 100.*correct/total})

    return running_loss/len(train_loader), correct/total

def validate_epoch(model, val_loader, criterion, device):
    """Validate model for one epoch"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc='Validation'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return running_loss/len(val_loader), correct/total, all_preds, all_labels

def evaluate_model(model, test_loader, device, class_names):
    """
    Comprehensive model evaluation with multiple metrics

    Returns:
        metrics: Dictionary containing accuracy, F1-scores, and per-class metrics
    """
    model.eval()
    all_preds = []
    all_labels = []

    print("Evaluating model on test set...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs, _ = model(inputs)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted')
    macro_f1 = f1_score(all_labels, all_preds, average='macro')
    per_class_f1 = f1_score(all_labels, all_preds, average=None)

    # Print results
    print("\n=== Test Set Metrics ===")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Weighted F1-score: {weighted_f1:.4f}")
    print(f"Macro F1-score: {macro_f1:.4f}")

    print("\nPer-class F1 scores:")
    for class_name, f1 in zip(class_names, per_class_f1):
        print(f"  {class_name}: {f1:.4f}")

    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix - ResNet-50')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    # Return metrics dictionary
    return {
        'accuracy': accuracy,
        'weighted_f1': weighted_f1,
        'macro_f1': macro_f1,
        'per_class_f1': dict(zip(class_names, per_class_f1)),
        'confusion_matrix': cm
    }




## Learning Curve Analysis


In [None]:
def generate_learning_curve(model_class, model_config, dataset, test_loader,
                           n_iterations=15, n_repeats=5, num_epochs=25,
                           batch_size=64):
    """
    Generate learning curve by training on increasingly larger data subsets

    Args:
        model_class: ResNet class
        model_config: ResNet configuration
        dataset: Training dataset
        test_loader: Test data loader
        n_iterations: Number of different training sizes
        n_repeats: Number of repeats per size
        num_epochs: Maximum epochs per training run
        batch_size: Training batch size

    Returns:
        Learning curve data (training sizes and F1 scores)
    """
    print("Generating learning curve for ResNet-50...")

    # Calculate training subset sizes
    total_size = len(dataset)
    train_sizes = np.linspace(0.1, 1.0, n_iterations)
    train_sizes = [int(size * total_size) for size in train_sizes]

    all_scores = []

    # Image preprocessing
    pretrained_size = 224
    pretrained_means = [0.485, 0.456, 0.406]
    pretrained_stds = [0.229, 0.224, 0.225]

    train_transforms = transforms.Compose([
        transforms.Resize((pretrained_size, pretrained_size)),
        transforms.RandomRotation(5),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomCrop(pretrained_size, padding=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=pretrained_means, std=pretrained_stds)
    ])

    eval_transforms = transforms.Compose([
        transforms.Resize((pretrained_size, pretrained_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=pretrained_means, std=pretrained_stds)
    ])

    for size in train_sizes:
        size_scores = []
        print(f"\nTraining with {size}/{total_size} samples ({size/total_size:.1%})")

        for repeat in range(n_repeats):
            print(f"  Repeat {repeat+1}/{n_repeats}")

            try:
                # Create random subset
                indices = torch.randperm(total_size)[:size]
                train_subset = TransformSubset(dataset, indices, train_transforms)

                # Split into train/val
                train_size = int(0.8 * len(train_subset))
                val_size = len(train_subset) - train_size
                train_subset, val_subset = torch.utils.data.random_split(
                    train_subset, [train_size, val_size]
                )

                # Create data loaders
                train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
                val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

                # Initialize model
                model = model_class(model_config, output_dim=3).to(device)

                # Load pretrained weights
                pretrained_model = models.resnet50(pretrained=True)
                IN_FEATURES = pretrained_model.fc.in_features
                pretrained_model.fc = nn.Linear(IN_FEATURES, 3)
                model.load_state_dict(pretrained_model.state_dict())

                # Training setup
                criterion = nn.CrossEntropyLoss()
                optimizer = optim.Adam(model.parameters(), lr=1e-4)
                scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
                early_stopping = EarlyStopping(patience=5, verbose=False)

                # Training loop
                for epoch in range(num_epochs):
                    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
                    val_loss, val_acc, _, _ = validate_epoch(model, val_loader, criterion, device)

                    scheduler.step(val_loss)
                    early_stopping(val_loss, model)

                    if early_stopping.early_stop:
                        break

                # Load best model and evaluate
                model.load_state_dict(torch.load('checkpoint.pt'))
                metrics = evaluate_model(model, test_loader, device,
                                       ['Bleikja', 'Lax', 'Urriði'])
                size_scores.append(metrics['macro_f1'])

                # Cleanup
                del model, train_loader, val_loader
                torch.cuda.empty_cache()
                gc.collect()

            except Exception as e:
                print(f"    Error in repeat {repeat}: {str(e)}")
                size_scores.append(0.0)  # Add low score on failure

        all_scores.append(size_scores)

        # Print progress
        mean_score = np.mean(size_scores)
        std_score = np.std(size_scores)
        print(f"  Mean F1: {mean_score:.4f} ± {std_score:.4f}")

    # Calculate final statistics
    test_scores_mean = [np.mean(scores) for scores in all_scores]
    test_scores_std = [np.std(scores) for scores in all_scores]

    # Plot learning curve
    plt.figure(figsize=(10, 6))
    plt.plot(train_sizes, test_scores_mean, 'o-', color='red',
             label='ResNet-50 Fine-tuned', linewidth=2, markersize=6)
    plt.fill_between(train_sizes,
                     [max(0, m - s) for m, s in zip(test_scores_mean, test_scores_std)],
                     [min(1, m + s) for m, s in zip(test_scores_mean, test_scores_std)],
                     alpha=0.2, color='red')

    plt.xlabel('Training Set Size')
    plt.ylabel('Macro F1-Score')
    plt.title('ResNet-50 Learning Curve')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

    return {
        'train_sizes': train_sizes,
        'test_scores_mean': test_scores_mean,
        'test_scores_std': test_scores_std,
        'all_scores': all_scores
    }




## Main Execution


In [None]:


def main():
    """Main execution function"""

    print("="*50)
    print("ResNet-50 Fish Species Classification Benchmark")
    print("="*50)

    # Configuration
    OUTPUT_DIM = 3  # Number of fish species
    BATCH_SIZE = 64 if torch.cuda.is_available() else 32
    NUM_EPOCHS = 50
    PATIENCE = 7

    # Image preprocessing parameters
    pretrained_size = 224
    pretrained_means = [0.485, 0.456, 0.406]
    pretrained_stds = [0.229, 0.224, 0.225]

    print(f"Configuration:")
    print(f"  Device: {device}")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Max epochs: {NUM_EPOCHS}")
    print(f"  Early stopping patience: {PATIENCE}")

    # 1. Load dataset
    print("\n1. Loading dataset...")

    # If you have extracted frames, load them directly
    data_dir = "fish_frames"  # Adjust path as needed

    if not os.path.exists(data_dir):
        print(f"Dataset directory {data_dir} not found.")
        print("Please ensure you have extracted frames from videos using the data preparation section.")
        return

    # Create dataset
    base_dataset = FishFrameDataset(root_dir=data_dir)

    # Print dataset statistics
    print(f"\nDataset Statistics:")
    print(f"  Total samples: {len(base_dataset)}")
    print(f"  Classes: {base_dataset.classes}")

    # Calculate class distribution
    class_counts = {}
    for _, class_name in base_dataset.samples:
        class_counts[class_name] = class_counts.get(class_name, 0) + 1

    for class_name, count in class_counts.items():
        percentage = (count / len(base_dataset)) * 100
        print(f"    {class_name}: {count} samples ({percentage:.1f}%)")

    # 2. Create data splits
    print("\n2. Creating data splits...")

    # Define transforms
    train_transforms = transforms.Compose([
        transforms.Resize((pretrained_size, pretrained_size)),
        transforms.RandomRotation(5),
        transforms.RandomHorizontalFlip(0.5),
        transforms.RandomCrop(pretrained_size, padding=10),
        transforms.ToTensor(),
        transforms.Normalize(mean=pretrained_means, std=pretrained_stds)
    ])

    eval_transforms = transforms.Compose([
        transforms.Resize((pretrained_size, pretrained_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=pretrained_means, std=pretrained_stds)
    ])

    # Create train/val/test split (70/15/15)
    total_size = len(base_dataset)
    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)
    test_size = total_size - train_size - val_size

    indices = torch.randperm(total_size)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Create datasets with transforms
    train_dataset = TransformSubset(base_dataset, train_indices, train_transforms)
    val_dataset = TransformSubset(base_dataset, val_indices, eval_transforms)
    test_dataset = TransformSubset(base_dataset, test_indices, eval_transforms)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Validation samples: {len(val_dataset)}")
    print(f"  Test samples: {len(test_dataset)}")

    # 3. Initialize and train model
    print("\n3. Training ResNet-50 model...")

    # Initialize model
    model = ResNet(resnet50_config, OUTPUT_DIM).to(device)

    # Load pretrained weights
    pretrained_model = models.resnet50(pretrained=True)
    IN_FEATURES = pretrained_model.fc.in_features
    pretrained_model.fc = nn.Linear(IN_FEATURES, OUTPUT_DIM)
    model.load_state_dict(pretrained_model.state_dict())

    print(f"  Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # Training setup with class weights for imbalanced data
    class_counts = torch.tensor([279, 1086, 1588], dtype=torch.float32)  # Bleikja, Lax, Urriði
    class_weights = 1.0 / class_counts
    class_weights = class_weights / class_weights.sum()
    class_weights = class_weights.to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Layer-wise learning rates (differential learning rates)
    FOUND_LR = 6.03e-04
    params = [
        {'params': model.conv1.parameters(), 'lr': FOUND_LR / 10},
        {'params': model.bn1.parameters(), 'lr': FOUND_LR / 10},
        {'params': model.layer1.parameters(), 'lr': FOUND_LR / 8},
        {'params': model.layer2.parameters(), 'lr': FOUND_LR / 6},
        {'params': model.layer3.parameters(), 'lr': FOUND_LR / 4},
        {'params': model.layer4.parameters(), 'lr': FOUND_LR / 2},
        {'params': model.fc.parameters()}
    ]

    optimizer = optim.Adam(params, lr=FOUND_LR)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    early_stopping = EarlyStopping(patience=PATIENCE, verbose=True)

    # Training metrics storage
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    print("  Starting training...")

    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}')

        # Training phase
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Validation phase
        val_loss, val_acc, val_preds, val_labels = validate_epoch(model, val_loader, criterion, device)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        # Print epoch results
        print(f'  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')
        print(f'  Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%')

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Early stopping check
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("  Early stopping triggered")
            break

    # Load best model
    model.load_state_dict(torch.load('checkpoint.pt'))
    print("  Training completed. Best model loaded.")

    # 4. Plot training curves
    print("\n4. Plotting training curves...")

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot losses
    epochs_range = range(1, len(train_losses) + 1)
    ax1.plot(epochs_range, train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.plot(epochs_range, val_losses, 'r-', label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot accuracies
    ax2.plot(epochs_range, [acc*100 for acc in train_accs], 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(epochs_range, [acc*100 for acc in val_accs], 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # 5. Evaluate on test set
    print("\n5. Evaluating on test set...")

    class_names = ['Bleikja', 'Lax', 'Urriði']
    test_metrics = evaluate_model(model, test_loader, device, class_names)

    # 6. Generate learning curve
    print("\n6. Generating learning curve...")

    learning_curve_data = generate_learning_curve(
        model_class=ResNet,
        model_config=resnet50_config,
        dataset=base_dataset,
        test_loader=test_loader,
        n_iterations=10,  # Reduced for faster execution
        n_repeats=3,      # Reduced for faster execution
        num_epochs=20,    # Reduced for faster execution
        batch_size=BATCH_SIZE
    )

    # 7. Save results
    print("\n7. Saving results...")

    results = {
        'model': 'ResNet-50',
        'timestamp': datetime.now().isoformat(),
        'final_metrics': test_metrics,
        'training_history': {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accuracies': train_accs,
            'val_accuracies': val_accs
        },
        'learning_curve': learning_curve_data,
        'configuration': {
            'batch_size': BATCH_SIZE,
            'num_epochs': NUM_EPOCHS,
            'learning_rate': FOUND_LR,
            'early_stopping_patience': PATIENCE,
            'device': str(device)
        }
    }

    # Save to JSON
    with open('resnet50_results.json', 'w') as f:
        json.dump(results, f, indent=2, cls=NumpyEncoder)

    # Save model
    torch.save(model.state_dict(), 'resnet50_best_model.pth')

    print("  Results saved to 'resnet50_results.json'")
    print("  Model saved to 'resnet50_best_model.pth'")

    # 8. Print summary
    print("\n" + "="*50)
    print("TRAINING SUMMARY")
    print("="*50)
    print(f"Final Test Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Final Test Macro F1: {test_metrics['macro_f1']:.4f}")
    print(f"Final Test Weighted F1: {test_metrics['weighted_f1']:.4f}")
    print("\nPer-class F1 Scores:")
    for class_name, f1_score in test_metrics['per_class_f1'].items():
        print(f"  {class_name}: {f1_score:.4f}")

    # Compare with paper results
    print(f"\nComparison with Paper Results:")
    print(f"  ResNet-50 (this run): {test_metrics['macro_f1']:.3f}")
    print(f"  SigLIP + Temporal Pooling (paper): 0.968")
    print(f"  Performance gap: {0.968 - test_metrics['macro_f1']:.3f}")

    return model, test_metrics, results




# Utility Functions

In [None]:


def visualize_sample_predictions(model, test_loader, device, class_names, num_samples=9):
    """Visualize model predictions on sample images"""

    model.eval()
    samples_shown = 0

    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    axes = axes.ravel()

    with torch.no_grad():
        for batch_images, batch_labels in test_loader:
            batch_images = batch_images.to(device)
            batch_labels = batch_labels.to(device)

            outputs, _ = model(batch_images)
            _, predicted = torch.max(outputs, 1)
            probabilities = torch.softmax(outputs, dim=1)

            for i in range(batch_images.size(0)):
                if samples_shown >= num_samples:
                    break

                # Denormalize image for display
                img = batch_images[i].cpu()
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
                img = img * std + mean
                img = torch.clamp(img, 0, 1)

                # Convert to numpy and transpose for matplotlib
                img_np = img.permute(1, 2, 0).numpy()

                true_label = class_names[batch_labels[i].item()]
                pred_label = class_names[predicted[i].item()]
                confidence = probabilities[i][predicted[i]].item()

                # Color: green if correct, red if incorrect
                color = 'green' if true_label == pred_label else 'red'

                axes[samples_shown].imshow(img_np)
                axes[samples_shown].set_title(
                    f'True: {true_label}\nPred: {pred_label}\nConf: {confidence:.3f}',
                    color=color, fontsize=10
                )
                axes[samples_shown].axis('off')

                samples_shown += 1

            if samples_shown >= num_samples:
                break

    plt.tight_layout()
    plt.suptitle('Sample Predictions (Green=Correct, Red=Incorrect)', y=1.02)
    plt.show()

def load_pretrained_model(model_path, model_config, output_dim, device):
    """Load a pretrained model from checkpoint"""

    model = ResNet(model_config, output_dim).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    print(f"Model loaded from {model_path}")
    return model

def count_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# =============================================================================
# 10. DATA PREPARATION (IF NEEDED)
# =============================================================================

def prepare_data_from_npz(data_dir, output_dir):
    """
    Prepare training data by extracting frames from videos based on NPZ files

    Args:
        data_dir: Directory containing NPZ files with feature extraction metadata
        output_dir: Directory to save extracted frames organized by class
    """

    print("Preparing data from NPZ files...")

    # Load NPZ file information
    data_loader = DataLoader(data_dir)
    video_paths, species, central_frames = data_loader.load_npz_files()

    # Create video paths dictionary
    full_vid_paths = {}
    for i, path in enumerate(video_paths):
        full_vid_paths[path] = [central_frames[i], species[i]]

    print(f"Found {len(full_vid_paths)} videos to process")

    # Extract frames
    extract_and_save_frames(full_vid_paths, output_dir)

    print(f"Frames extracted to {output_dir}")

    # Print class distribution
    for class_name in ['Bleikja', 'Lax', 'Urriði']:
        class_dir = Path(output_dir) / class_name
        if class_dir.exists():
            count = len(list(class_dir.glob('*.jpg')))
            print(f"  {class_name}: {count} images")

# =============================================================================
# 11. EXECUTION
# =============================================================================

if __name__ == "__main__":
    # Example usage
    print("ResNet-50 Fish Species Classification")
    print("=" * 40)

    # Option 1: If you need to prepare data from NPZ files
    # prepare_data_from_npz("path/to/npz/files", "fish_frames")

    # Option 2: Run main training and evaluation
    try:
        model, metrics, results = main()

        # Optional: Visualize some predictions
        print("\nVisualizing sample predictions...")
        # Recreate test_loader for visualization (you may need to adjust this)
        # visualize_sample_predictions(model, test_loader, device, ['Bleikja', 'Lax', 'Urriði'])

    except Exception as e:
        print(f"Error during execution: {str(e)}")
        print("Please ensure your data is properly prepared and paths are correct.")

    print("\nBenchmark completed!")