### Sending a request to the torchserve encoder service for the Segment Anything model (SAM)

When running services locally, make sure to select the correct port. 70* for the cpu service, 80* for the gpu service.

Note, the GPU service doesn't support the decoder model since this can be run on the CPU service!

The CPU service supports the encoder (slow) and the decoder (fast).

In [None]:
import httpx
import base64
from PIL import Image
from PIL.ImageOps import autocontrast
from io import BytesIO
import numpy as np
import matplotlib.pyplot as plt
import json

# default localhost endpoints after starting both containers, see README
# encode_url="http://127.0.0.1:7080/predictions/sam_vit_h_encode"
encode_url="http://127.0.0.1:8080/predictions/sam_vit_h_encode"

pth_slick = "../data/tile_with_slick_512_512.png"
input_point_not_on_slick = (10, 120)
input_point_on_slick = (6, 120)
input_label = 1

We'll run SAM on a small subset of a Sentinel-1 image that captured an oil slick on the ocean from a shipping vessel.

In [None]:
img_slick = Image.open(pth_slick)
autocontrast(img_slick, cutoff=0, ignore=None, mask=None, preserve_tone=False)

Reads image as bytes, converts bytes to string so it can be sent as a post request

In [None]:
with open(pth_slick, 'rb') as f:
    byte_string = f.read()
    base64_string = base64.b64encode(byte_string).decode('utf-8')

payload = {"encoded_image": base64_string}

Let's run the image encoder locally. Use the CPU endpoint if you don't have a GPU. Timings will differ based on the GPU type or if running on the CPU it will take over a minute for an unoptimized model. Time to encode image on 1080 Ti GPU is about 2 seconds.

In [None]:
%%time
try:
    response = httpx.post(encode_url, json=payload, timeout=None)
except (BrokenPipeError, httpx.RemoteProtocolError, ConnectionResetError) as e:
    print("wait and try again")

In [None]:
response

Accessing the image embeddings for the oil slick scene and converting to a numpy array. The image embeddings represent the features of the image that we can produce mask predictions from.

In [None]:
encoded_embedding_string = response.json()['image_embedding']
base64_bytes = base64.b64decode(encoded_embedding_string)
image_embedding = np.frombuffer(base64_bytes, dtype=np.float32)
image_embedding

Next we send the image embeddings to the decoder service

In [None]:
img_shape = np.array(img_slick).shape
decode_payload = {
    "image_embeddings": encoded_embedding_string,
    "image_shape": img_shape,
    "input_label": input_label,
    "input_point": input_point_on_slick
}

In [None]:
%%time
decode_url="http://127.0.0.1:7080/predictions/sam_vit_h_decode" # make sure to select correct port. 70* for cpu, 80* for gpu
response = httpx.post(decode_url, json=decode_payload, timeout=None)

In [None]:
response

In [None]:
encoded_masks_string = response.json()['masks']
base64_bytes_masks = base64.b64decode(encoded_masks_string)
masks = np.frombuffer(base64_bytes_masks, dtype=bool)

There are four masks, each with their own confidence score, predicted for the single point prompt. SAM makes an effort to predict valid masks in cases where there is ambiguity as to which object is desired and minimal prompting. See the SAM paper https://arxiv.org/pdf/2304.02643.pdf for details

In [None]:
masks = masks.reshape((1,4,512, 512))

Setting up plotting functions

In [None]:
def show_mask(mask, ax):
    color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

With our encoder and decoder service, we get a solid mask prediction by just supplying a point on the object of interest!

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(np.array(img_slick))
show_mask(masks[0,1,:,:], plt.gca())
input_point_arr = np.array(input_point_on_slick)[np.newaxis,:]
input_label_arr = np.array(input_label)[np.newaxis]
show_points(input_point_arr, input_label_arr, plt.gca())
plt.axis('off')
plt.show() 

Next let's test the geospatial endpoint. For geospatial imagery it is more useful to return a georeferenced mask instead of an unreferenced numpy array so that we can plot these predictions on a map and associate them with other geospatial data.

In [None]:
import rasterio
import io
from skimage import img_as_ubyte
with rasterio.open("../data/sample-georeferenced_burn_scar.tif") as dataset:
    arr = dataset.read()
    bbox = dataset.bounds
    crs = "EPSG:32610"


arr = img_as_ubyte(arr).transpose((1,2,0))

img = Image.fromarray(arr)

# Create byte stream
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
base64_string = img_str.decode('utf-8')
payload = {"encoded_image": base64_string}

In [None]:
%%time
try:
    response = httpx.post(encode_url, json=payload, timeout=None)
except (BrokenPipeError, httpx.RemoteProtocolError) as e:
    print("wait and try again")

In [None]:
encoded_embedding_string = response.json()['image_embedding']
base64_bytes = base64.b64decode(encoded_embedding_string)
image_embedding = np.frombuffer(base64_bytes, dtype=np.float32)
image_embedding

In [None]:
input_point_on_burn = (220,120)

We'll test SAM on a Sentinel-2 image of a burn scar in a USA agricultural region. SAM does a decent job with single point prompting of delineating the burn scar but adds in some incorrect pixels to the mask.

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(np.array(img))
input_point_arr = np.array(input_point_on_burn)[np.newaxis,:]
input_label_arr = np.array(input_label)[np.newaxis]
show_points(input_point_arr, input_label_arr, plt.gca())
plt.axis('off')
plt.show() 

Our source crs. The decoder service will reproject the outputs to WGS84 no matter the source CRS (the source CRS can only be supplied via epsg code).

In [None]:
crs

In [None]:
img_shape = img.size
decode_payload = {
    "image_embeddings": encoded_embedding_string,
    "image_shape": img_shape,
    "input_label": input_label,
    "input_point": input_point_on_burn,
    "crs":crs,
    "bbox": list(bbox),
}

In [None]:
%%time
decode_url="http://127.0.0.1:7080/predictions/sam_vit_h_decode" # make sure to select correct port. 70* for cpu, 80* for gpu
response = httpx.post(decode_url, json=decode_payload, timeout=None)

In [None]:
response

In [None]:
geojson_masks = response.json()['geojsons']

In [None]:
type(geojson_masks)

In [None]:
type(geojson_masks[0])

In [None]:
geojson_masks[0][0:1000]

In [None]:
geojson_dict = json.loads(geojson_masks[0])

In [None]:
with open('multi_polygon.geojson', 'w') as f:
    f.write(geojson_masks[3])

In [None]:
len(geojson_masks)

Check out the geojson result in your favorite GIS! Like the slick image, we have 4 masks, represented by 4 MultiPolygon types.