In [None]:
# Copyright (c) 2023 William Locke

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

In [None]:
# NEON Tree Crowns / Evaluation Dataset copyright (c) 2023 weecology

# This work is licensed under CC BY 4.0

# NEON Tree Crowns website: https://www.weecology.org/data-projects/neon-crowns/
# NEON Tree Evaluation Github: https://github.com/weecology/NeonTreeEvaluation

This notebook is intended to be run in Google Colab with access to corresponding Google Drive files. If running locally or on another service, change import and install code accordingly.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lu-liang-geo/UAV_Tree_Detection/blob/main/notebooks/NEON_Tree_Evaluation_EDA.ipynb)

**NEON Tree Training Dataset**

There are only 16 images in the NEON training set, but they have significant variation in size and number of trees captured. The smallest image is 888 x 1153 pixels while the largest (five) images are 10,000 x 10,000 pixels. Most are in the low thousands of pixels per side. (These sizes are for the RGB images; the Hyperspectral and CHM images are all roughly 10x smaller than their corresponding RGB images, but cover the same area just with a lower resolution.)

The differing sizes and different tree densities also results in different numbers of trees being captured per image. Most images are in the hundreds to low thousands (less than 2000), but the two lowest-count images have 1 and 40 trees respectively, and the two highest have 3670 and 9730 respectively. Bizarrely, the low-tree outliers are two of the largest images (10,000 x 10,000), and the high-tree outliers are smaller images.

All NEON images seem to have roughly the same spatial resolution as each other, and this resolution is slightly less than our images (i.e. a single pixel covers more area in the NEON dataset, resulting in less details captured). I'm not sure the actual range of spatial resolutions in the dataset.

All of this makes me wonder if we should reconsider cropping our training data to fixed or at least smaller size before passing it to the model, and if so what size that should be with what overlap. It would require additional work to make sure the proper bounding boxes get assigned to the proper image crops.

**NEON Tree Evaluation Dataset**

There are roughly 200 images in the NEON evaluation dataset, and they all have a uniform size of 400 x 400 for the RGB channels and 40 x 40 for the Hyperspectral and CHM channels (so the same 10x size difference between RGB and other channels as in the training set). The evaluation images are drawn from a larger set of forests than the training images, though there is some overlap between them (in terms of forests, presumably not individual images). I'm not sure if some of the evaluation images are cropped from larger, contiguous images, and if so whether they could be recombined back into larger images (or if there would be any advantage to doing so).

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
!unzip '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/training.zip' -d "/content/training"
!unzip '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/evaluation.zip' -d "/content"
!unzip '/content/drive/MyDrive/UAV/Data/NEONTreeEvaluation/annotations.zip' -d "/content"

In [None]:
%%capture
!pip install rasterio
!pip install supervision

In [None]:
import os
import glob
import torch
import rasterio
import numpy as np
import supervision as sv
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from collections import defaultdict

## NEONTreeDataset class definition

In [None]:
'''
This class definition is similar to the one found in Github, but it also gives access to the raw values of
several image channels for analysis purposes.
'''

class NEONTreeDataset(torch.utils.data.Dataset):
  def __init__(self, image_path, ann_path, prompt_path=None, check_values=False):
    '''
    params
      image_path (str): Path to top-level training image directory (contains RGB, LiDAR, Hyperspectral, CHM)
      ann_path (str):  Path to annotations

    __init__ stores filenames from RGB folder, which will be used to retrieve relevant files from other
    folders. Because RGB folder has four files not found in Hyperspectral or CHM, we remove any entries
    not found in these other folders. We also manually remove four images we consider unsuitable for
    training / evaluation, either because they have a large number of invalid pixels or they don't have
    good annotations.
    '''
    self.check_values = check_values
    self.image_path = image_path
    self.ann_path = ann_path
    self.prompt_path = prompt_path
    file_names = os.listdir(os.path.join(image_path, 'RGB'))
    problem_files = set(['2019_SJER_4_251000_4103000_image', '2019_TOOL_3_403000_7617000_image', 'TALL_043_2019', 'SJER_062_2018'])
    basenames = [name.split('.')[0] for name in file_names]
    basenames = [name for name in basenames if name not in problem_files]
    basenames = [name for name in basenames if os.path.exists(os.path.join(image_path, 'Hyperspectral', f'{name}_hyperspectral.tif'))]
    basenames = [name for name in basenames if os.path.exists(os.path.join(ann_path,f'{name}.xml'))]
    basenames = [name for name in basenames if os.path.exists(os.path.join(image_path, 'CHM', f'{name}_CHM.tif'))]
    self.basenames = list(set(basenames))

  def __len__(self):
    return len(self.basenames)

  def __getitem__(self, idx):
    '''
    returns
      annotated_image (dict) with keys:
        rgb_img: HxWxC ndarray of RGB channels
        multi_img: HxWxC ndarray of multi channels
        prompt (if prompt_path specified): Nx4 ndarray of prompt bounding boxes in XYXY format
        annotation: Nx4 ndarray of ground truth bounding boxes in XYXY format

        ndvi_raw: HxWxC ndarray of NDVI values, with NaN values replaced by 4 to aid in visualization


    Currently I hardcode the multi_img channels, but later I may allow the user to specify them in __init__
    '''
    basename = self.basenames[idx]
    rgb_path = os.path.join(self.image_path, 'RGB', f'{basename}.tif')
    chm_path = os.path.join(self.image_path, 'CHM', f'{basename}_CHM.tif')
    hs_path = os.path.join(self.image_path, 'Hyperspectral', f'{basename}_hyperspectral.tif')
    ann_path = os.path.join(self.ann_path, f"{basename}.xml")
    if self.prompt_path:
      prompt_path = os.path.join(self.prompt_path, 'Boxes', f"{basename}.npy")
    annotated_image = {'basename': basename}

    # Open RGB path and two paths used to construct multi image. Save rgb_img in annotated_image.
    with rasterio.open(rgb_path) as img:
      rgb_img = img.read().transpose(1,2,0)
    with rasterio.open(hs_path) as img:
      hs_img = img.read()
    with rasterio.open(chm_path) as img:
      chm_img = img.read()
    annotated_image['rgb'] = rgb_img

    # Remove blank rows or columns from edges of CHM and Hyperspectral images, based on null value of -9999.0 in CHM.
    if chm_img[0,0,1]==-9999.0:
      chm_img = chm_img[:,1:,:]
      hs_img = hs_img[:,1:,:]
    if chm_img[0,1,0]==-9999.0:
      chm_img = chm_img[:,:,1:]
      hs_img = hs_img[:,:,1:]
    if chm_img[0,-1,-2]==-9999.0:
      chm_img = chm_img[:,:-1,:]
      hs_img = hs_img[:,:-1,:]
    if chm_img[0,-2,-1]==-9999.0:
      chm_img = chm_img[:,:,:-1]
      hs_img = hs_img[:,:,:-1]

    assert not (chm_img==-9999.0).any()

    # Select NIR, Red, and Red-Edge channels based on frequency reference chart
    # https://github.com/weecology/NeonTreeEvaluation/blob/master/neon_aop_bands.csv
    # Note that reference chart starts at 1 while we start at 0, so add 1 to the numbers below to get corresponding chart row.
    nir = 95
    red = 53
    edge = 69

    # Extract NIR, Red, and Red-Edge channels from Hyperspectral Image. Make CHM single channel.
    nir_img = hs_img[nir]
    red_img = hs_img[red]
    edge_img = hs_img[edge]
    chm_img = chm_img[0]

    # Check for any negative values in NIR, Red, Red-Edge, and CHM channels, which are likely invalid pixels.
    if self.check_values:
      assert rgb_img.min() >= 0, f'{basename} RGB values below 0 (min value: {rgb_img.min()}), check source image'
      assert nir_img.min() >= 0, f'{basename} NIR values below 0 (min value: {nir_img.min()}), check source image.'
      assert red_img.min() >= 0, f'{basename} Red values below 0 (min value: {red_img.min()}), check source image.'
      assert edge_img.min() >= 0, f'{basename} Red Edge values below 0 (min value: {edge_img.min()}), check source image.'
      assert chm_img.min() >= 0, f'{basename} Canopy Height values below 0 (min value: {chm_img.min()}), check source image.'

    # Set NaN and values less than 0 equal to 0. This allows for processing images with invalid pixel values
    # in some channels (we specifically allow this in 2019_OSBS_5 in the training set).
    rgb_img[rgb_img<0] = 0
    nir_img[nir_img<0] = 0
    red_img[red_img<0] = 0
    edge_img[edge_img<0] = 0
    chm_img = np.nan_to_num(chm_img, nan=0.0)

    # Save non-standardized channel values for analysis
    annotated_image['nir_raw'] = nir_img
    annotated_image['red_raw'] = red_img
    annotated_image['edge_raw'] = edge_img
    annotated_image['chm_raw'] = chm_img

    # Create NDVI from NIR and Red channels. NaN values (where NIR and Red are both 0) converted to 0 for multi_img,
    # but set to 4 for ndvi_raw to highlight locations where NDVI can't be calculated normally.
    _ndvi_img = (nir_img - red_img) / (nir_img + red_img)
    ndvi_img = np.nan_to_num(_ndvi_img, nan=0)
    if not np.isfinite(_ndvi_img).all():
      ndvi_raw = np.nan_to_num(_ndvi_img, nan=4)
      annotated_image['ndvi_raw'] = ndvi_raw

    # Standardize Red-Edge channel
    edge_img = (edge_img - edge_img.mean()) / edge_img.std()

    # Standardize CHM channel
    chm_img = (chm_img - chm_img.mean()) / chm_img.std()

    # Create multi-channel image of chm, NDVI, and Red-Edge, save in annotated_image.
    multi_img = np.stack([chm_img, ndvi_img, edge_img], axis=-1).astype('float32')
    annotated_image['multi'] = multi_img

    # If Prompt Boxes have already been generated (and self.prompt_path is not None), load prompt boxes
    # and save in annotated_image.
    if self.prompt_path:
      prompt = np.load(prompt_path)
      annotated_image['prompt'] = prompt

    # Extract bounding boxes from annotations, save in annotated_image.
    xyxy = []
    tree = ET.parse(ann_path)
    root = tree.getroot()
    for obj in root.findall('object'):
      name = obj.find('name').text
      if name == 'Tree':
        bbox = obj.find('bndbox')
        xyxy.append([int(bbox[i].text) for i in range(4)])
    annotation = np.array(xyxy)
    annotated_image['annotation'] = annotation

    return annotated_image

  def get_image(self, basename, return_index=False):
    index = self.basenames.index(basename)
    if return_index:
      return index
    else:
      return self.__getitem__(index)

In [None]:
train_ds = NEONTreeDataset('/content/training', '/content/annotations', check_values=False)
val_ds = NEONTreeDataset('/content/evaluation', '/content/annotations', check_values=False)

## Image Stats

As mentioned above, both the training and validation datasets have different sized images (RGB and Multi) and different numbers of trees. This next code block saves text summaries of these statistics, both for individual images and for forest groups.

In [None]:
#@title Save Image Names, Sizes, Num Trees, and Forest Groups

train_maps = defaultdict(int)
train_trees = defaultdict(int)

# Save text description of dataset RGB image sizes and tree counts
f = open('NEON_train.txt', 'w')
f.write(f'{"Filename":<45} {"RGB Image Size":<20} {"Multi Image Size":<20} {"Num Trees"}\n')
for img in train_ds:
  # Write individual image stats
  basename = img['basename']
  rgb_img = img['rgb']
  multi_img = img['multi']
  boxes = img['annotation']
  f.write(f'{basename:<45} {str(rgb_img.shape[:-1]):<20} {str(multi_img.shape[:-1]):<20} {boxes.shape[0]}\n')

  # Record group stats
  group_name = [name for name in basename.split('_') if name.isupper()][0]
  train_maps[group_name] += 1
  train_trees[group_name] += boxes.shape[0]
f.close()

# Write group stats
f = open('NEON_train_groups.txt', 'w')
f.write(f'{"Group Name":<15} {"Num Maps":<10} {"Num Trees"}\n')
for key in train_maps.keys():
  f.write(f'{key:<15} {train_maps[key]:<10} {train_trees[key]}\n')
f.close()


val_maps = defaultdict(int)
val_trees = defaultdict(int)

# Save text description of dataset RGB image sizes and tree counts
f = open('NEON_eval.txt', 'w')
f.write(f'{"Filename":<45} {"RGB Image Size":<20} {"Multi Image Size":<20} {"Num Trees"}\n')
for img in val_ds:
  # Write individual image stats
  basename = img['basename']
  rgb_img = img['rgb']
  multi_img = img['multi']
  boxes = img['annotation']
  f.write(f'{basename:<45} {str(rgb_img.shape[:-1]):<20} {str(multi_img.shape[:-1]):<20} {boxes.shape[0]}\n')

  # Record group stats
  group_name = [name for name in basename.split('_') if name.isupper()][0]
  val_maps[group_name] += 1
  val_trees[group_name] += boxes.shape[0]
f.close()

# Write group stats
f = open('NEON_eval_groups.txt', 'w')
f.write(f'{"Group Name":<15} {"Num Maps":<10} {"Num Trees"}\n')
for key in val_maps.keys():
  f.write(f'{key:<15} {val_maps[key]:<10} {val_trees[key]}\n')
f.close()

## Show Images

The following code blocks allow a user to view any RGB image in the chosen dataset, as well as any channel in the Multi image.

The user can specify the image by index in the dataset, using `img = ds[index]` (note that this index is subject to change each time the dataset is initialized), or by name using the method `img = ds.get_image(name)`.

In [None]:
#@title Choose Image

# Choose dataset, train_ds or val_ds
ds = val_ds

# Choose image, x = an integer for choosing by index or x = a string for choosing by image name
x = 'DELA_047_2019'

if isinstance(x, int):
  img = ds[x]
else:
  img = ds.get_image(x)

In [None]:
#@title Show RGB Image

# Choose whether to include ground truth annotations with image
annotations = True
box_annotator = sv.BoxAnnotator(thickness=2, color=sv.Color.red())

plt.figure(figsize=(10,10))
plt.axis('off')
rgb_img = img['rgb']
boxes = img['annotation']
if annotations:
  boxes = sv.Detections(xyxy=boxes, confidence=np.ones(len(boxes)))
  bgr_img = box_annotator.annotate(scene=rgb_img[:,:,::-1].copy(), detections=boxes, skip_label=True)
  rgb_img = bgr_img[:,:,::-1]
plt.imshow(rgb_img)
plt.show()

In [None]:
#@title Visualize Canopy Height Model

plt.figure(figsize=(10,10))
plt.axis('off')
chm_img = img['multi'][:,:,0]
plt.imshow(chm_img, cmap='viridis')
plt.show()

In [None]:
#@title Visualize Red-Edge Channel

plt.figure(figsize=(10,10))
plt.axis('off')
edge_img = img['multi'][:,:,2]
plt.imshow(edge_img, cmap='viridis')
plt.show()

For some images, there are pixels where both the Red channel and the NIR channel have a value of zero. Because of this, attempting to calculate `NDVI = (NIR - Red) / (NIR + Red)` returns a division-by-zero error, and these pixels are assigned a NaN value. For processing the Multi image, I simply replace these values with 0, which is the midpoint of possible NDVI values (-1 to 1). However, I'm still not sure this is the most appropriate solution. In the code block below, if there are pixels where NDVI cannot be calculated, there will be two visualizations -- one with those pixels filled in with a value of 0, as they are in the Multi image, and one with those pixels filled in with a value of 4, which causes them to clearly stand out against a much darker background. The two visualizations can then be compared to see what pixels have been filled in with 0 in the Multi image and whether this seems an appropriate choice.

In [None]:
#@title Visualize NDVI Channel
if 'ndvi_raw' in img.keys():
  fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(8,16))
  for ax in axs:
    ax.axis('off')
  ndvi = img['multi'][:,:,1]
  ndvi_raw = img['ndvi_raw']
  axs[0].imshow(ndvi, cmap='viridis')
  axs[1].imshow(ndvi_raw, cmap='viridis')
  fig.tight_layout()
  fig.show()
else:
  plt.figure(figsize=(10,10))
  plt.axis('off')
  ndvi = img['multi'][:,:,1]
  plt.imshow(ndvi)
  plt.show()

## Multi Channel Statistics

The Multi image is composed of three channels: Canopy Height Model (CHM), Normalized Difference Vegetation Index (NDVI), and Red-edge. We selected these based on their availability in the data and what we thought might be most helpful for distinguishing individual trees from each other and from other objects.

The Multi image has three channels because it is being passed through the SAM encoder, which only accepts images with three channels (usually RGB). In the future, it might be worth training a new encoder to handle Multi images separately, but for now we are passing both RGB and Multi images through the same encoder, and so they must have the same number of channels (more on this below).

In [None]:
# Collect all NIR, Red, CHM, and Red-edge raw image values
nir_raw = np.stack([img['nir_raw'] for img in ds])
red_raw = np.stack([img['red_raw'] for img in ds])
chm_raw = np.stack([img['chm_raw'] for img in ds])
edge_raw = np.stack([img['edge_raw'] for img in ds])

# Collect all NDVI image values after NaN replacement and
# all CHM and Red-edge image values after standardization
chm_all = np.stack([img['multi'][:,:,0] for img in ds])
ndvi_all = np.stack([img['multi'][:,:,1] for img in ds])
edge_all = np.stack([img['multi'][:,:,2] for img in ds])

The SAM encoder has its own saved values used to standardize RGB values, but these will not work for the Multi images. As a result, we must standardize them ourselves. The code block below shows the values of these Multi channels before and after standardization.

In [None]:
#@title Multi Channel Values
print('BEFORE Standardization')
print(f'{"CHM":<10} Max Value: {chm_raw.max():>7.2f}    Min Value: {chm_raw.min():>5.2f}    Mean Value: {chm_raw.mean():>7.2f}    STD: {chm_raw.std():>7.2f} \n')
print(f'{"NIR":<10} Max Value: {nir_raw.max():>7.2f}    Min Value: {nir_raw.min():>5.2f}    Mean Value: {nir_raw.mean():>7.2f}    STD: {nir_raw.std():>7.2f} \n')
print(f'{"Red":<10} Max Value: {red_raw.max():>7.2f}    Min Value: {red_raw.min():>5.2f}    Mean Value: {red_raw.mean():>7.2f}    STD: {red_raw.std():>7.2f} \n')
print(f'{"Red Edge":<10} Max Value: {edge_raw.max():>7.2f}    Min Value: {edge_raw.min():>5.2f}    Mean Value: {edge_raw.mean():>7.2f}    STD: {edge_raw.std():>7.2f} \n')
print()
print('After Standardization')
print(f'{"CHM":<10} Max Value: {chm_all.max():>7.2f}    Min Value: {chm_all.min():>5.2f}    Mean Value: {chm_all.mean():>7.2f}    STD: {chm_all.std():>7.2f} \n')
print(f'{"NDVI":<10} Max Value: {ndvi_all.max():>7.2f}    Min Value: {ndvi_all.min():>5.2f}    Mean Value: {ndvi_all.mean():>7.2f}    STD: {ndvi_all.std():>7.2f} \n')
print(f'{"Red Edge":<10} Max Value: {edge_all.max():>7.2f}    Min Value: {edge_all.min():>5.2f}    Mean Value: {edge_all.mean():>7.2f}    STD: {edge_all.std():>7.2f} \n')

We want to know if the channels in the Multi image carry unique information, so we aren't simply repeating the same information across three channels. Calculating the Correlation Matrix between CHM, NDVI, and Red-edge channels shows that they are not strongly correlated, meaning each channel should be carrying unique information from the others. In fact, if we compare this to the RGB channels, they are much more strongly correlated than the Multi channels. (However, because we are passing both images through the same encoder, this could also mean that the encoder would have a harder time interpreting the Multi image; more on this below.)

In [None]:
#@title Multi Channel Correlation Matrix

chm_flat = chm_all.flatten()
ndvi_flat = ndvi_all.flatten()
edge_flat = edge_all.flatten()

multi_channels = np.stack([chm_flat, ndvi_flat, edge_flat])
r_multi = np.corrcoef(multi_channels)
print('Multi Channel Correlation Matrix \n')
print(r_multi)

In [None]:
rgb_all = np.stack([img['rgb'] for img in ds])

In [None]:
red_flat = rgb_all[...,0].flatten()
green_flat = rgb_all[...,1].flatten()
blue_flat = rgb_all[...,2].flatten()

rgb_channels = np.stack([red_flat, green_flat, blue_flat])
r_rgb = np.corrcoef(rgb_channels)
print('RGB Channel Correlation Matrix \n')
print(r_rgb)

## SAM Encoder on Multi Images

As stated above, both the RGB and Multi images are passed through the same SAM encoder, which was trained on roughly 11 million images and so is a very powerful tool for extracting meaningful representations of images. However, it was trained exclusively on RGB images, and so it's questionable whether the representations it creates of non-RGB images would be as useful. Ultimately, this will have to be experimentally tested with our own Box Decoder model, which will learn to interpret the outputs of the SAM encoder both on RGB and Multi images. However, as an early proxy test, we can see how SAM's Mask Decoder (which was also trained exclusively on RGB images) handles the encodings of Multi images vs RGB images.

In [None]:
# Copy SAM from personal github repository

%%capture
%cd /content
import os
if os.path.exists('/content/UAV_Tree_Detection'):
  !rm -r /content/UAV_Tree_Detection
!git clone https://github.com/lu-liang-geo/UAV_Tree_Detection.git
%cd /content/UAV_Tree_Detection
!pip install -q .
!mkdir /content/weights
%cd /content/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
%cd /content

In [None]:
from segment_and_detect_anything import sam_model_registry, SamPredictor

In [None]:
# Initialize SAM model and predictor
sam_model = sam_model_registry["vit_h"](checkpoint="/content/weights/sam_vit_h_4b8939.pth")
sam_predictor = SamPredictor(sam_model)

In [None]:
def segment(sam_predictor: SamPredictor, boxes: np.ndarray) -> np.ndarray:
    result_masks = []
    for box in boxes:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            output_type='mask',
            multimask_output=False
        )
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)

In [None]:
# Embed RGB image and Multi image

sam_predictor.set_images(img['rgb'], img['multi'])
rgb_features, multi_features = sam_predictor.get_image_embedding()

In [None]:
#@title Segmentation with RGB embeddings

box_annotator = sv.BoxAnnotator(thickness=2, color=sv.Color.red())
mask_annotator = sv.MaskAnnotator()

rgb_img = img['rgb']
boxes = img['annotation']
rgb_masks = segment(sam_predictor, boxes)

plt.figure(figsize=(10,10))
plt.axis('off')
detections = sv.Detections(xyxy=boxes, confidence=np.ones(len(boxes)), mask=rgb_masks, class_id=np.zeros(len(boxes), dtype=np.int64))
bgr_img = box_annotator.annotate(scene=rgb_img[:,:,::-1].copy(), detections=detections, skip_label=True)
bgr_img = mask_annotator.annotate(scene=bgr_img.copy(), detections=detections)
rgb_img = bgr_img[:,:,::-1]
plt.imshow(rgb_img)
plt.show()

In [None]:
#@title Segmentation with Multi embeddings

sam_predictor.features = multi_features

box_annotator = sv.BoxAnnotator(thickness=2, color=sv.Color.red())
mask_annotator = sv.MaskAnnotator()

rgb_img = img['rgb']
boxes = img['annotation']
rgb_masks = segment(sam_predictor, boxes)

plt.figure(figsize=(10,10))
plt.axis('off')
detections = sv.Detections(xyxy=boxes, confidence=np.ones(len(boxes)), mask=rgb_masks, class_id=np.zeros(len(boxes), dtype=np.int64))
bgr_img = box_annotator.annotate(scene=rgb_img[:,:,::-1].copy(), detections=detections, skip_label=True)
bgr_img = mask_annotator.annotate(scene=bgr_img.copy(), detections=detections)
rgb_img = bgr_img[:,:,::-1]
plt.imshow(rgb_img)
plt.show()

Somewhat surprisingly, the Mask Decoder is able to interpret the Multi image embedding and outputs similar masks to those it outputs when relying on the RGB image embedding, though with a bit less accuracy. This suggests that the embeddings are similar enough to be useful in making the same sorts of predictions, despite encoding different information. Of course, this is a very small test -- it might be worth experimenting with other images and other combinations of channels to see if this similarity holds.