# Imports

In [None]:
%matplotlib inline
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = "5"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debugging GPU stuff
import time, random

from PIL import Image

import open3d  # Need to import this before torch
import torch
import torch.optim as optim
from tensorboardX import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt

from importlib import reload


sys.path.append('/home/chrisxie/projects/ssc/')
import util.flowlib as flowlib

In [None]:
# Config file
cfg_file = '/home/chrisxie/local_installations/PointGroup/config/pointgroup_TOD.yaml'
from util.config import get_parser_notebook
get_parser_notebook(cfg_file=cfg_file, pretrain_path=None)

from util.config import cfg
from util.log import logger
import util.utils as utils

# Testing function definitions

In [None]:
def init():
#     global result_dir
#     result_dir = os.path.join(cfg.exp_path, 'result', 'nmst{}_scoret{}_npointt{}'.format(cfg.TEST_NMS_THRESH, cfg.TEST_SCORE_THRESH, cfg.TEST_NPOINT_THRESH), cfg.split)
#     backup_dir = os.path.join(result_dir, 'backup_files')
#     os.makedirs(backup_dir, exist_ok=True)
#     os.makedirs(os.path.join(result_dir, 'predicted_masks'), exist_ok=True)
#     os.system('cp test.py {}'.format(backup_dir))
#     os.system('cp {} {}'.format(cfg.model_dir, backup_dir))
#     os.system('cp {} {}'.format(cfg.dataset_dir, backup_dir))
#     os.system('cp {} {}'.format(cfg.config, backup_dir))
# TODO(chrisdxie): Set directory to write results to: ~/projects/ssc/external

    global semantic_label_idx
    semantic_label_idx = [0, 1, 2]

    logger.info(cfg)

    random.seed(cfg.test_seed)
    np.random.seed(cfg.test_seed)
    torch.manual_seed(cfg.test_seed)
    torch.cuda.manual_seed_all(cfg.test_seed)

In [None]:
def non_max_suppression(ious, scores, threshold):
    ixs = scores.argsort()[::-1]
    pick = []
    while len(ixs) > 0:
        i = ixs[0]
        pick.append(i)
        iou = ious[i, ixs[1:]]
        remove_ixs = np.where(iou > threshold)[0] + 1
        ixs = np.delete(ixs, remove_ixs)
        ixs = np.delete(ixs, 0)
    return np.array(pick, dtype=np.int32)

In [None]:
# Some utils


def imwrite_indexed(filename,array):
    """ Save indexed png with palette."""

    palette_abspath = '/home/chrisxie/projects/random_stuff/palette.txt' # hard-coded filepath
    color_palette = np.loadtxt(palette_abspath, dtype=np.uint8).reshape(-1,3)

    if np.atleast_3d(array).shape[2] != 1:
        raise Exception("Saving indexed PNGs requires 2D array.")

    im = Image.fromarray(array)
    im.putpalette(color_palette.ravel())
    im.save(filename, format='PNG')


def normalize(x):
    return (x - x.min()) / (x.max() - x.min())


# reconstruct image
def inpaint_with_valid_mask(values, valid_mask, H=480, W=640):
    """Use valid_mask to inpaint values back into a HxW image."""
    
    if torch.is_tensor(values):
        values = values.cpu().numpy()
        
    if values.ndim == 1:
        values = values[..., None]
        channels = 1
    else:
        channels = values.shape[-1]
        
    img = np.zeros((H * W, channels), dtype=np.float32)
    img[valid_mask] = values
    
    if channels == 1:  # Return a [H, W] image instead of [H, W, 1]
        return img.reshape((H, W))
    else:
        return img.reshape((H, W, channels))

    
def inpaint_cluster_img(clusters, valid_mask, H=480, W=640):
    cluster_num = 2  # OBJECT starts here
    instances_img = np.zeros((H, W), dtype=np.uint8)
    for i in range(clusters.shape[0]):
        temp = inpaint_with_valid_mask(clusters[i], valid_mask).astype(np.uint8)
        assert np.all(~np.logical_and(instances_img > 0, temp > 0)), 'uh oh, the clusters overlap'
#         if np.sum(temp) < 500:
#             continue
        instances_img += temp * cluster_num
        cluster_num += 1
    return instances_img
    
    
def subplotter(images, max_plots_per_row=4, fig_index_start=1):
    """Plot images side by side.
    
    Args:
        images: an Iterable of [H, W, C] np.arrays. If images is
            a dictionary, the values are assumed to be the arrays,
            and the keys are strings which will be titles.
    """
    
    fig_index = fig_index_start
    
    num_plots = len(images)
    num_rows = int(np.ceil(num_plots / max_plots_per_row))

    for row in range(num_rows):

        fig = plt.figure(fig_index, figsize=(max_plots_per_row*5, 5))
        fig_index += 1

        for j in range(max_plots_per_row):

            ind = row*max_plots_per_row + j
            if ind >= num_plots:
                break

            plt.subplot(1, max_plots_per_row, j+1)
            if type(images) == dict:
                title = list(images.keys())[ind]
                image = images[title]
                plt.title(title)
            else:
                image = images[ind]
            plt.imshow(image)


# Testing script

In [None]:
##### init
init()

In [None]:
##### get model version and data version
exp_name = cfg.config.split('/')[-1][:-5]
print(exp_name)
model_name = exp_name.split('_')[0]
print(model_name)
data_name = exp_name.split('_')[1]
print(data_name)

In [None]:
##### model
logger.info('=> creating model ...')
if model_name == 'pointgroup':
    from model.pointgroup.pointgroup import PointGroup as Network
    from model.pointgroup.pointgroup import model_fn_decorator
else:
    print("Error: no model - " + model_name)
    exit(0)
model = Network(cfg)

use_cuda = torch.cuda.is_available()
logger.info('cuda available: {}'.format(use_cuda))
assert use_cuda
model = model.cuda()

# logger.info(model)
logger.info('#classifier parameters: {}'.format(sum([x.nelement() for x in model.parameters()])))

In [None]:
##### model_fn (criterion)
model_fn = model_fn_decorator(test=True)

##### load model
# utils.checkpoint_restore(model,
#                          cfg.exp_path,
#                          cfg.config.split('/')[-1][:-5],
#                          use_cuda,
#                          cfg.test_epoch)
# resume from the latest epoch, or specify the epoch to restore

In [None]:
##### dataset
if cfg.dataset == 'TOD':
    if data_name == 'TOD':
        import data.TOD
        data.TOD = reload(data.TOD)
        dataset = data.TOD.Dataset()
    else:
        print("Error: no data loader - " + data_name)
        exit(0)
elif cfg.dataset == 'OCID':
    if data_name == 'OCID':
        import data.OCID
        data.OCID = reload(data.OCID)
        dataset = data.OCID.Dataset()
    else:
        print("Error: no data loader - " + data_name)
        exit(0)
elif cfg.dataset == 'OSD':
    if data_name == 'OSD':
        import data.OSD
        data.OSD = reload(data.OSD)
        dataset = data.OSD.Dataset()
    else:
        print("Error: no data loader - " + data_name)
dataset.testLoader()

In [None]:
#### Test loop
from tqdm.notebook import tqdm

H = 480; W = 640
save_dir = '/home/chrisxie/projects/ssc/external/'
if data_name == 'OCID':
    save_dir = os.path.join(save_dir, 'OCID_results')
elif data_name == 'OSD':
    save_dir = os.path.join(save_dir, 'OSD_results')
elif data_name == 'TOD':
    save_dir = os.path.join(save_dir, 'TODv5_results', 'test_set')
else:
    raise NotImplementedError(f"Testing on {data_name} not implemented...")
save_dir = os.path.join(save_dir, 'PointGroup')

# num_iters_trained = cfg.test_epoch * len(dataset.train_data_loader)
num_iters_trained = 300000  # Force it to cluster
with torch.no_grad():
    model = model.eval()
    start = time.time()

    matches = {}
    for i, batch in tqdm(enumerate(dataset.test_data_loader)):
        
        N = batch['feats'].shape[0] 

        start1 = time.time()
        preds = model_fn(batch, model, num_iters_trained)
        end1 = time.time() - start1

        ##### get predictions (#1 semantic_pred, pt_offsets; #2 scores, proposals_pred)
        semantic_scores = preds['semantic']  # (N, nClass) float32, cuda
        semantic_pred = semantic_scores.max(1)[1]  # (N) long, cuda

        pt_offsets = preds['pt_offsets']    # (N, 3), float32, cuda

        scores = preds['score']   # (nProposal, 1) float, cuda
        scores_pred = torch.sigmoid(scores.view(-1))
        
        proposals_idx, proposals_offset = preds['proposals']
        # proposals_idx: (sumNPoint, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
        # proposals_offset: (nProposal + 1), int, cpu
        proposals_pred = torch.zeros((proposals_offset.shape[0] - 1, N), dtype=torch.int, device=scores_pred.device) # (nProposal, N), int, cuda
        proposals_pred[proposals_idx[:, 0].long(), proposals_idx[:, 1].long()] = 1

        semantic_id = torch.tensor(semantic_label_idx, device=scores_pred.device)[semantic_pred[proposals_idx[:, 1][proposals_offset[:-1].long()].long()]] # (nProposal), long
        
        ##### score threshold
        score_mask = (scores_pred > cfg.TEST_SCORE_THRESH)
        scores_pred = scores_pred[score_mask]
        proposals_pred = proposals_pred[score_mask]
        semantic_id = semantic_id[score_mask]

        ##### npoint threshold
        proposals_pointnum = proposals_pred.sum(1)
        npoint_mask = (proposals_pointnum > cfg.TEST_NPOINT_THRESH)
        scores_pred = scores_pred[npoint_mask]
        proposals_pred = proposals_pred[npoint_mask]
        semantic_id = semantic_id[npoint_mask]
        
       
        ##### nms
        if semantic_id.shape[0] == 0:
            pick_idxs = np.empty(0)
        else:
            proposals_pred_f = proposals_pred.float()  # (nProposal, N), float, cuda
            intersection = torch.mm(proposals_pred_f, proposals_pred_f.t())  # (nProposal, nProposal), float, cuda
            proposals_pointnum = proposals_pred_f.sum(1)  # (nProposal), float, cuda
            proposals_pn_h = proposals_pointnum.unsqueeze(-1).repeat(1, proposals_pointnum.shape[0])
            proposals_pn_v = proposals_pointnum.unsqueeze(0).repeat(proposals_pointnum.shape[0], 1)
            cross_ious = intersection / (proposals_pn_h + proposals_pn_v - intersection)
            pick_idxs = non_max_suppression(cross_ious.cpu().numpy(), scores_pred.cpu().numpy(), cfg.TEST_NMS_THRESH)  # int, (nCluster, N)
        clusters = proposals_pred[pick_idxs]
        cluster_scores = scores_pred[pick_idxs]
        cluster_semantic_id = semantic_id[pick_idxs]
        
        nclusters = clusters.shape[0]
        
        ##### Inpaint
        instances_imgs = np.zeros((batch['offsets'].shape[0]-1, H, W), dtype=np.uint8)
        for img_num in range(batch['offsets'].shape[0]-1):
            valid_mask = batch['valid_mask'][img_num * H*W : (img_num + 1) * H*W].numpy()
            instances_imgs[img_num] = inpaint_cluster_img(
                clusters[:, batch['offsets'][img_num] : batch['offsets'][img_num+1]],
                valid_mask)
        
        ##### Save 
        for i, path in enumerate(batch['label_abs_path']):
            file_path = os.path.join(save_dir, path.rsplit('/', 1)[0])
            if not os.path.exists(file_path):
                os.makedirs(file_path)
            file_name = os.path.join(file_path, path.rsplit('/', 1)[1].rsplit('.', 1)[0] + '.png')
            imwrite_indexed(file_name, instances_imgs[i].astype(np.uint8))

# Legacy Code

In [None]:
d_iter = iter(dataset.test_data_loader)

In [None]:
batch = next(d_iter)

In [None]:
img_num = 0

H = 480; W = 640
valid_mask = batch['valid_mask'][
        img_num * H*W : (img_num + 1) * H*W].numpy()
# plt.imshow(valid_mask.reshape(H,W), vmin=0, vmax=1)
rgb_vals = batch['feats'][batch['offsets'][img_num] : batch['offsets'][img_num+1]]
rgb_img = inpaint_with_valid_mask(rgb_vals, valid_mask)

plot_dict = {'rgb': normalize(rgb_img)}
if 'instance_labels' in batch:
    gt_instances_img = inpaint_with_valid_mask(batch['instance_labels'][batch['offsets'][img_num] : batch['offsets'][img_num+1]],
                                               valid_mask)
    print(np.unique(gt_instances_img))
    gt_instances_img[gt_instances_img == cfg.ignore_label] = 0
    
    plot_dict['instance_labels'] = gt_instances_img

subplotter(plot_dict)

In [None]:
# num_iters_trained = cfg.test_epoch * len(dataset.train_data_loader)
num_iters_trained = 300000
with torch.no_grad():
    start1 = time.time()
    preds = model_fn(batch, model, num_iters_trained)
    end1 = time.time() - start1

In [None]:
model.eval()
N = batch['feats'].shape[0]

##### get predictions (#1 semantic_pred, pt_offsets; #2 scores, proposals_pred)
semantic_scores = preds['semantic']  # (N, nClass) float32, cuda
semantic_pred = semantic_scores.max(1)[1]  # (N) long, cuda

pt_offsets = preds['pt_offsets']    # (N, 3), float32, cuda

scores = preds['score']   # (nProposal, 1) float, cuda
scores_pred = torch.sigmoid(scores.view(-1))

proposals_idx, proposals_offset = preds['proposals']
# proposals_idx: (sumNPoint, 2), int, cpu, dim 0 for cluster_id, dim 1 for corresponding point idxs in N
# proposals_offset: (nProposal + 1), int, cpu
proposals_pred = torch.zeros((proposals_offset.shape[0] - 1, N), dtype=torch.int, device=scores_pred.device) # (nProposal, N), int, cuda
proposals_pred[proposals_idx[:, 0].long(), proposals_idx[:, 1].long()] = 1

semantic_id = torch.tensor(semantic_label_idx, device=scores_pred.device)[semantic_pred[proposals_idx[:, 1][proposals_offset[:-1].long()].long()]] # (nProposal), long

##### score threshold
score_mask = (scores_pred > cfg.TEST_SCORE_THRESH)
scores_pred = scores_pred[score_mask]
proposals_pred = proposals_pred[score_mask]
semantic_id = semantic_id[score_mask]

##### npoint threshold
proposals_pointnum = proposals_pred.sum(1)
npoint_mask = (proposals_pointnum > cfg.TEST_NPOINT_THRESH)
scores_pred = scores_pred[npoint_mask]
proposals_pred = proposals_pred[npoint_mask]
semantic_id = semantic_id[npoint_mask]

##### nms
if semantic_id.shape[0] == 0:
    pick_idxs = np.empty(0)
else:
    proposals_pred_f = proposals_pred.float()  # (nProposal, N), float, cuda
    intersection = torch.mm(proposals_pred_f, proposals_pred_f.t())  # (nProposal, nProposal), float, cuda
    proposals_pointnum = proposals_pred_f.sum(1)  # (nProposal), float, cuda
    proposals_pn_h = proposals_pointnum.unsqueeze(-1).repeat(1, proposals_pointnum.shape[0])
    proposals_pn_v = proposals_pointnum.unsqueeze(0).repeat(proposals_pointnum.shape[0], 1)
    cross_ious = intersection / (proposals_pn_h + proposals_pn_v - intersection)
    pick_idxs = non_max_suppression(cross_ious.cpu().numpy(), scores_pred.cpu().numpy(), cfg.TEST_NMS_THRESH)  # int, (nCluster, N)
clusters = proposals_pred[pick_idxs]
cluster_scores = scores_pred[pick_idxs]
cluster_semantic_id = semantic_id[pick_idxs]

In [None]:
pt_offsets_img = inpaint_with_valid_mask(pt_offsets[batch['offsets'][img_num] : batch['offsets'][img_num+1]],
                                         valid_mask)
semantic_pred_img = inpaint_with_valid_mask(semantic_pred[batch['offsets'][img_num] : batch['offsets'][img_num+1]],
                                            valid_mask)
instances_img = inpaint_cluster_img(clusters[:, batch['offsets'][img_num] : batch['offsets'][img_num+1]],
                                   valid_mask)

In [None]:
num_pred_instances = len([x for x in np.unique(instances_img) if x not in [0,1]])
fg_mask = semantic_pred_img == 2

plot_dict = {
    'Point Offsets' : flowlib.flow_to_image(pt_offsets_img * fg_mask[..., None]),
    'Semantic Prediction' : semantic_pred_img,
    f"Instances Prediction. Num: {num_pred_instances}" : instances_img,
}
if 'instance_labels' in batch:
    num_gt_instances = len([x for x in np.unique(gt_instances_img) if x not in [0,1]])
    plot_dict[f"Instances GT. Num: {num_gt_instances}"] = gt_instances_img

subplotter(plot_dict, max_plots_per_row=3)

In [None]:
subplotter([instances_img == x for x in np.unique(instances_img)[1:]])