[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/edwinRNDR/nlnml/blob/master/NB20_segment_images.ipynb)


# CLIP prompt feature extractor

In [1]:
!pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-p74u6xke
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-p74u6xke
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36588 sha256=9bde7a2ddd26030eba505dbab7fb89619dc7db078bdd451f0c739532e9b86b79
  Stored in directory: /tmp/pip-ephem-wheel-cache-35enswc5/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99d
Successfully built segment-anything
Installing collected packages: segment-anything
Successfully 

In [2]:
import os
import torch
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from glob import glob
from tqdm import tqdm
from PIL import Image
from os import path
import numpy as np
import cv2

# Connect to GoogleDrive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# This is where my dataset lives


In [None]:
dataset_location = '/content/gdrive/MyDrive/nln-dataset'

# Segment images

This can take several hours

In [None]:
dataset_location_normalized = f"{dataset_location}/normalized"
dataset_location_masked = f"{dataset_location}/masked"

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam, points_per_batch=16,  min_mask_region_area=500)

os.makedirs(dataset_location_masked, exist_ok=True)

to_process = [sorted(glob(f'{label}/*')) for label in sorted(glob(f"{dataset_location_normalized}/*"))]
to_process_flat = [item for lst in to_process for item in lst]

def make_mask_image(image, anns, image_file):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    sorted_anns = [x for x in sorted_anns if x['area'] > 500]

    rp = path.relpath(image_file, dataset_location_normalized)
    label = path.dirname(rp)

    target_label = f"{dataset_location_masked}/{label}"
    os.makedirs(target_label, exist_ok=True)

    for (index, ann) in enumerate(sorted_anns):
        img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
        img[:, :, 3] = 0
        m = ann['segmentation']
        color_mask = [ 1.0, 1.0, 1.0, 1.0]
        img[m] = color_mask

        img = img * image

        bytes = np.uint8(img)
        target = f"{target_label}/{path.splitext(path.basename(image_file))[0]}-{index:03d}.png"
        Image.fromarray(bytes).save(target)

for image_file in tqdm(to_process_flat):

    rp = path.relpath(image_file, dataset_location_normalized)
    label = path.dirname(rp)

    target_label = f"{dataset_location_masked}/{label}"
    target = f"{target_label}/{path.splitext(path.basename(image_file))[0]}-000.png"

    if not path.exists(target):
        torch.cuda.empty_cache()

        image = cv2.imread(image_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        masks = mask_generator.generate(image)
        make_mask_image(cv2.cvtColor(image, cv2.COLOR_RGB2RGBA), masks, image_file)