<a href="https://colab.research.google.com/github/joaosMart/fish-species-class-siglip/blob/main/Code/fish-detection/SigLIP_fish_detection_prediction_savings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to Use the Complete Fish Detection Pipeline

## Prerequisites
Install required packages in Google Colab:
```bash
!pip install transformers open_clip_torch
```

## Setup your video data



### Manual video list
Define your video list as follows:

```python
fish_path_list = [
    "/path/to/video1.mp4",
    "/path/to/video2.mp4",
    "/path/to/video3.mp4"
]
```

## Run the complete pipeline
Execute the entire notebook - it will automatically:

### Part 1: Generate Fish Detection Scores
- Load SigLIP Vision Transformer model
- Process each video frame-by-frame
- Generate probability scores: `[no_fish_probability, fish_probability]`
- Save with checkpointing every 10 videos

### Part 2: Filter and Analyze Results
- Apply probability threshold (default: 0.977989)
- Filter frames containing fish above threshold
- Generate comprehensive statistical analysis

## Configuration
Adjust the detection threshold in Part 2:

```python
DETECTION_THRESHOLD = 0.977989  # Modify as needed
```

## Outputs
The pipeline generates two files:
1. **Raw scores**: `Scores-ViT-SO400M-14-SigLIP.pkl`
2. **Filtered results**: `fish_detection_results.json`

## Resume capability
If processing is interrupted, simply re-run the notebook - it automatically resumes from the last checkpoint.

## Analysis results
The final section provides detailed statistics including:
- Total fish frames detected
- Video distribution (with/without fish)
- Probability distributions and percentiles
- Frame-level statistics and percentages

In [None]:
# Install required packages (run in Colab)
# !pip install transformers open_clip_torch


# Import libraries for Part 1 (Score Generation)
import torch
import torch.nn.functional as F
import open_clip
import cv2
from PIL import Image
import pickle
import os
import time
import pandas as pd
from tqdm import tqdm

# Import libraries for Part 2 (Analysis)
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from statistics import mean, median

## Fish Detection Score Generation

In [None]:
print("=" * 60)
print("PART 1: GENERATING FISH DETECTION SCORES")
print("=" * 60)

# Setup device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load SigLIP model
model_name = 'ViT-SO400M-14-SigLIP'
model, _, preprocess_val = open_clip.create_model_and_transforms(model_name, pretrained='webli')
model = model.to(device)
tokenizer = open_clip.get_tokenizer(model_name)

# Define text prompts for fish detection
positive_prompts = tokenizer([
    "Salmon-like fish swimming",
    "An underwater photo of a salmon-like fish seen clearly swimming.",
    "Image of salmon-like fish in a contained environment.",
    "A photo of a salmon-like fish in a controlled river environment."
], context_length=model.context_length).to(device)

negative_prompts = tokenizer([
    "An image of an empty white water container.",
    "A contained environment with nothing in it.",
    "An image of a empty container with nothing in it."
], context_length=model.context_length).to(device)

# Encode text features once
print("Encoding text prompts...")
with torch.no_grad(), torch.cuda.amp.autocast():
    pos_text_features = model.encode_text(positive_prompts)
    neg_text_features = model.encode_text(negative_prompts)
    text_features = torch.stack((neg_text_features.mean(axis=0), pos_text_features.mean(axis=0)))
    text_features = F.normalize(text_features, dim=-1)

print("Text features encoded successfully!")

# Load video data!
print("CSV file not found. Please define fish_path_list manually:")
print("fish_path_list = ['/path/to/video1.mp4', '/path/to/video2.mp4', ...]")
fish_path_list = []

# Video processing function
def process_video_batch(video_path, text_features, batch_size=128):
    """Process a single video and return probability scores for each frame"""
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    results = []
    start_time = time.time()

    batch = []
    for _ in tqdm(range(frame_count), desc=f"Processing {os.path.basename(video_path)}", leave=False):
        ret, frame = cap.read()
        if not ret:
            break

        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        image_tensor = preprocess_val(image).unsqueeze(0)
        batch.append(image_tensor)

        if len(batch) == batch_size:
            batch_tensor = torch.cat(batch).to(device)
            with torch.no_grad(), torch.cuda.amp.autocast():
                image_features = model.encode_image(batch_tensor)
                image_features = F.normalize(image_features, dim=-1)
                text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
                results.extend(text_probs.cpu().numpy())
            batch = []

    # Process any remaining frames
    if batch:
        batch_tensor = torch.cat(batch).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(batch_tensor)
            image_features = F.normalize(image_features, dim=-1)
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            results.extend(text_probs.cpu().numpy())

    cap.release()
    processing_time = time.time() - start_time
    print(f"Processed {len(results)} frames in {processing_time:.2f} seconds")
    return results

# Checkpoint management
CHECKPOINT_PATH = '/path/to/Scores-ViT-SO400M-14-SigLIP.pkl'

def save_checkpoint(results, processed_videos, filename=CHECKPOINT_PATH):
    """Save progress to pickle file"""
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    checkpoint_data = {
        'results': results,
        'processed_videos': processed_videos
    }
    with open(filename, 'wb') as f:
        pickle.dump(checkpoint_data, f)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(filename=CHECKPOINT_PATH):
    """Load existing progress from pickle file"""
    if os.path.isfile(filename):
        with open(filename, 'rb') as f:
            checkpoint = pickle.load(f)
        print(f"Loaded checkpoint from {filename}")
        return checkpoint['results'], checkpoint['processed_videos']
    else:
        print(f"No checkpoint found at {filename}")
        return {}, set()

def process_videos(video_files, text_features, checkpoint_interval=10):
    """Process all videos with checkpointing"""
    overall_results, processed_videos = load_checkpoint()

    print(f"Starting video processing...")
    print(f"Total videos to process: {len(video_files)}")
    print(f"Already processed: {len(processed_videos)}")
    print(f"Remaining: {len(video_files) - len(processed_videos)}")

    for i, video_path in enumerate(tqdm(video_files, desc="Processing videos")):
        if video_path not in processed_videos:
            print(f"\nProcessing video {len(processed_videos) + 1}/{len(video_files)}: {os.path.basename(video_path)}")
            results = process_video_batch(video_path, text_features)
            overall_results[video_path] = results
            processed_videos.add(video_path)

            if (len(processed_videos) % checkpoint_interval == 0):
                save_checkpoint(overall_results, processed_videos)

    # Save final results
    save_checkpoint(overall_results, processed_videos)
    print(f"\nScore generation complete! Processed {len(overall_results)} videos.")
    return overall_results

# Run video processing (only if fish_path_list is defined)
if fish_path_list:
    results = process_videos(fish_path_list, text_features)
else:
    print("Skipping video processing - no video list provided")

## Fish Detection Analysis and Filtering

In [None]:

print("\n" + "=" * 60)
print("PART 2: ANALYZING FISH DETECTION RESULTS")
print("=" * 60)

# Load the generated scores
file_name = "/path/to/Scores-ViT-SO400M-14-SigLIP.pkl"

try:
    with open(file_name, 'rb') as file:
        data = pickle.load(file)

    print(f"Loaded data successfully!")
    print(f"Number of videos processed: {len(data['results'])}")
    print(f"Sample video frames: {len(list(data['results'].values())[0])}")
except FileNotFoundError:
    print(f"Score file not found at {file_name}")
    print("Please run Part 1 first to generate scores")
    data = None

if data is not None:
    # Processing functions
    def process_video_cpu(video_path, frame_data, threshold=0.9898):
        """Filter frames above threshold - CPU version"""
        probabilities = [float(prob[1]) for prob in frame_data]
        fish_frames = [i for i, prob in enumerate(probabilities) if prob >= threshold]

        return {
            "video_name": video_path,
            "total_frames": len(frame_data),
            "fish_frames": [
                {"frame": frame, "probability": probabilities[frame]}
                for frame in fish_frames
            ]
        }

    def process_all_videos_gpu(data, output_file, batch_size=512, threshold=0.5):
        """Process all videos and apply threshold filtering"""
        video_paths = list(data['results'].keys())
        all_results = {}

        print(f"Applying threshold filtering (threshold = {threshold})...")
        for i in tqdm(range(0, len(video_paths), batch_size), desc="Processing batches"):
            batch_paths = video_paths[i:i+batch_size]
            batch_results = [
                process_video_cpu(path, data['results'][path], threshold)
                for path in batch_paths
            ]

            for result in batch_results:
                all_results[result['video_name']] = {
                    "total_frames": result['total_frames'],
                    "fish_frames": result['fish_frames']
                }

        # Save results to JSON file
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            json.dump(all_results, f, indent=2)

        print(f"Filtered results saved to {output_file}")
        return all_results

    # Apply threshold filtering
    output_file = '/path/to/fish_detection_results.json'
    DETECTION_THRESHOLD = 0.977989  # Adjust this threshold as needed

    filtered_results = process_all_videos_gpu(data, output_file, threshold=DETECTION_THRESHOLD)

    # =============================================================================
    # PART 3: STATISTICAL ANALYSIS
    # =============================================================================

    print("\n" + "=" * 60)
    print("PART 3: STATISTICAL ANALYSIS")
    print("=" * 60)

    def load_json_file(file_path):
        """Load JSON file with error handling"""
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except FileNotFoundError:
            print(f"Error: File {file_path} not found")
            return None
        except json.JSONDecodeError:
            print(f"Error: File {file_path} is not valid JSON")
            return None

    def calculate_probability_stats(data):
        """Calculate statistics about probabilities across all fish detections"""
        all_probabilities = []
        for video_data in data.values():
            probabilities = [frame['probability'] for frame in video_data['fish_frames']]
            all_probabilities.extend(probabilities)

        if not all_probabilities:
            return None

        probabilities_array = np.array(all_probabilities)
        quantiles = np.percentile(probabilities_array, [25, 50, 75])

        return {
            'min_probability': min(all_probabilities),
            'max_probability': max(all_probabilities),
            'mean_probability': np.mean(all_probabilities),
            'q25': quantiles[0],
            'q50': quantiles[1],
            'q75': quantiles[2],
        }

    def calculate_fish_frame_stats(data):
        """Calculate statistics about frames with fish across all videos"""
        frames_per_video = [len(video['fish_frames']) for video in data.values()]
        total_frames_per_video = [video['total_frames'] for video in data.values()]

        # Separate videos with and without fish
        videos_with_fish = [(total, fish) for total, fish in zip(total_frames_per_video, frames_per_video) if fish > 0]
        videos_without_fish = [(total, fish) for total, fish in zip(total_frames_per_video, frames_per_video) if fish == 0]

        stats = {
            'total_videos': len(data),
            'total_fish_frames': sum(frames_per_video),
            'average_fish_frames_per_video': mean(frames_per_video),
            'median_fish_frames_per_video': median(frames_per_video),
            'max_fish_frames': max(frames_per_video),
            'min_fish_frames': min(frames_per_video),
            'videos_with_fish': len(videos_with_fish),
            'videos_without_fish': len(videos_without_fish),
            'average_total_frames_all_videos': mean(total_frames_per_video),
            'average_total_frames_videos_with_fish': mean([total for total, _ in videos_with_fish]) if videos_with_fish else 0,
            'average_total_frames_videos_without_fish': mean([total for total, _ in videos_without_fish]) if videos_without_fish else 0,
            'average_fish_frames_in_positive_videos': mean([fish for _, fish in videos_with_fish]) if videos_with_fish else 0
        }

        # Calculate percentage of frames with fish for each video
        percentages = []
        for video_name, video_data in data.items():
            total_frames = video_data['total_frames']
            fish_frames = len(video_data['fish_frames'])
            percentage = (fish_frames / total_frames) * 100 if total_frames > 0 else 0
            percentages.append(percentage)

        stats['average_percentage_frames_with_fish'] = mean(percentages)
        return stats

    def print_fish_frame_analysis(file_path):
        """Print comprehensive analysis of fish frames in the dataset"""
        data = load_json_file(file_path)
        if not data:
            return

        stats = calculate_fish_frame_stats(data)
        prob_stats = calculate_probability_stats(data)

        print("\n=== FISH FRAME ANALYSIS ===")
        print(f"\nThreshold used: {DETECTION_THRESHOLD}")
        print("\nOverall Statistics:")
        print(f"Total number of videos analyzed: {stats['total_videos']}")
        print(f"Total frames containing fish: {stats['total_fish_frames']}")

        print("\nFrame Distribution:")
        print(f"Average total frames per video (all videos): {stats['average_total_frames_all_videos']:.2f}")
        print(f"Average total frames in videos WITH fish: {stats['average_total_frames_videos_with_fish']:.2f}")
        print(f"Average total frames in videos WITHOUT fish: {stats['average_total_frames_videos_without_fish']:.2f}")

        print("\nFish Frame Statistics:")
        print(f"Average frames with fish per video (all videos): {stats['average_fish_frames_per_video']:.2f}")
        print(f"Average frames with fish in videos containing fish: {stats['average_fish_frames_in_positive_videos']:.2f}")
        print(f"Median frames with fish per video: {stats['median_fish_frames_per_video']}")
        print(f"Maximum frames with fish in a video: {stats['max_fish_frames']}")
        print(f"Minimum frames with fish in a video: {stats['min_fish_frames']}")

        print("\nVideo Distribution:")
        print(f"Videos containing fish: {stats['videos_with_fish']}")
        print(f"Videos without fish: {stats['videos_without_fish']}")
        print(f"Percentage of videos with fish: {(stats['videos_with_fish']/stats['total_videos']*100):.1f}%")

        print("\nPercentage Analysis:")
        print(f"Average percentage of frames with fish per video: {stats['average_percentage_frames_with_fish']:.2f}%")

        if prob_stats:
            print("\nProbability Statistics (for detected fish frames):")
            print(f"Minimum probability: {prob_stats['min_probability']:.6f}")
            print(f"Maximum probability: {prob_stats['max_probability']:.6f}")
            print(f"Mean probability: {prob_stats['mean_probability']:.6f}")
            print("\nProbability Quantiles:")
            print(f"25th percentile (Q1): {prob_stats['q25']:.6f}")
            print(f"50th percentile (Median): {prob_stats['q50']:.6f}")
            print(f"75th percentile (Q3): {prob_stats['q75']:.6f}")

        return stats, prob_stats

    # Run the complete analysis
    analysis_results = print_fish_frame_analysis(output_file)

print("\n" + "=" * 60)
print("PIPELINE COMPLETE!")
print("=" * 60)
print("Files generated:")
print(f"1. Raw scores: {CHECKPOINT_PATH}")
print(f"2. Filtered results: {output_file}")
print("\nYou can now use these files for further analysis or visualization.")