# SAM segmentation
Image segmentation using Segment Anything Model (SAM)

In [1]:
D_NAME_TARGET = None#'set-green'

D_NAME_SRC = 'SASVAR_set'
SUB_DIRECTORY = 'NuevosProductos_crop'

# Preparation

In [2]:
import os
import sys
from pathlib import Path


import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


In [3]:
!nvidia-smi

Wed Jun 14 19:13:55 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   48C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Instalation

In [4]:
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /content


In [5]:
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-errokiun
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-errokiun
  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment-anything
  Building wheel for segment-anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment-anything: filename=segment_anything-1.0-py3-none-any.whl size=36589 sha256=67215a52d1f65f139108b6e3b9166b46676e68366fec9a5d9a986e90786a5a36
  Stored in directory: /tmp/pip-ephem-wheel-cache-gjocz7hl/wheels/10/cf/59/9ccb2f0a1bcc81d4fbd0e501680b5d088d690c6cfbc02dc99

In [6]:
# Get pretrained weights
!mkdir {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, ". Exist:", os.path.isfile(CHECKPOINT_PATH))

/content/weights/sam_vit_h_4b8939.pth . Exist: True


## Load Model

In [7]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"

In [8]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator = SamAutomaticMaskGenerator(sam)

# Segment

`SamAutomaticMaskGenerator` returns a `list` of masks, where each mask is a `dict` containing various information about the mask:

* `segmentation` - `[np.ndarray]` - the mask with `(W, H)` shape, and `bool` type
* `area` - `[int]` - the area of the mask in pixels
* `bbox` - `[List[int]]` - the boundary box of the mask in `xywh` format
* `predicted_iou` - `[float]` - the model's own prediction for the quality of the mask
* `point_coords` - `[List[List[float]]]` - the sampled input point that generated this mask
* `stability_score` - `[float]` - an additional measure of mask quality
* `crop_box` - `List[int]` - the crop of the image used to generate this mask in `xywh` format

In [9]:
def blend(im1, im2, alpha):
    """Alpha blend"""
    return (1-alpha)*im1 + alpha*im2

def assemble_png(im, sam_mask):
    """Assemble PNG with mask from SAM"""
    alpha = np.zeros(sam_mask.shape)
    alpha[sam_mask] = 255
    mask = np.dstack([alpha for i in 'RGB']) / 255
    # Remove background
    bck = np.zeros(im.shape)+255
    src =  blend(bck, im, mask)

    return np.dstack([src, alpha])

def create_out_dir(path, tag='out'):
    output = path.split(os.sep)[-1]+f'_{tag}'

    output_dir = os.path.join(os.sep.join(path.split(os.sep)[:-1]), output)
    os.makedirs(output_dir, exist_ok=True)
    print(f'    - Output directory: {output_dir}')

    return output_dir

In [10]:
# Unzip source dataset
drive_path = f'/content/drive/MyDrive/SASVAR/waste_datasets/{D_NAME_SRC}.zip'
local_path = f'/content/{D_NAME_SRC}'
if not os.path.exists(local_path):
    !unzip -qq {drive_path}

In [11]:
# Unzip target dataset
# Contains previously labeled
if D_NAME_TARGET:
    drive_path_target = f'/content/drive/MyDrive/{D_NAME_TARGET}.zip'
    local_path_target = f'/content/{D_NAME_TARGET}'
    if not os.path.exists(local_path_target):
        !unzip -qq {drive_path_target}

In [12]:
# Check lready segmented images
sub_path = f'{local_path}/{SUB_DIRECTORY}'

src_files = list(Path(sub_path).glob('**/*'))

# Already segmented files
if D_NAME_TARGET:
    # Return the difference between source and target sets of filenames
    seg_files = Path(local_path_target).glob('**/*')

    src_filenames = set([f.name for f in src_files])
    trg_filenames = set([f.name for f in seg_files])

    print(f'{len(src_filenames)} source files | {len(trg_filenames)} target files')

    # Files to segment
    filenames = list(src_filenames.difference(trg_filenames))

    # Return again Path objects
    files = [f for f in src_files if f.name in filenames]
else:
    files = src_files

print(f'{len(files)} images to segment')

261 images to segment


In [13]:
if D_NAME_TARGET:
    output_dir = local_path_target
else:
    output_dir = create_out_dir(sub_path)


err = []
for filepath in tqdm(files, total=len(files)):
    try:
        image_bgr = cv2.imread(str(filepath))
        image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        sam_result = mask_generator.generate(image_rgb)
        # Sort image segments by area
        # Second largest segment is the target element
        # First is backgrund
        sam_result = sorted(sam_result, key=lambda item: item['area'], reverse=True)
        if len(sam_result) > 1:
            sam_mask = sam_result[1]['segmentation']

            # Assemble PNG file with segment
            out = assemble_png(image_bgr, sam_mask)

            out_filename = f'{filepath.name.split(".")[0]}.png'
            out_filepath = os.path.join(output_dir, out_filename)
            cv2.imwrite(out_filepath, out)

        else:
            err.append(str(filepath))


    except Exception as e:
        err.append(str(filepath))

print(f'{len(err)} Errors on images: {err}')


    - Output directory: /content/SASVAR_set/NuevosProductos_crop_out


100%|██████████| 261/261 [29:28<00:00,  6.78s/it]

0 Errors on images: []





# Compress and save

In [14]:
!zip -r {D_NAME_TARGET}_out.zip {output_dir}

  adding: content/SASVAR_set/NuevosProductos_crop_out/ (stored 0%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_408_0_130623024607_0.png (deflated 8%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_145_0_130623003753_0.png (deflated 5%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_439_0_130623021830_2.png (deflated 9%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_693_0_130623012538_0.png (deflated 6%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_145_0_130623003753_1.png (deflated 7%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_648_0_130623023747_0.png (deflated 3%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_651_0_130623025514_2.png (deflated 4%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_281_0_130623015915_0.png (deflated 4%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_683_0_130623015053_0.png (deflated 7%)
  adding: content/SASVAR_set/NuevosProductos_crop_out/O_646_0_130623022141_1.p

In [15]:
!cp {D_NAME_TARGET}_out.zip /content/drive/MyDrive/