# SONY A6700 to Segment Anything Model (SAM) Pipeline

This notebook creates a pipeline for image segmentation:
1. Import necessary libraries and SAM
2. Take image with GPhoto2 Open-source camera control library
3. Augment and process image
4. Send image through SAM
5. Display segmented image

Requirements:
* Camera comptabile with GPhoto2
* Linux operating system
* GPU for SAM processing


## 1. Import necessary libraries and SAM

In [None]:
# import python libraries
import numpy as np
import torch 
import matplotlib.pyplot as plt
import cv2 
import sys
import os
import time
import gphoto2 as gp
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
sys.path.append("..")

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)


## 2. Take image with GPhoto2 Open-source camera control library

In [None]:
# initialize camera object
camera = gp.Camera()
camera.init()

try:
    print("Capturing image....")
    file_path = camera.capture(gp.GP_CAPTURE_IMAGE)

    filename = file_path.name
    folder = file_path.folder

    # define where to store image, local directory in this case
    target_path = os.path.join(os.getcwd(), "test1.jpg")

    # download file from camera to host
    print(f"Saving image to {target_path}")
    camera_file = camera.file_get(folder, filename, gp.GP_FILE_TYPE_NORMAL)
    camera_file.save(target_path)

except gp.GPhoto2Error as ex:
    print(f"An error occured: {ex}")

finally:
    camera.exit()
    print("Success")


## 3. Augment and process image

In [None]:
# load image into cv2
image = cv2.imread(target_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(15,15))
plt.imshow(image)
plt.axis('off')
plt.show()


## 4. Send image through SAM

In [None]:
start = time.time()
masks = mask_generator.generate(image)
end = time.time()

print(f"SAM took {start - end} seconds to segment.")

## 5. Display segmented image

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

plt.figure(figsize=(15,15))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()