<a href="https://colab.research.google.com/github/joaosMart/fish-species-class-siglip/blob/main/Code/fish-detection/SigLIP_fish_detection_prediction_savings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to Use

## Prerequisites
Install required packages if running this in colab:

```bash
!pip install transformers open_clip_torch opencv-python pillow torch
```

## Setup your video list
Create a list called `fish_path_list` containing the full file paths to your video files:

```
fish_path_list = [
    "/path/to/video1.mp4",
    "/path/to/video2.mp4",
    "/path/to/video3.mp4"
]
```

## Run the code
Execute the script - it will automatically process each video frame-by-frame, generate fish detection probability scores, and save results to a pickle file with checkpointing every 10 videos.

## Output
Results are saved to `Scores-ViT-SO400M-14-SigLIP.pkl` containing probability scores for each frame: `[no_fish_probability, fish_probability]`

## Resume capability
If processing is interrupted, simply re-run the script - it will automatically resume from the last checkpoint.

In [None]:
import torch
import torch.nn.functional as F
import open_clip
import cv2
from PIL import Image
import pickle
import os
import time
from tqdm import tqdm

# Setup device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model_name = 'ViT-SO400M-14-SigLIP'
model, _, preprocess_val = open_clip.create_model_and_transforms(model_name, pretrained='webli')
model = model.to(device)
tokenizer = open_clip.get_tokenizer(model_name)

positive_prompts = tokenizer([
    "Salmon-like fish swimming",
    "An underwater photo of a salmon-like fish seen clearly swimming.",
    "Image of salmon-like fish in a contained environment.",
    "A photo of a salmon-like fish in a controlled river environment."
], context_length=model.context_length).to(device)

negative_prompts = tokenizer([
    "An image of an empty white water container.",
    "A contained environment with nothing in it.",
    "An image of a empty container with nothing in it."
], context_length=model.context_length).to(device)

# Encode text features
with torch.no_grad(), torch.cuda.amp.autocast():
    pos_text_features = model.encode_text(positive_prompts)
    neg_text_features = model.encode_text(negative_prompts)
    text_features = torch.stack((neg_text_features.mean(axis=0), pos_text_features.mean(axis=0)))
    text_features = F.normalize(text_features, dim=-1)

def process_video_batch(video_path, text_features, batch_size=128):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    results = []
    start_time = time.time()

    batch = []
    for _ in tqdm(range(frame_count), desc=f"Processing {os.path.basename(video_path)}", leave=False):
        ret, frame = cap.read()
        if not ret:
            break

        image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        image_tensor = preprocess_val(image).unsqueeze(0)
        batch.append(image_tensor)

        if len(batch) == batch_size:
            batch_tensor = torch.cat(batch).to(device)
            with torch.no_grad(), torch.cuda.amp.autocast():
                image_features = model.encode_image(batch_tensor)
                image_features = F.normalize(image_features, dim=-1)
                text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
                results.extend(text_probs.cpu().numpy())
            batch = []

    # Process any remaining frames
    if batch:
        batch_tensor = torch.cat(batch).to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(batch_tensor)
            image_features = F.normalize(image_features, dim=-1)
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            results.extend(text_probs.cpu().numpy())

    cap.release()
    return results

# Set the path for saving checkpoints
CHECKPOINT_PATH = '/path/to/Scores-ViT-SO400M-14-SigLIP.pkl'

def save_checkpoint(results, processed_videos, filename=CHECKPOINT_PATH):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    checkpoint_data = {
        'results': results,
        'processed_videos': processed_videos
    }
    with open(filename, 'wb') as f:
        pickle.dump(checkpoint_data, f)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(filename=CHECKPOINT_PATH):
    if os.path.isfile(filename):
        with open(filename, 'rb') as f:
            checkpoint = pickle.load(f)
        print(f"Loaded checkpoint from {filename}")
        return checkpoint['results'], checkpoint['processed_videos']
    else:
        print(f"No checkpoint found at {filename}")
        return {}, set()  # Return empty results and empty set of processed videos

def process_videos(video_files, text_features, checkpoint_interval=10):
    overall_results, processed_videos = load_checkpoint()

    for i, video_path in enumerate(tqdm(video_files)):
        if video_path not in processed_videos:
            results = process_video_batch(video_path, text_features)
            overall_results[video_path] = results
            processed_videos.add(video_path)

            if (len(processed_videos) % checkpoint_interval == 0):
                save_checkpoint(overall_results, processed_videos)

    # Save final results
    save_checkpoint(overall_results, processed_videos)

    return overall_results

# Your list of video paths
results = process_videos(fish_path_list, text_features)