# Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"  # Limit PyTorch to seeing 1 GPU only.
import glob
import yaml

from time import time
from collections import OrderedDict

import torch
import numpy as np

import src.data_augmentation as data_augmentation
import src.graph_construction as gc
import src.graph_networks as gn
import src.merge_split_networks as msn
import src.delete_network as delnet
import src.sample_tree_cem as stc
import src.network_config as nc
import src.losses as losses
import src.visualize_graph as visualize_graph
import src.constants as constants
import src.util.utilities as util_

# Choose image

In [None]:
# Load images
example_images_dir = os.path.abspath('.') + '/example_images/'
OSD_image_files = sorted(glob.glob(example_images_dir + '/OSD_*.npy'))
OCID_image_files = sorted(glob.glob(example_images_dir + '/OCID_*.npy'))
N = len(OSD_image_files) + len(OCID_image_files)

orig_rgb_imgs = np.zeros((N, 480, 640, 3), dtype=np.uint8)
rgb_imgs = np.zeros((N, 480, 640, 3), dtype=np.float32)
xyz_imgs = np.zeros((N, 480, 640, 3), dtype=np.float32)
label_imgs = np.zeros((N, 480, 640), dtype=np.uint8)

for i, img_file in enumerate(OSD_image_files + OCID_image_files):
    d = np.load(img_file, allow_pickle=True, encoding='bytes').item()
    
    # RGB
    orig_rgb_imgs[i] = d['rgb']
    rgb_imgs[i] = data_augmentation.standardize_image(orig_rgb_imgs[i])

    # XYZ
    xyz_imgs[i] = d['xyz']

    # Label
    label_imgs[i] = d['label']
    
# Choose image
img_index = 0  # TODO: select an index from [0, 1, 2, 3, 4]
rgb_img = data_augmentation.array_to_tensor(rgb_imgs[img_index])
xyz_img = data_augmentation.array_to_tensor(xyz_imgs[img_index])
label_img = data_augmentation.array_to_tensor(label_imgs[img_index])

# Display
util_.gallery([util_.visualize_segmentation(orig_rgb_imgs[i], label_imgs[i])
               for i in range(N)],
              width='300px')

# Load ResNet50+FPN

Pre-trained on COCO2017.

In [None]:
rn50_fpn = gc.get_resnet50_fpn_model(pretrained=True)
with torch.no_grad():
    rgb_img_features = gc.extract_rgb_img_features(rn50_fpn, rgb_img)

# Use UOIS-Net-3D to get masks.

Obtain initial instance segmentation masks. You can find the models for UOIS-Net-3D [here](https://github.com/chrisdxie/uois).

In [None]:
# Load UOIS-Net-3D
uoisnet3d_cfg_filename = 'configs/uoisnet3d.yaml'
dsn_filename = ''  # TODO: add path to saved model for UOIS-Net-3D
rrn_filename = ''  # TODO: add path to saved model for UOIS-Net-3D
uois_net = util_.load_uoisnet_3d(
    uoisnet3d_cfg_filename,
    dsn_filename,
    rrn_filename,
)

# Put data into a batch
batch = {
    'rgb' : rgb_img.unsqueeze(0),
    'xyz' : xyz_img.unsqueeze(0),
}
N = batch['rgb'].shape[0]
print("Number of images: {0}".format(N))

### Compute segmentation masks ###
st_time = time()
fg_masks, center_offsets, initial_masks, seg_masks = uois_net.run_on_batch(batch)
total_time = time() - st_time
print('Total time taken for Segmentation: {0} seconds'.format(round(total_time, 3)))
print('FPS: {0}'.format(round(N / total_time,3)))

seg_masks = seg_masks[0]  # [H, W]
fg_mask = fg_masks[0] == constants.OBJECTS_LABEL  # [H, W]

# Display results
util_.gallery({'Instance Masks' : util_.get_color_mask(seg_masks.cpu().numpy()),
               'Foreground Mask' : util_.get_color_mask(fg_mask.cpu().numpy())
              }, width='300px')

# Construct and visualize segmentation graph

In [None]:
# Construct segmentation graph
segmentation_graph = gc.construct_segmentation_graph(
    rgb_img_features,
    xyz_img,
    seg_masks,
)

# Visualization
viz_graph = visualize_graph.visualize_graph(orig_rgb_imgs[img_index],
                                            segmentation_graph,
                                            mode='seg_graph_on_rgb')
image_dict = OrderedDict({
    'Original RGB' : orig_rgb_imgs[img_index],
    'Predicted Mask' : util_.get_color_mask(util_.copy_to_numpy(seg_masks)),
    'Graph' : viz_graph,
    'GT Mask' : util_.get_color_mask(util_.copy_to_numpy(label_img)),
})
print("Unique labels:", torch.unique(seg_masks))

# Compute GT score of graph
score = losses.compute_graph_score(segmentation_graph.orig_masks[1:],  # Drop BG mask
                                   label_img)
print(f"Predicted mask score: {score:.05f}")

# Display
util_.gallery(image_dict, width='230px')

# Load Networks

You can find the models for RICE [here](https://github.com/chrisdxie/rice).

In [None]:
# SplitNet
splitnet_config = nc.get_splitnet_config('configs/splitnet.yaml')
sn_wrapper = msn.SplitNetWrapper(splitnet_config)
sn_filename = ''  # TODO: add path to saved model
sn_wrapper.load(sn_filename)

# MergeNet (uses SplitNet under the hood)
merge_net_config = splitnet_config.copy()
merge_net_config['splitnet_model'] = sn_wrapper.model
mn_wrapper = msn.MergeBySplitWrapper(merge_net_config)

# DeleteNet
deletenet_config = nc.get_deletenet_config('configs/deletenet.yaml')
dn_wrapper = delnet.DeleteNetWrapper(deletenet_config)
delnet_filename = ''  # TODO: add path to saved model
dn_wrapper.load(delnet_filename)

# SGS-Net
sgsnet_config = nc.get_sgsnet_config('configs/sgsnet.yaml')
sgsnet_wrapper = gn.SGSNetWrapper(sgsnet_config)
sgsnet_filename = ''  # TODO: add path to saved model
sgsnet_wrapper.load(sgsnet_filename)

# Run SampleTreeCEM

Note: run these cells a few times to obtain different outputs since RICE is stochastic.

In [None]:
# Load RICE
with open('configs/rice.yaml', 'r') as f:
    rice_config = yaml.load(f)
sample_operator_networks = {
    'mergenet_wrapper' : mn_wrapper,
    'splitnet_wrapper' : sn_wrapper,
    'deletenet_wrapper' : dn_wrapper,
}
rice = stc.SampleTreeCEMWrapper(
    rn50_fpn,
    sample_operator_networks,
    sgsnet_wrapper,
    rice_config,
)

In [None]:
# Run RICE!
batch = {
    'rgb' : rgb_img.unsqueeze(0),
    'xyz' : xyz_img.unsqueeze(0),
    'seg_masks' : seg_masks.unsqueeze(0),
    'fg_mask' : fg_mask.unsqueeze(0),
}
sample_tree = rice.run_on_batch(batch, verbose=True)  # this is where the magic happens!

# Get GT scores
def _get_gt_score(graph):
    return losses.compute_graph_score(graph.orig_masks[1:], label_img)  # Drop BG mask
id_to_gt_score = {node.id : node.graph for node in sample_tree.all_nodes()}
id_to_gt_score = util_.parallel_map_dict(_get_gt_score, id_to_gt_score)
    
import gc as garbage_collection
garbage_collection.collect()
torch.cuda.empty_cache()

In [None]:
# Plot sample tree
for depth in range(sample_tree.max_depth() + 1):

    nodes_at_depth = [node for node in sample_tree.all_nodes()
                      if node.depth() == depth]
    image_dict = OrderedDict()
    
    for node in nodes_at_depth:
        title = (node.id + 
                 f' GT: {id_to_gt_score[node.id].item():0.3f}.' + 
                 f' Pred: {node.graph.sgs_net_score.item():0.3f}')
        image_dict[title] = visualize_graph.visualize_graph(rgb_img,
                                                            node.graph,
                                                            mode='seg_graph_on_rgb')
            
    util_.gallery(image_dict, width='225px')

In [None]:
# Plot RICE outputs (best graph and uncertainty estimates)
scores = np.array([g.sgs_net_score for g in sample_tree.all_graphs()])
best_node = sample_tree.all_nodes()[np.argmax(scores)]

# Print some stuff
print(f"Best graph: {best_node.id}, "
      f"GT score: {id_to_gt_score[best_node.id]:.05f}, "
      f"Best GT score: {max(id_to_gt_score.values()):.05f}, ",
      f"score: {best_node.graph.sgs_net_score:.05f}")

# Generate some nice images
base_prediction_img = visualize_graph.visualize_graph(orig_rgb_imgs[img_index],
                                                      segmentation_graph,
                                                      mode='seg_graph_on_rgb')
prediction_img = visualize_graph.visualize_graph(orig_rgb_imgs[img_index],
                                                 best_node.graph,
                                                 mode='seg_graph_on_rgb')
contour_mean, contour_std = rice.contour_uncertainties(sample_tree)
contour_img = util_.visualize_contour_img(contour_mean,
                                          contour_std,
                                          orig_rgb_imgs[img_index])

# Display
to_plot = {
    f'RGB' : orig_rgb_imgs[img_index],
    'Original graph': base_prediction_img,
    'Best graph': prediction_img,
    'Contour Mean/Stddev' : contour_img,
}
util_.gallery(to_plot, width='300px')