# Zero-Shot Image Segmentation For Segmented SLAM

## Import libraries

In [1]:
import os
import cv2
import numpy as np
import torch
import rasterio
from matplotlib import pyplot as plt
import samgeo as samsam
from samgeo.text_sam import LangSAM
from PIL import Image

In [2]:
import tensorflow as tf

# Check if GPU is available
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[0], 'GPU')


2024-04-30 18:15:42.000785: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-30 18:15:42.094837: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Num GPUs Available:  2


##Select Image and Path

## Initialize LangSAM class

In [3]:
sam = LangSAM()


final text_encoder_type: bert-base-uncased


##Custom Function Implementation of `predict`

To fix native function bug where the object does not update it's mask upon zero object detection

In [4]:
def predict(
    self,
    image,
    text_prompt,
    box_threshold,
    text_threshold,
    output=None,
    mask_multiplier=255,
    dtype=np.uint8,
    save_args={},
    return_results=False,
    return_coords=False,
    **kwargs,
):
    """
    Run both GroundingDINO and SAM model prediction.

    Parameters:
        image (Image): Input PIL Image.
        text_prompt (str): Text prompt for the model.
        box_threshold (float): Box threshold for the prediction.
        text_threshold (float): Text threshold for the prediction.
        output (str, optional): Output path for the prediction. Defaults to None.
        mask_multiplier (int, optional): Mask multiplier for the prediction. Defaults to 255.
        dtype (np.dtype, optional): Data type for the prediction. Defaults to np.uint8.
        save_args (dict, optional): Save arguments for the prediction. Defaults to {}.
        return_results (bool, optional): Whether to return the results. Defaults to False.

    Returns:
        tuple: Tuple containing masks, boxes, phrases, and logits.
    """

    if isinstance(image, str):
        if image.startswith("http"):
            image = download_file(image)

        if not os.path.exists(image):
            raise ValueError(f"Input path {image} does not exist.")

        self.source = image

        # Load the georeferenced image
        with rasterio.open(image) as src:
            image_np = src.read().transpose(
                (1, 2, 0)
            )  # Convert rasterio image to numpy array
            transform = src.transform  # Save georeferencing information
            crs = src.crs  # Save the Coordinate Reference System
            image_pil = Image.fromarray(
                image_np[:, :, :3]
            )  # Convert numpy array to PIL image, excluding the alpha channel
    else:
        image_pil = image
        image_np = np.array(image_pil)

    self.image = image_pil

    boxes, logits, phrases = self.predict_dino(
        image_pil, text_prompt, box_threshold, text_threshold
    )
    masks = torch.tensor([])
    if len(boxes) > 0:
        masks = self.predict_sam(image_pil, boxes)
        masks = masks.squeeze(1)

    # Create an empty image to store the mask overlays
    mask_overlay = np.zeros_like(
        image_np[..., 0], dtype=dtype
    )  # Adjusted for single channel

    if boxes.nelement() == 0:  # No "object" instances found
        print("No "+text_prompt+" found in the image.")
        # return
    else:
        for i, (box, mask) in enumerate(zip(boxes, masks)):
            # Convert tensor to numpy array if necessary and ensure it contains integers
            if isinstance(mask, torch.Tensor):
                mask = (
                    mask.cpu().numpy().astype(dtype)
                )  # If mask is on GPU, use .cpu() before .numpy()
            mask_overlay += ((mask > 0) * (i + 1)).astype(
                dtype
            )  # Assign a unique value for each mask

        # Normalize mask_overlay to be in [0, 255]
        mask_overlay = (
            mask_overlay > 0
        ) * mask_multiplier  # Binary mask in [0, 255]

    if output is not None:
      if isinstance(mask_overlay, torch.Tensor):
        mask_overlay = mask_overlay.cpu().numpy().astype(dtype)
      Image.fromarray(mask_overlay).save(output+'.png')
      # array_to_image(mask_overlay, output, self.source, dtype=dtype, **save_args)

    self.masks = masks
    self.boxes = boxes
    self.phrases = phrases
    self.logits = logits
    self.prediction = mask_overlay

    if return_results:
        return masks, boxes, phrases, logits

    if return_coords:
        boxlist = []
        for box in self.boxes:
            box = box.cpu().numpy()
            boxlist.append((box[0], box[1]))
        return boxlist

In [5]:
def show_anns(
    self,
    figsize=(12, 10),
    axis="off",
    cmap="viridis",
    alpha=0.4,
    add_boxes=True,
    box_color="r",
    box_linewidth=1,
    title=None,
    output=None,
    blend=True,
    **kwargs,
):
    """Show the annotations (objects with random color) on the input image.

    Args:
        figsize (tuple, optional): The figure size. Defaults to (12, 10).
        axis (str, optional): Whether to show the axis. Defaults to "off".
        cmap (str, optional): The colormap for the annotations. Defaults to "viridis".
        alpha (float, optional): The alpha value for the annotations. Defaults to 0.4.
        add_boxes (bool, optional): Whether to show the bounding boxes. Defaults to True.
        box_color (str, optional): The color for the bounding boxes. Defaults to "r".
        box_linewidth (int, optional): The line width for the bounding boxes. Defaults to 1.
        title (str, optional): The title for the image. Defaults to None.
        output (str, optional): The path to the output image. Defaults to None.
        blend (bool, optional): Whether to show the input image. Defaults to True.
        kwargs (dict, optional): Additional arguments for matplotlib.pyplot.savefig().
    """

    import warnings
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches

    warnings.filterwarnings("ignore")

    anns = self.prediction

    if anns is None:
        print("Please run predict() first.")
        return
    elif len(anns) == 0:
        print("No objects found in the image.")
        return

    plt.figure(figsize=figsize)
    plt.imshow(self.image)

    if add_boxes:
        for box in self.boxes:
            # Draw bounding box
            box = box.cpu().numpy()  # Convert the tensor to a numpy array
            rect = patches.Rectangle(
                (box[0], box[1]),
                box[2] - box[0],
                box[3] - box[1],
                linewidth=box_linewidth,
                edgecolor=box_color,
                facecolor="none",
            )
            plt.gca().add_patch(rect)

    if "dpi" not in kwargs:
        kwargs["dpi"] = 100

    if "bbox_inches" not in kwargs:
        kwargs["bbox_inches"] = "tight"

    plt.imshow(anns, cmap=cmap, alpha=alpha)

    if title is not None:
        plt.title(title)
    plt.axis(axis)
    #print(output)
    if output is not None:
        if blend:
            plt.savefig(output, **kwargs)
        else:
            self.array_to_image(self.prediction, output, self.source)
    plt.close()

## Custom implementation of `change_color`

Function to change color of numpy array from white to RGB value contained in input tuple `tup`

In [6]:
def change_color(tup,data):
  red, green, blue = data[:,:,0], data[:,:,1], data[:,:,2]
  mask = (red == 255) & (green == 255) & (blue == 255)
  data[:,:,:3][mask] = [tup[0], tup[1], tup[2]]

## Input Classes

In [7]:
# Input classes
# Enter RGB values after class string
classes = ['wall', (255,0,0),'chair', (0,255,0),'floor', (0,0,255),'human',(150,150,150)]

## Segment the Image

## Segment and Save to `final_mask_path` in drive

In [9]:

import rospy
from sensor_msgs.msg import Image as R_Image
from cv_bridge import CvBridge
import threading

os.environ['ROS_MASTER_URI'] = 'http://172.24.16.118:11311'
os.environ['ROS_IP'] = '172.24.16.118'
os.environ['ROS_HOSTNAME'] = 'csis-inspire'
#os.environ['PYTHONPATH']='/opt/ros/noetic/lib/python3/dist-packages'

input_path = 'seg_user'
upper = 500
lower = 1
ext = '.jpg'
final_mask_path = "/home/seg_user/deepak_sem2/final-masks/"


class Nodo(object):
    def __init__(self):
        # Params
        # print("init start")
        self.image = None
        self.seg_image = None
        self.image_PIL = None
        self.wid = 0
        self.hei = 0
        self.br = CvBridge()
        self.ready=False
        # Node cycle rate (in Hz).
        self.loop_rate = rospy.Rate(10)

        # Publishers
        self.pub = rospy.Publisher('image_segmented', R_Image,queue_size=10)

        # Subscribers
        rospy.Subscriber("/camera/color/image_raw",R_Image,self.callback)
        # print("init done")

    def callback(self, msg):
        #rospy.loginfo('Image received...')
        self.image = self.br.imgmsg_to_cv2(msg)
        self.image_PIL = Image.fromarray(self.image)
        self.wid, self.hei = self.image_PIL.size
        #print("recv image")
        

    def prediction(self):
        with tf.device('/device:GPU:0'):
            while not rospy.is_shutdown():
                if self.ready:
                    out_masks=[]
                    for ii in range(0, len(classes), 2):
                        #print(f"{i}) Segmented mask for {classes[ii]}:")
                        fmp=final_mask_path + 'temp-' + classes[ii] + '.png'
                        predict(sam, self.image_PIL, classes[ii], box_threshold=0.35, text_threshold=1)  # Higher threshold indicates greater confidence in prediction
                        show_anns(sam,
                            cmap="Greys_r",
                            add_boxes=False,
                            alpha=1,
                            output=fmp,
                            blend=True,
                            bbox_inches='tight',
                            pad_inches=0,
                        )
                        # print("predicted")
                        img = Image.open(fmp)
                        img = img.resize((self.wid, int((float(img.size[1]) * float((self.wid / float(img.size[0])))))), Image.ANTIALIAS)  #     To resize image to the same dimensions as the input
                        os.remove(fmp)
                        out_masks.append(img)
                    # rospy.loginfo('Image segmented...')
                    new_masks = []
                    for ii in range(len(out_masks)):
                        temp = out_masks[ii].convert('RGB')  # RGBA for pngs
                        new_masks.append(np.array(temp))

                    for ii in range(len(new_masks)):
                        change_color(classes[(2 * ii) + 1], new_masks[ii])
                    
                    

                    final_mask = sum(new_masks)

                    # Save the final mask
                    final_image = Image.fromarray(final_mask,'RGB')
                    self.seg_image = np.asarray(final_image)
                    self.ready=False


    def start(self):
        # rospy.loginfo("Segmenting images")
        thx = threading.Thread(target=self.prediction)
        thx.start()
        #rospy.spin()
        while not rospy.is_shutdown():
            #if self.wid!=0:
            #    self.prediction()
            self.loop_rate.sleep()
            if self.image is not None:
                self.ready=True
                cv2.imshow("input",self.image)
            if self.seg_image is not None:
                cv2.imshow("output",self.seg_image)
                self.pub.publish(self.br.cv2_to_imgmsg(self.seg_image))
                # print("published")
            cv2.waitKey(10)
def main():
    rospy.init_node("image_sam2", anonymous=True)
    # print(1)
    my_node = Nodo()
    #rospy.spin()
    my_node.start()
if __name__ == "__main__":
    # print("sdsdsd")
    main()        
cv2.destroyAllWindows()
torch.cuda.empty_cache()


2024-04-30 18:16:48.459189: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 34861 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:17:00.0, compute capability: 8.0


No chair found in the image.
No chair found in the image.
No human found in the image.
No chair found in the image.
No chair found in the image.
No human found in the image.
No chair found in the image.
No floor found in the image.
No floor found in the image.
No wall found in the image.
No wall found in the image.
No floor found in the image.
No floor found in the image.
No wall found in the image.
No floor found in the image.
No wall found in the image.
No floor found in the image.
No human found in the image.
No wall found in the image.
No floor found in the image.
No human found in the image.
No floor found in the image.
No human found in the image.
No wall found in the image.
No floor found in the image.
No wall found in the image.
No chair found in the image.
No human found in the image.
No chair found in the image.
No human found in the image.
No chair found in the image.
No human found in the image.
No chair found in the image.
No chair found in the image.
No human found in the

In [None]:
# import rospy
# from sensor_msgs.msg import Image

# def image_segmented_callback(data):
#     # This function will be called whenever a message is received on the /image_segmented topic
#     # You can process the segmented image data here
#     # For demonstration, let's print out the height and width of the image
#     height = data.height
#     width = data.width
#     rospy.loginfo("Received segmented image with height: {} and width: {}".format(height, width))
#     two_d_array = [[0 for j in range(width)] for i in range(height)]
#     for i in range (0,width):
#         for j in range (0,height):
#             two_d_array[i][j]=get_pixel_rgb(i,j)
#     return two_d_array

# def point_cloud_subscriber():
#     # Initialize the ROS node
#     rospy.init_node('image_segmented_subscriber', anonymous=True)

#     # Subscribe to the /image_segmented topic
#     rospy.Subscriber("/camera/color/image_raw", Image, image_segmented_callback)

#     # Spin() simply keeps python from exiting until this node is stopped
#     rospy.spin()

# if __name__ == '__main__':
#     try:
#         point_cloud_subscriber()
#     except rospy.ROSInterruptException:
#         pass


In [None]:
# import rospy
# from sensor_msgs.msg import PointCloud2
# import sensor_msgs.point_cloud2 as pc2
# import numpy as np

# # Callback function to process incoming point cloud data
# def point_cloud_callback(msg):
#     # Parse point cloud data
#     points = []
#     for point in pc2.read_points(msg, field_names=("x", "y", "z", "rgb"), skip_nans=True):
#         points.append(point)

#     # Example: Given x and y coordinates, change RGB color values
#     target_x = 1.0  # Example x-coordinate
#     target_y = 2.0  # Example y-coordinate
#     for i, point in enumerate(points):
#         x, y, z, rgb = point
#         if x == target_x and y == target_y:
#             # Modify RGB color values (e.g., set to red)
#             rgb = (255, 0, 0)
#             # Update the point in the list
#             points[i] = (x, y, z, rgb)
#             break

#     # Publish the modified point cloud data
#     modified_msg = pc2.create_cloud_xyzrgb(msg.header, points)
#     pub.publish(modified_msg)

# # Initialize ROS node, subscriber, and publisher
# rospy.init_node("point_cloud_modifier")
# rospy.Subscriber("/input/point_cloud_topic", PointCloud2, point_cloud_callback)
# pub = rospy.Publisher("/output/modified_point_cloud_topic", PointCloud2, queue_size=10)

# # Spin to keep the node running
# rospy.spin()
