## Import dependencies

In [None]:
from osg.utils.general_utils import load_r3d_data, create_r3d_observation_graph, get_spatial_referents, draw_observations_graph
from osg.vlm_library import vlm_library

## Setup

In [None]:
## setup
tmp_fldr=f"results/"
vlm_instance   = vlm_library(vl_model="owl_vit", data_src="r3d", seg_model="mobile_sam", tmp_fldr=tmp_fldr) 

## load data
data_file = "../data/sample_r3d_data/bdai_kitchen.r3d"
posed_dataset, observation_data, env_pointcloud = load_r3d_data(data_file, tmp_fldr, depth_confidence_cutoff=0.7, pcd_downsample=False)

observations_graph, node_id2key, node_key2id, node_coords = create_r3d_observation_graph(observation_data, tmp_fldr=tmp_fldr)

## draw observations graph
# draw_observations_graph(observations_graph, node_coords, plt_size=(10,10),axis=False)

In [None]:
observations_graph.nodes[0]

## Composible Referent Descriptors

In [None]:
## Composible Referent ## Composible Referent Descriptor (CRD) 
    # CRDs are propositional expressions that represent specific referent instances by chaining comparators that encode descriptive spatial information. 
    # For more details see: https://arxiv.org/abs/2402.11498

## CRD Syntax
    # referent_1::isbetween(referent_2,referent_3)  :denotes that referent_1 is between referent_2 and referent_3.
    # referent_1::isabove(referent_2)               :denotes that referent_1 is above referent_2.
    # referent_1::isbelow(referent_2)               :denotes that referent_1 is below referent_2.
    # referent_1::isleftof(referent_2)              :denotes that referent_1 is left of referent_2.
    # referent_1::isrightof(referent_2)             :denotes that referent_1 is right of referent_2.
    # referent_1::isnextto(referent_2)              :denotes that referent_1 is close to referent_2.
    # referent_1::isinfrontof(referent_2)           :denotes that referent_1 is in front of referent_2.
    # referent_1::isbehind(referent_2)              :denotes that referent_1 is behind referent_2.

## Examples
    # Desired referent:   table behind the fridge
    # CRD representation: table::isbehind(fridge) 

    # Desired referent:    chair between the green laptop and the yellow box below the sofa
    # CRD representation:  chair::isbetween(green_laptop,yellow_box::isbelow(sofa))

    # Desired referent:    brown bag between the television and the kettle on the left of the green seat
    # CRD representation:  brown_bag::isbetween(television, kettle::isleftof(green_seat))

## Ground referents and filter instances via spatial constraints

In [None]:
# Enter comma seperated referent names or composible referent descriptors you wish to ground
# referents_to_ground = ["coffee_machine", "red_cup::isnextto(microwave)"]
referents_to_ground = ["fridge"]

## Extract spatial information
referent_spatial_details = get_spatial_referents(referents_to_ground)
print("referent_spatial_details: ",referent_spatial_details,"\n")

In [None]:
## Spatial grounding
relevant_element_details = vlm_instance.spatial_grounding(observations_graph, referent_spatial_details, visualize=True, use_segmentation=True, multiprocessing=True, workers=3)

In [None]:
print(f"\nReferents after spatial constraint filtering:",len(relevant_element_details))
#for all relevant elements print their ids
print(f"Filtered elements \n",[element['mask_id'] for element in relevant_element_details])

### Implementation Grounds

In [None]:
## if multiprocessing approach
# 1. chunk data into parts
# 2. batch use owlvit to label images in each chunk
# 3. batch generate sam embeddings on each chunk

In [None]:
def process_node_r3d_batch():

In [None]:
def process_node_r3d(self, node_idx ,observation_graph,propositions,visualize, use_segmentation=True): 
    node_element_details = []   
    print(f"Evaluating Waypoint at Node {node_idx}")
    observation_graph.nodes[node_idx]['annotated_img']={}
    bounds,labels,confidence = self.label_observation(observation_graph.nodes[node_idx]['rgb_pil'],propositions,threshold=0.1)

    present_propositions = list(set(labels))
    tracking_ids=[]
    # print(f"Elements in observation: {present_propositions}")

    #Check if detected labels/propositions at observation node havent already been segmented and masks obtained 
    if not all(element in self.grounded_elements for element in present_propositions):
        if use_segmentation:
            print(f"      Detected Task relevant elements: {present_propositions} || Segmenting to obtain masks ...")
            # segment for masks of each proposition
            masks,sam_embedding=self.segment(observation_graph.nodes[node_idx]['rgb_pil'],bounds)
        else:
            print(f"      Detected Task relevant elements: {present_propositions} || Using bounding boxes as masks ...")
            masks = bounds
            sam_embedding=None

        #get depth info
        try :
            depth_data = observation_graph.nodes[node_idx]['depth_data']
            print(f"      Loaded depth data from waypoint node {node_idx}, ")
        except:
            print(f"      No existing depth data || Estimating depth image...")
            depth_img, depth_data = self.estimate_depth(observation_graph.nodes[node_idx]['rgb_pil'])

        #get mask info
        for i,mask in enumerate(masks):
            print(f"      Processing {labels[i]} mask")
            if use_segmentation:
                    actual_mask = mask[0].cpu()
                    center_pixel, center_pixel_depth = get_center_pixel_depth(actual_mask,depth_data[0])
                    mask_pixel_coords,pixel_depths,average_depth=get_mask_pixels_depth(actual_mask,depth_data[0])
            else:
                    actual_mask = mask
                    center_pixel, center_pixel_depth = get_bounding_box_center_depth(actual_mask,depth_data[0])
                    mask_pixel_coords,pixel_depths,average_depth=get_bounding_box_pixels_depth(actual_mask,depth_data[0])
            
            print(f"         Mask {labels[i]}_{str(node_idx)}_{i} || Original Center pixel: {center_pixel} || Center pixel depth: {center_pixel_depth}")
            center_pixel_depth=average_depth #Use average of mask depth with actual values as depth of center pixel || not just the actual center pixel depth
            
            ##Three step adaptive depth data approach 
            mask_depth=0.0
            if center_pixel_depth != 0.0:
                    mask_depth=center_pixel_depth
            if center_pixel_depth == 0.0:
                    successful_pixel = False
                    #Try to get closest pixel to center pixel in mask, that has a depth
                    # Create a list of tuples with pixel indices, depth, and distance from center pixel
                    pixel_data = [(idx, depth, ((coord[0] - center_pixel[0])**2 + (coord[1] - center_pixel[1])**2)**0.5)
                                for idx, (depth, coord) in enumerate(zip(pixel_depths, mask_pixel_coords))]

                    # Sort the pixel data list based on the distance from center pixel
                    sorted_pixel_data = sorted(pixel_data, key=lambda x: x[2])

                    # Iterate over the sorted pixel data list
                    for idx, depth, distance in sorted_pixel_data:
                        if depth != 0.0:
                            mask_depth = depth
                            center_pixel = mask_pixel_coords[idx]
                            successful_pixel = True
                            print(f"         Center pixel depth empty,obtained new pixel from mask || pixel:{center_pixel}, depth: {mask_depth}")
                            break

                    # Skip monocular model and just move on to next mask if no sensor depth 
                    if successful_pixel == False:
                        print("         Not using depth model, moving on to next mask")
                        continue

            print(f"         Mask {labels[i]}_{str(node_idx)}_{i} || Chosen Center pixel: {center_pixel} || Average mask depth: {average_depth} || Chosen Mask depth: {mask_depth}")

            mask_label =  labels[i]
            tracking_id = labels[i]+"_"+str(node_idx)+"_"+str(i)
            tracking_ids.append(tracking_id)
            if tracking_id not in self.grounded_elements: #only record new masks
                    #get worldframe backprojected 3d position of object(center pixel)
                    print(f"         Backprojectig 3D ray using pixel: {center_pixel} & depth: {mask_depth}m for {tracking_id}...")
                    center_y, center_x = center_pixel
                    pixel_depth = mask_depth
                    pose = observation_graph.nodes[node_idx]['pose']
                    intrinsics = observation_graph.nodes[node_idx]['intrinsics']

                    #switch to record3d approach for backprojection
                    transformed_point = get_xyz_coordinate(depth_data[0], pose['pose_matrix'], intrinsics, center_x, center_y)
                    # get a

                    print(f"         Recording mask info for {tracking_id}...")
                    node_element_details.append({"mask_label":mask_label,
                                            "mask_id":tracking_id,
                                            "origin_obsnode":node_idx,
                                            "mask":actual_mask.cpu() if torch.is_tensor(actual_mask) else actual_mask,
                                            "mask_center_pixel":center_pixel,
                                            "mask_center_pixel_depth":center_pixel_depth,
                                            "mask_all_pixels":mask_pixel_coords,
                                            "mask_all_pixels_depth":pixel_depths,
                                            "mask_depth":mask_depth,
                                            "sam_embedding":sam_embedding.cpu() if torch.is_tensor(sam_embedding) else sam_embedding,
                                            "origin_nodeimg":observation_graph.nodes[node_idx]['rgb_pil'],
                                            "origin_nodedepthimg":depth_data,
                                            "origin_nodepose":observation_graph.nodes[node_idx]['pose'],
                                            "worldframe_3d_position":transformed_point #if bad_point==False else None
                                            })        
        self.grounded_elements.extend(tracking_ids)
    
        if visualize: 
            # save anotated images to tmp folder
            file_name = f"observation_{node_idx}.png"
            annotated = self.plot_boxes(observation_graph.nodes[node_idx]['rgb_pil'], bounds, labels, confidence,plt_size=8,file_name=file_name) #visualize grounding results
        
    return node_element_details




## Example of getting single pixel xyz

In [None]:
# Example get 
batch_index = 0  # Batch index
y = 100  # Y-coordinate (height)
x = 200  # X-coordinate (width)

# Get the XYZ coordinates for the specific pixel
xyz_coordinates = xyz_output[batch_index, y, x, :]

print(xyz_coordinates)

In [None]:
import numpy as np

#load in npy file
xyz_output = np.load('results/relevant_elements_alldetails.npy',allow_pickle=True)

In [None]:
xyz_output[0]