<a href="https://colab.research.google.com/github/martintmv-git/RB-IBDM/blob/main/Experiments/Generating%20Masks%20with%20SAM/Annotated%20dataset%20masks/sam_full_dataset_generate_masks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RB-IBDM Image Segmentation
### Segment Anything by Meta AI
### Generating object masks with SAM for RB-IBDM and saving them in Google Drive

## Before starting

Make sure you are connected to a GPU.

In [None]:
!nvidia-smi

Sat Mar 30 17:48:31 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   67C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

**NOTE:** To make it easier for us to manage datasets, images and models we create a `HOME` constant.

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /content


# Install Segment Anything Model (SAM) and dependencies

In [None]:
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone


# Download SAM weights

In [None]:
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

In [None]:
import os

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

/content/weights/sam_vit_h_4b8939.pth ; exist: True


# Download Insect Data

In [None]:
!mkdir -p {HOME}/data

import os
from google.colab import drive

drive.mount('/content/drive')

# Dataset of images
dataset_path = '/content/drive/MyDrive/diopsis_tests/images_clean'

# Where the generated masks will be saved
save_path = '/content/drive/MyDrive/diopsis_tests/diopsis_masks_generated'

# Ensure the directory exists
if not os.path.exists(save_path):
    os.makedirs(save_path)
    print(f"Directory created for saving masks: {save_path}")
else:
    print(f"Save directory already exists: {save_path}")

# Counting the number of images in the dataset
num_images = len(os.listdir(dataset_path))
print(f"Number of images read in dataset: {num_images}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Save directory already exists: /content/drive/MyDrive/diopsis_tests/diopsis_masks_generated
Number of images read in dataset: 27649


# Load Model

In [None]:
import torch

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

# Automated Mask Generation

To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended.

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)

### Generate masks with SAM

Creating `processed_images.txt` so if I need to re-run the script after quitting, it will first read the `processed_images.txt` file to determine which images have already been processed.

This approach saves a lot of time and computational resources, especially when working with large datasets or when running processes that have a long execution time, as it prevents reprocessing of images that have already been handled. It's a simple yet effective.

In [None]:
import cv2
from PIL import Image
import numpy as np

def save_masks_to_drive(masks, save_path, image_name):
    if masks:  # Check if there is at least one mask
        try:
            img = Image.fromarray((masks[0] * 255).astype(np.uint8))  # Use only the first mask
            mask_file_path = os.path.join(save_path, f'mask_{image_name}_0.png')  # Name for the first mask
            img.save(mask_file_path)
            print(f"Successfully saved mask to drive: {mask_file_path}")
            return True
        except Exception as e:
            print(f"Failed saving mask to drive for {image_name}: {e}")
            return False

def process_images_in_batches(dataset_path, save_path, batch_size=100):
    processed_images_file = os.path.join(save_path, "processed_images.txt")
    if os.path.exists(processed_images_file):
        with open(processed_images_file, "r") as file:
            processed_images = set(file.read().splitlines())
    else:
        processed_images = set()

    image_paths = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if os.path.isfile(os.path.join(dataset_path, f))]

    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        for path in batch_paths:
            image_name = os.path.basename(path).replace('.jpg', '').replace('.png', '')
            if image_name in processed_images:
                print(f"Skipping already processed image: {path}")
                continue  # Skip this image because it has already been processed
                          # as generating masks for the whole dataset takes around 18h

            try:
                print(f"Processing: {path}")
                image_bgr = cv2.imread(path)
                image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                sam_result = mask_generator.generate(image_rgb)
                masks = [mask['segmentation'] for mask in sorted(sam_result, key=lambda x: x['area'], reverse=True)]

                if save_masks_to_drive(masks, save_path, image_name):
                    with open(processed_images_file, "a") as file:
                        file.write(f"{image_name}\n")
                    print(f"Successfully processed and saved masks for: {path}")
            except Exception as e:
                print(f"Error processing {path}: {e}")


process_images_in_batches(dataset_path, save_path, batch_size=100)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Successfully saved mask to drive: /content/drive/MyDrive/diopsis_tests/diopsis_masks_generated/mask_126_20200810000545_5656_t_0.png
Successfully processed and saved masks for: /content/drive/MyDrive/diopsis_tests/images_clean/126_20200810000545_5656_t.jpg
Processing: /content/drive/MyDrive/diopsis_tests/images_clean/182_20200628034901_5089_t.jpg
Successfully saved mask to drive: /content/drive/MyDrive/diopsis_tests/diopsis_masks_generated/mask_182_20200628034901_5089_t_0.png
Successfully processed and saved masks for: /content/drive/MyDrive/diopsis_tests/images_clean/182_20200628034901_5089_t.jpg
Processing: /content/drive/MyDrive/diopsis_tests/images_clean/192_20200725054003_5170_t.jpg
Successfully saved mask to drive: /content/drive/MyDrive/diopsis_tests/diopsis_masks_generated/mask_192_20200725054003_5170_t_0.png
Successfully processed and saved masks for: /content/drive/MyDrive/diopsis_tests/images_clean/192_202007250

KeyboardInterrupt: 

## Checking if the number of images in the dataset match the number of processed images

In [None]:
num_images = len(os.listdir(dataset_path))
print(f"Number of images read in dataset: {num_images}")

# ------------------------------------------------------

num_images = len(os.listdir(save_path))
print(f"Number of images read in processed images: {num_images}")

Number of images read in dataset: 27649
Number of images read in processed images: 25617
