### Demo of OCTO+ and other object placement methods

Install

In [None]:
!git clone https://github.com/octo-pearl/octo-pearl.git
%cd octo-pearl
%pip install -qe .

Download Weights

In [None]:
!mkdir -p weights
!wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
!wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth
!wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth

Define Input

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
OBJECT_TO_PLACE = "cupcake"
IMAGE_PATH = "assets/test_img.jpg"

image = Image.open(IMAGE_PATH)
plt.imshow(image)
plt.axis('off')
plt.show()

### Stage 1: Image Understanding
Generate and filter a list of "tags", or objects, in the image

Note: These functions may take a while to run the first time they are called, as they have to load the models from the disk into the CPU/GPU. Subsequent calls should be much faster.

In [None]:
from octo_pearl.placement.tagging import get_tags_scp, get_tags_gpt4v, get_tags_ram
# tags = get_tags_scp(image)
# tags = get_tags_gpt4v(image)
tags = get_tags_ram(image, threshold_multiplier=0.8)
tags

In [None]:
from octo_pearl.placement.filtering import filter_tags_clipseg, filter_tags_vilt, filter_tags_gdino
# filtered_tags = filter_tags_clipseg(image, tags)
# filtered_tags = filter_tags_vilt(image, tags)
filtered_tags = filter_tags_gdino(image, tags)
filtered_tags

### Stage 2: Reasoning
Select which tag the chosen object should be placed on

Note: You can directly pass in your `OPENAI_API_KEY` as a parameter to the `select_best_tag` function. You can also add your `OPENAI_API_KEY` to the environment variable (.env) file in `octo-pearl/placement/.env` path. The .env file should contain the following:
```
OPENAI_API_KEY=<your_api_key>
```

For `select_best_tag` the function signature is:
```
select_best_tag(filtered_tags: List[str], object_to_place: str, api_key: str = "") -> str
```

In [None]:
from octo_pearl.placement.selecting import select_best_tag
selected_object = select_best_tag(filtered_tags, OBJECT_TO_PLACE, api_key="<your_api_key>")
selected_object

### Stage 3: Locating
Select a 2D location corresponding to the selected tag

In [None]:
from octo_pearl.placement.locating import get_location_clipseg, get_location_gsam
# x, y = get_location_clipseg(image, selected_object)
x, y = get_location_gsam(image, selected_object)
plt.imshow(image)
plt.scatter(x, y, c="red", s=50)
plt.axis('off')
plt.show()

### Evaluating
For the evaluation, we will use an image from NYU Depth Dataset V2

In [None]:
from octo_pearl.eval.pearl import placement_score
IMAGE_NAME = "000749.png"
image = Image.open(f"octo_pearl/eval/data/images/{IMAGE_NAME}")
plt.imshow(image)
plt.show()

We will use the object "computer", since it is included in PEARL for this image. This is the PEARL segmentation mask of valid locations for a computer:

In [None]:
from octo_pearl.eval.pearl import create_mask
OBJECT_TO_PLACE = "computer"
mask = create_mask(IMAGE_NAME, OBJECT_TO_PLACE)
plt.imshow(mask)
plt.show()

A computer would appear natural if it were centered on the table, at around (100, 325)

In [None]:
x = 100
y = 325
plt.imshow(image)
plt.scatter(x, y, c="red", s=50)
in_mask, pearl_score = placement_score(IMAGE_NAME, OBJECT_TO_PLACE, x, y)
print(f"In mask: {in_mask}")
print(f"PEARL score: {pearl_score}")

The computer would look less natural if it were at the edge of the table, at around (200, 325)

In [None]:
x = 200
y = 325
plt.imshow(image)
plt.scatter(x, y, c="red", s=50)
in_mask, pearl_score = placement_score(IMAGE_NAME, OBJECT_TO_PLACE, x, y)
print(f"In mask: {in_mask}")
print(f"PEARL score: {pearl_score}")

The computer would look very unnatural if it were on a cabinet, at around (200, 50)

In [None]:
x = 200
y = 50
plt.imshow(image)
plt.scatter(x, y, c="red", s=50)
in_mask, pearl_score = placement_score(IMAGE_NAME, OBJECT_TO_PLACE, x, y)
print(f"In mask: {in_mask}")
print(f"PEARL score: {pearl_score}")