<a href="https://colab.research.google.com/github/joaosMart/fish-species-class-siglip/blob/update-readme-comprehensive/Code/species-classification/Feature_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Feature Extraction for Fish Species Classification

From: "Temporal Aggregation of Vision-Language Features for High-Accuracy Fish Classification in Automated Monitoring"

This notebook implements the feature extraction phase of our fish classification pipeline, extracting SigLIP vision-language features from selected video frames for temporal aggregation and species classification.

## REQUIRED INPUTS:
- filtered_data_class.json: JSON file containing video metadata with selected frames from detection phase
```
  Structure: {
    video_path: {
        selected_frames: [list],
        middle_frame: int,
        fish_species: str,
        mean_probability: float
        }
}
```

## OUTPUTS:
- SigLIP Features: NPZ files containing normalized feature vectors for each video
  - features: dict mapping frame_number -> feature_vector
  - frame_numbers: list of selected frame numbers
  - middle_frame: central frame number for single-frame baseline
  - fish_species: ground truth species label
- ResNet-50 Features: NPZ files following the same format containing ResNet-50 features for comparison for central frames only
- Extracted Frames: JPEG images of middle frames for ResNet-50 processing. This is not required for the model. It is only used for obtaining the features for the ResNET model


In [None]:
# Install required packages
!pip install transformers open_clip_torch decord --quiet

In [None]:
import json
import torch
import torch.nn.functional as F
import open_clip
from PIL import Image
import numpy as np
from pathlib import Path
from decord import VideoReader
from decord import cpu, gpu
import decord
from tqdm import tqdm
import os
import cv2
import torchvision.models as models
from torchvision import transforms
import concurrent.futures
import multiprocessing
from collections import Counter

## Data Loading and Configuration


In [None]:
def load_filtered_data(data_path):
    """
    Load filtered video data from JSON file.

    Args:
        data_path (str): Path to filtered_data_class.json

    Returns:
        dict: Video metadata with selected frames
    """
    with open(data_path, 'r') as file:
        filtered_data = json.load(file)

    print(f"Loaded {len(filtered_data)} videos from {data_path}")
    return filtered_data

# Load filtered data from detection phase
# Update this path to your actual data location
filtered_data_file = 'filtered_data_class.json'
filtered_data = load_filtered_data(filtered_data_file)




## Frame Selection Strategy: Central Frame Expansion


In [None]:
def select_frames_by_center_expansion(video_data, window_size=11):
    """
    Select frames by expanding from center frame, finding sequence with highest mean probability.
    This implements the temporal frame selection strategy described in the paper.

    Args:
        video_data (dict): Dictionary containing frame_predictions list
        window_size (int): Number of consecutive frames to analyze (default 11)

    Returns:
        tuple: (selected frames list, middle frame number, sequence mean probability)
    """
    # Get frame predictions sorted by frame number
    frame_preds = sorted(
        video_data['frame_predictions'],
        key=lambda x: x['frame_number']
    )

    if len(frame_preds) < window_size:
        # If we have fewer frames than window size, return all frames
        selected_frames = [pred['frame_number'] for pred in frame_preds]
        middle_frame = selected_frames[len(selected_frames)//2] if selected_frames else None
        mean_prob = sum(pred['original_probability'] for pred in frame_preds) / len(frame_preds) if frame_preds else 0
        return selected_frames, middle_frame, mean_prob

    # Find center index
    center_idx = len(frame_preds) // 2
    half_window = window_size // 2

    max_mean = float('-inf')
    best_start_idx = 0

    # Expand from center to find best sequence
    for offset in range(len(frame_preds)):
        # Try sequences centered around center_idx + offset and center_idx - offset
        for center in [center_idx + offset, center_idx - offset]:
            # Skip if we've gone too far
            if center < half_window or center >= len(frame_preds) - half_window:
                continue

            # Get sequence centered at this position
            start_idx = center - half_window
            window = frame_preds[start_idx:start_idx + window_size]

            current_mean = sum(pred['original_probability'] for pred in window) / window_size

            if current_mean > max_mean:
                max_mean = current_mean
                best_start_idx = start_idx

            # If checking the same position twice, skip
            if center_idx + offset == center_idx - offset:
                break

    # Get the best sequence
    best_sequence = frame_preds[best_start_idx:best_start_idx + window_size]
    selected_frames = [pred['frame_number'] for pred in best_sequence]
    middle_frame = selected_frames[window_size//2]

    return selected_frames, middle_frame, max_mean

def process_videos_dict_with_center_expansion(videos_dict, window_size=11):
    """
    Process videos dictionary using center expansion approach.

    Args:
        videos_dict (dict): Dictionary of video data
        window_size (int): Number of consecutive frames to analyze

    Returns:
        dict: Dictionary with selected frames and mean probability for each video
    """
    results = {}

    for video_path, video_data in videos_dict.items():
        selected_frames, middle_frame, mean_prob = select_frames_by_center_expansion(
            video_data,
            window_size
        )

        results[video_path] = {
            'selected_frames': selected_frames,
            'middle_frame': middle_frame,
            'mean_probability': mean_prob,
            'fish_species': video_data['fish_species']
        }

    return results

# Process videos to select optimal frame sequences
print("Selecting optimal frame sequences using center expansion strategy...")
window_mean_frames = process_videos_dict_with_center_expansion(filtered_data)

print(f"Processed {len(window_mean_frames)} videos for frame selection")



## Quality Analysis and Filtering


This code was used only for intitial iteration of the experiments but for the later stages of the project the threshold was set to 0 to not filter any data out. We kept this because we thoguht it could be interesting to have available.

In [None]:
def analyze_video_quality(results_dict, threshold=0.999999, num_samples=10):
    """
    Analyze video quality based on mean probability threshold.

    Args:
        results_dict (dict): Dictionary from process_videos_dict_with_center_expansion
        threshold (float): Threshold for high quality classification
        num_samples (int): Number of low quality videos to sample
    """
    # Separate videos into high and low quality
    high_quality = {}
    low_quality = {}

    for path, data in results_dict.items():
        if data['mean_probability'] >= threshold:
            high_quality[path] = data['mean_probability']
        else:
            low_quality[path] = data['mean_probability']

    # Print results
    print(f"High quality videos (score >= {threshold}): {len(high_quality)}")
    print(f"Low quality videos (score < {threshold}): {len(low_quality)}")

    # Sample low quality videos if any exist
    sample_size = min(num_samples, len(low_quality))
    if sample_size > 0:
        print(f"\nSample of lowest scoring videos (path, score):")
        import random
        sampled_items = random.sample(list(low_quality.items()), sample_size)
        for path, score in sampled_items:
            print(f"{Path(path).name}: {score:.10f}")

def filter_high_quality_videos(results_dict, threshold=0.0):
    """
    Filter the results dictionary to keep only high quality videos.
    Note: Using threshold=0.0 to keep all videos as described in paper methodology.

    Args:
        results_dict (dict): Dictionary from process_videos_dict_with_center_expansion
        threshold (float): Threshold for high quality classification

    Returns:
        dict: Filtered dictionary containing only high quality videos
    """
    high_quality_dict = {
        path: data
        for path, data in results_dict.items()
        if data['mean_probability'] >= threshold
    }

    print(f"Original number of videos: {len(results_dict)}")
    print(f"Number of videos kept (score >= {threshold}): {len(high_quality_dict)}")
    print(f"Number of videos removed: {len(results_dict) - len(high_quality_dict)}")

    return high_quality_dict

# Analyze video quality and apply filtering
analyze_video_quality(window_mean_frames, threshold=0.0)
filtered_results = filter_high_quality_videos(window_mean_frames, threshold=0.0)

## SigLIP Feature Extraction

In [None]:
class FeatureExtractor:
    """
    SigLIP-based feature extractor for fish species classification.

    Uses ViT-SO400M-14-SigLIP model as feature extractor.
    """

    def __init__(self, model_name='ViT-SO400M-14-SigLIP', device=None):
        # CLIP model uses GPU, but Decord uses CPU for stability
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"SigLIP model using device: {self.device}")

        # Force Decord to use CPU for video processing
        decord.bridge.set_bridge('torch')
        self.ctx = cpu(0)

        # Create SigLIP model and preprocessing transform
        try:
            self.model, _, self.preprocess = open_clip.create_model_and_transforms(
                model_name,
                pretrained='webli'
            )
            self.model = self.model.to(self.device)
            self.model.eval()
            print(f"Successfully loaded {model_name} model")
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise

    def _extract_frames_from_video(self, video_path, frame_numbers):
        """
        Extract specific frames from video using Decord with CPU processing.

        Args:
            video_path (str): Path to video file
            frame_numbers (list): List of frame indices to extract

        Returns:
            list: PIL Images or None if extraction fails
        """
        try:
            # Load video with Decord on CPU
            vr = VideoReader(str(video_path), ctx=self.ctx)

            # Direct frame access
            frames = vr.get_batch(frame_numbers)

            # Convert to PIL Images
            pil_frames = [
                Image.fromarray(frame.numpy())
                for frame in frames
            ]

            return pil_frames

        except Exception as e:
            print(f"Error processing video {video_path}: {str(e)}")
            return None

    def extract_features(self, video_path, frame_numbers, batch_size=256):
        """
        Extract SigLIP features for specific frames from a video.

        Implements the feature extraction strategy described in the paper,
        producing normalized 1152-dimensional feature vectors.

        Args:
            video_path (str): Path to video file
            frame_numbers (list): Frame indices to process
            batch_size (int): Batch size for processing (default 256)

        Returns:
            dict: Mapping from frame_number to normalized feature vector
        """
        features_dict = {}
        frames_batches = [frame_numbers[i:i + batch_size]
                         for i in range(0, len(frame_numbers), batch_size)]

        for batch in frames_batches:
            frames = self._extract_frames_from_video(video_path, batch)

            if frames is None:
                continue

            # Preprocess frames for SigLIP model
            processed_frames = torch.stack([
                self.preprocess(frame).to(self.device)
                for frame in frames
            ])

            # Extract features with automatic mixed precision
            with torch.no_grad(), torch.cuda.amp.autocast():
                batch_features = self.model.encode_image(processed_frames)
                # L2 normalize features as described in paper methodology
                batch_features = F.normalize(batch_features, dim=-1)

            # Store features
            for idx, frame_num in enumerate(batch):
                features_dict[frame_num] = batch_features[idx].cpu().numpy()

        return features_dict

def process_videos_batch(videos_mapping, output_dir, model_name='ViT-SO400M-14-SigLIP', device=None):
    """
    Process a batch of videos and save their SigLIP features.

    Args:
        videos_mapping (dict): Video paths and metadata
        output_dir (str): Directory to save feature files
        model_name (str): SigLIP model variant to use
        device (str): Device for computation
    """
    # Initialize the feature extractor
    extractor = FeatureExtractor(model_name=model_name, device=device)

    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    # Track processing results
    results = Counter()

    # Process each video
    pbar = tqdm(videos_mapping.items(), desc="Extracting SigLIP features")
    for video_path, data in pbar:
        try:
            # Extract features for selected frames
            features = extractor.extract_features(
                video_path,
                data['selected_frames']
            )

            if not features:
                results['failed'] += 1
                continue

            # Generate output filename with dataset information
            dataset_name = Path(video_path).parent.name.split('_')[0]
            video_number = Path(video_path).stem
            output_filename = f"{dataset_name}_{video_number}_features.npz"
            output_path = output_dir / output_filename

            # Save features and metadata
            np.savez(
                output_path,
                features=features,
                frame_numbers=data['selected_frames'],
                middle_frame=data['middle_frame'],
                fish_species=data['fish_species']
            )

            results['success'] += 1

            # Update progress bar
            pbar.set_postfix({
                'Success': results['success'],
                'Failed': results['failed']
            })

        except Exception as e:
            print(f"Error processing {video_path}: {str(e)}")
            results['failed'] += 1
            continue

    # Print final statistics
    print("\nSigLIP Feature Extraction Complete!")
    print(f"Successfully processed: {results['success']} videos")
    print(f"Failed to process: {results['failed']} videos")
    if results['success'] + results['failed'] > 0:
        completion_rate = results['success']/(results['success'] + results['failed'])*100
        print(f"Total completion rate: {completion_rate:.2f}%")



## ResNet-50 Feature Extraction (Baseline Comparison)


In [None]:
def process_single_video_frame_extraction(args):
    """
    Process a single video and save its middle frame for ResNet-50 processing.

    Args:
        args (tuple): (video_path, info, output_dir)

    Returns:
        bool: Success status
    """
    video_path, info, output_dir = args
    try:
        # Get middle frame number
        middle_frame = info['middle_frame']

        # Create VideoReader object and extract frame
        vr = VideoReader(str(video_path), ctx=cpu(0))
        frame = vr[middle_frame].asnumpy()

        # Create output filename
        video_name = Path(video_path).stem
        fish_species = info['fish_species']
        output_path = output_dir / f"{video_name}_{fish_species}_frame_{middle_frame}.jpg"

        # Save frame
        cv2.imwrite(str(output_path), cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        return True
    except Exception as e:
        print(f"\nError processing {video_path}: {str(e)}")
        return False

def extract_middle_frames(video_dict, output_dir):
    """
    Extract middle frames from videos for ResNet-50 feature extraction.

    Args:
        video_dict (dict): Dictionary containing video paths and frame information
        output_dir (str): Directory to save frames
    """
    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Determine number of workers (optimized for Colab)
    num_workers = min(4, multiprocessing.cpu_count())

    total_videos = len(video_dict)
    print(f"Extracting middle frames from {total_videos} videos using {num_workers} workers")

    # Prepare arguments for parallel processing
    args_list = [(path, info, output_dir)
                 for path, info in list(video_dict.items())]

    # Process videos in parallel
    success_count = 0
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(tqdm(
            executor.map(process_single_video_frame_extraction, args_list),
            total=total_videos,
            desc="Extracting frames"
        ))
        success_count = sum(results)

    print(f"\nFrame extraction complete!")
    print(f"Successfully processed: {success_count}/{total_videos} videos")
    print(f"Frames saved to: {output_dir}")

class ResNetFeatureExtractor:
    """
    ResNet-50 feature extractor for baseline comparison.

    Implements ResNet-50 feature extraction as described in the paper
    for comparison with SigLIP-based approach.
    """

    def __init__(self, device=None):
        self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"ResNet-50 using device: {self.device}")

        try:
            # Load pretrained ResNet-50 and remove final classifier
            self.model = models.resnet50(pretrained=True)
            self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
            self.model = self.model.to(self.device)
            self.model.eval()

            # ImageNet normalization as used in paper
            self.preprocess = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
            print("ResNet-50 model loaded successfully")
        except Exception as e:
            print(f"Error loading ResNet-50 model: {str(e)}")
            raise

    def extract_feature(self, image_path):
        """
        Extract ResNet-50 features from a single image.

        Args:
            image_path (str): Path to image file

        Returns:
            numpy.ndarray: Normalized 2048-dimensional feature vector
        """
        try:
            # Load and process image
            image = Image.open(image_path)
            processed_image = self.preprocess(image).unsqueeze(0).to(self.device)

            with torch.no_grad():
                feature = self.model(processed_image)
                feature = feature.squeeze()
                # L2 normalize features for fair comparison
                feature = torch.nn.functional.normalize(feature, dim=-1)

            return feature.cpu().numpy()

        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            return None

def extract_species_from_filename(filename):
    """Extract species label from generated filename."""
    return filename.split('_')[1]  # Adjusted for our filename format

def process_images_batch_resnet(image_folder, output_dir):
    """
    Process extracted frames with ResNet-50 and save features.

    Args:
        image_folder (str): Directory containing extracted frames
        output_dir (str): Directory to save ResNet-50 features
    """
    extractor = ResNetFeatureExtractor()

    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    # Process all jpg images in the folder
    image_paths = list(Path(image_folder).glob('*.jpg'))

    print(f"Processing {len(image_paths)} images with ResNet-50...")

    for image_path in tqdm(image_paths, desc="Processing images"):
        try:
            # Extract features
            feature = extractor.extract_feature(image_path)

            if feature is not None:
                # Extract species from filename
                species = extract_species_from_filename(image_path.name)

                # Create output filename
                output_filename = f"{image_path.stem}_resnet_features.npz"
                output_path = output_dir / output_filename

                # Save features and species
                np.savez(
                    output_path,
                    features=feature,
                    fish_species=species
                )

        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            continue

    print("\nResNet-50 feature extraction complete!")


## Main Processing Pipeline


In [None]:


def main():
    """
    Main feature extraction pipeline.

    Processes videos through both SigLIP and ResNet-50 feature extraction
    for comparison as described in the paper methodology.
    """
    print("=" * 70)
    print("Fish Species Classification - Feature Extraction Pipeline")
    print("=" * 70)

    # Define output directories
    siglip_output_dir = "features/ViT-SO400M-14-SigLIP"
    resnet_output_dir = "features/ResNet-50"
    frames_output_dir = "middle_frames"

    print(f"\nProcessing {len(filtered_results)} videos...")
    print(f"SigLIP features will be saved to: {siglip_output_dir}")
    print(f"ResNet-50 features will be saved to: {resnet_output_dir}")

    # Extract SigLIP features for temporal aggregation
    print("\n1. Extracting SigLIP features for temporal aggregation...")
    process_videos_batch(
        filtered_results,  # these results where not filetered (threshold = 0)
        siglip_output_dir
    )

    # Extract middle frames for ResNet-50 baseline
    print("\n2. Extracting middle frames for ResNet-50 baseline...")
    extract_middle_frames(filtered_results, frames_output_dir)

    # Extract ResNet-50 features for comparison
    print("\n3. Extracting ResNet-50 features for baseline comparison...")
    process_images_batch_resnet(frames_output_dir, resnet_output_dir)

    print("\n" + "=" * 70)
    print("Feature extraction pipeline completed successfully!")
    print("=" * 70)
    print("\nGenerated outputs:")
    print(f"- SigLIP features: {siglip_output_dir}/")
    print(f"- ResNet-50 features: {resnet_output_dir}/")
    print(f"- Middle frame images: {frames_output_dir}/")
    print("\nThese features can now be used for temporal aggregation and species classification.")

if __name__ == "__main__":
    main()