In [1]:
"""Evaluate RPMNet. Also contains functionality to compute evaluation metrics given transforms

Example Usages:
    1. Evaluate RPMNet
        python eval.py --noise_type crop --resume [path-to-model.pth]

    2. Evaluate precomputed transforms (.npy file containing np.array of size (B, 3, 4) or (B, n_iter, 3, 4))
        python eval.py --noise_type crop --transform_file [path-to-transforms.npy]
"""
from collections import defaultdict
import json
import os
import pickle
import time
from typing import Dict, List

import numpy as np
import open3d  # Need to import before torch
import pandas as pd
from scipy import sparse
from tqdm import tqdm
import torch

from arguments import rpmnet_eval_arguments
from common.misc import prepare_logger
from common.torch import dict_all_to_device, CheckPointManager, to_numpy
from common.math import se3
from common.math_torch import se3
from common.math.so3 import dcm2euler
from data_loader.datasets import get_test_datasets
import models.rpmnet


def compute_metrics(data: Dict, pred_transforms) -> Dict:
    """Compute metrics required in the paper
    """

    def square_distance(src, dst):
        return torch.sum((src[:, :, None, :] - dst[:, None, :, :]) ** 2, dim=-1)

    with torch.no_grad():
        pred_transforms = pred_transforms
        gt_transforms = data['transform_gt']
        points_src = data['points_src'][..., :3]
        points_ref = data['points_ref'][..., :3]
        points_raw = data['points_raw'][..., :3]

        # Euler angles, Individual translation errors (Deep Closest Point convention)
        # TODO Change rotation to torch operations
        r_gt_euler_deg = dcm2euler(gt_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        r_pred_euler_deg = dcm2euler(pred_transforms[:, :3, :3].detach().cpu().numpy(), seq='xyz')
        t_gt = gt_transforms[:, :3, 3]
        t_pred = pred_transforms[:, :3, 3]
        r_mse = np.mean((r_gt_euler_deg - r_pred_euler_deg) ** 2, axis=1)
        r_mae = np.mean(np.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1)
        t_mse = torch.mean((t_gt - t_pred) ** 2, dim=1)
        t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1)

        # Rotation, translation errors (isotropic, i.e. doesn't depend on error
        # direction, which is more representative of the actual error)
        concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms)
        rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2]
        residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi
        residual_transmag = concatenated[:, :, 3].norm(dim=-1)

        # Modified Chamfer distance
        src_transformed = se3.transform(pred_transforms, points_src)
        ref_clean = points_raw
        src_clean = se3.transform(se3.concatenate(pred_transforms, se3.inverse(gt_transforms)), points_raw)
        dist_src = torch.min(square_distance(src_transformed, ref_clean), dim=-1)[0]
        dist_ref = torch.min(square_distance(points_ref, src_clean), dim=-1)[0]
        chamfer_dist = torch.mean(dist_src, dim=1) + torch.mean(dist_ref, dim=1)

        metrics = {
            'r_mse': r_mse,
            'r_mae': r_mae,
            't_mse': to_numpy(t_mse),
            't_mae': to_numpy(t_mae),
            'err_r_deg': to_numpy(residual_rotdeg),
            'err_t': to_numpy(residual_transmag),
            'chamfer_dist': to_numpy(chamfer_dist)
        }

    return metrics


def summarize_metrics(metrics):
    """Summaries computed metrices by taking mean over all data instances"""
    summarized = {}
    for k in metrics:
        if k.endswith('mse'):
            summarized[k[:-3] + 'rmse'] = np.sqrt(np.mean(metrics[k]))
        elif k.startswith('err'):
            summarized[k + '_mean'] = np.mean(metrics[k])
            summarized[k + '_rmse'] = np.sqrt(np.mean(metrics[k]**2))
        else:
            summarized[k] = np.mean(metrics[k])

    return summarized


def print_metrics(logger, summary_metrics: Dict, losses_by_iteration: List = None,
                  title: str = 'Metrics'):
    """Prints out formated metrics to logger"""

    logger.info(title + ':')
    logger.info('=' * (len(title) + 1))

    if losses_by_iteration is not None:
        losses_all_str = ' | '.join(['{:.5f}'.format(c) for c in losses_by_iteration])
        logger.info('Losses by iteration: {}'.format(losses_all_str))

    logger.info('DeepCP metrics:{:.4f}(rot-rmse) | {:.4f}(rot-mae) | {:.4g}(trans-rmse) | {:.4g}(trans-mae)'.format(
        summary_metrics['r_rmse'], summary_metrics['r_mae'],
        summary_metrics['t_rmse'], summary_metrics['t_mae'],
    ))
    logger.info('Rotation error {:.4f}(deg, mean) | {:.4f}(deg, rmse)'.format(
        summary_metrics['err_r_deg_mean'], summary_metrics['err_r_deg_rmse']))
    logger.info('Translation error {:.4g}(mean) | {:.4g}(rmse)'.format(
        summary_metrics['err_t_mean'], summary_metrics['err_t_rmse']))
    logger.info('Chamfer error: {:.7f}(mean-sq)'.format(
        summary_metrics['chamfer_dist']
    ))


def inference(data_loader, model: torch.nn.Module):
    """Runs inference over entire dataset

    Args:
        data_loader (torch.utils.data.DataLoader): Dataset loader
        model (model.nn.Module): Network model to evaluate

    Returns:
        pred_transforms_all: predicted transforms (B, n_iter, 3, 4) where B is total number of instances
        endpoints_out (Dict): Network endpoints
    """

    _logger.info('Starting inference...')
    model.eval()

    pred_transforms_all = []
    all_betas, all_alphas = [], []
    total_time = 0.0
    endpoints_out = defaultdict(list)
    total_rotation = []

    with torch.no_grad():
        for val_data in tqdm(data_loader):

            rot_trace = val_data['transform_gt'][:, 0, 0] + val_data['transform_gt'][:, 1, 1] + \
                        val_data['transform_gt'][:, 2, 2]
            rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi
            total_rotation.append(np.abs(to_numpy(rotdeg)))

            dict_all_to_device(val_data, _device)
            time_before = time.time()
            pred_transforms, endpoints = model(val_data, _args.num_reg_iter)
            total_time += time.time() - time_before

            if _args.method == 'rpmnet':
                all_betas.append(endpoints['beta'])
                all_alphas.append(endpoints['alpha'])

            if isinstance(pred_transforms[-1], torch.Tensor):
                pred_transforms_all.append(to_numpy(torch.stack(pred_transforms, dim=1)))
            else:
                pred_transforms_all.append(np.stack(pred_transforms, axis=1))

            # Saves match matrix. We only save the top matches to save storage/time.
            # However, this still takes quite a bit of time to save. Comment out if not needed.
            if 'perm_matrices' in endpoints:
                perm_matrices = to_numpy(torch.stack(endpoints['perm_matrices'], dim=1))
                thresh = np.percentile(perm_matrices, 99.9, axis=[2, 3])  # Only retain top 0.1% of entries
                below_thresh_mask = perm_matrices < thresh[:, :, None, None]
                perm_matrices[below_thresh_mask] = 0.0

                for i_data in range(perm_matrices.shape[0]):
                    sparse_perm_matrices = []
                    for i_iter in range(perm_matrices.shape[1]):
                        sparse_perm_matrices.append(sparse.coo_matrix(perm_matrices[i_data, i_iter, :, :]))
                    endpoints_out['perm_matrices'].append(sparse_perm_matrices)

    _logger.info('Total inference time: {}s'.format(total_time))
    total_rotation = np.concatenate(total_rotation, axis=0)
    _logger.info('Rotation range in data: {}(avg), {}(max)'.format(np.mean(total_rotation), np.max(total_rotation)))
    pred_transforms_all = np.concatenate(pred_transforms_all, axis=0)

    return pred_transforms_all, endpoints_out


def evaluate(pred_transforms, data_loader: torch.utils.data.dataloader.DataLoader):
    """ Evaluates the computed transforms against the groundtruth

    Args:
        pred_transforms: Predicted transforms (B, [iter], 3/4, 4)
        data_loader: Loader for dataset.

    Returns:
        Computed metrics (List of dicts), and summary metrics (only for last iter)
    """

    _logger.info('Evaluating transforms...')
    num_processed, num_total = 0, len(pred_transforms)

    if pred_transforms.ndim == 4:
        pred_transforms = torch.from_numpy(pred_transforms).to(_device)
    else:
        assert pred_transforms.ndim == 3 and \
               (pred_transforms.shape[1:] == (4, 4) or pred_transforms.shape[1:] == (3, 4))
        pred_transforms = torch.from_numpy(pred_transforms[:, None, :, :]).to(_device)

    metrics_for_iter = [defaultdict(list) for _ in range(pred_transforms.shape[1])]

    for data in tqdm(data_loader, leave=False):
        dict_all_to_device(data, _device)

        batch_size = 0
        for i_iter in range(pred_transforms.shape[1]):
            batch_size = data['points_src'].shape[0]

            cur_pred_transforms = pred_transforms[num_processed:num_processed+batch_size, i_iter, :, :]
            metrics = compute_metrics(data, cur_pred_transforms)
            for k in metrics:
                metrics_for_iter[i_iter][k].append(metrics[k])
        num_processed += batch_size

    for i_iter in range(len(metrics_for_iter)):
        metrics_for_iter[i_iter] = {k: np.concatenate(metrics_for_iter[i_iter][k], axis=0)
                                    for k in metrics_for_iter[i_iter]}
        summary_metrics = summarize_metrics(metrics_for_iter[i_iter])
        print_metrics(_logger, summary_metrics, title='Evaluation result (iter {})'.format(i_iter))

    return metrics_for_iter, summary_metrics


def save_eval_data(pred_transforms, endpoints, metrics, summary_metrics, save_pathdic):
    """Saves out the computed transforms
    """

    # Save transforms
    np.save(os.path.join(save_path, 'pred_transforms.npy'), pred_transforms)

    # Save endpoints if any
    for k in endpoints:
        if isinstance(endpoints[k], np.ndarray):
            np.save(os.path.join(save_path, '{}.npy'.format(k)), endpoints[k])
        else:
            with open(os.path.join(save_path, '{}.pickle'.format(k)), 'wb') as fid:
                pickle.dump(endpoints[k], fid)

    # Save metrics: Write each iteration to a different worksheet.
    writer = pd.ExcelWriter(os.path.join(save_path, 'metrics.xlsx'))
    for i_iter in range(len(metrics)):
        metrics[i_iter]['r_rmse'] = np.sqrt(metrics[i_iter]['r_mse'])
        metrics[i_iter]['t_rmse'] = np.sqrt(metrics[i_iter]['t_mse'])
        metrics[i_iter].pop('r_mse')
        metrics[i_iter].pop('t_mse')
        metrics_df = pd.DataFrame.from_dict(metrics[i_iter])
        metrics_df.to_excel(writer, sheet_name='Iter_{}'.format(i_iter+1))
    writer.close()

    # Save summary metrics
    summary_metrics_float = {k: float(summary_metrics[k]) for k in summary_metrics}
    with open(os.path.join(save_path, 'summary_metrics.json'), 'w') as json_out:
        json.dump(summary_metrics_float, json_out)

    _logger.info('Saved evaluation results to {}'.format(save_path))


def get_model():
    _logger.info('Computing transforms using {}'.format(_args.method))
    if _args.method == 'rpmnet':
        assert _args.resume is not None
        model = models.rpmnet.get_model(_args)
        model.to(_device)
        saver = CheckPointManager(os.path.join(_log_path, 'ckpt', 'models'))
        saver.load(_args.resume, model)
    else:
        raise NotImplementedError
    return model


def main():
    # Load data_loader
    test_dataset = get_test_datasets(_args)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=_args.val_batch_size, shuffle=False)

    if _args.transform_file is not None:
        _logger.info('Loading from precomputed transforms: {}'.format(_args.transform_file))
        pred_transforms = np.load(_args.transform_file)
        endpoints = {}
    else:
        model = get_model()
        pred_transforms, endpoints = inference(test_loader, model)  # Feedforward transforms

    # Compute evaluation matrices
    eval_metrics, summary_metrics = evaluate(pred_transforms, data_loader=test_loader)

    save_eval_data(pred_transforms, endpoints, eval_metrics, summary_metrics, _args.eval_save_path)
    _logger.info('Finished')





In [2]:
parser = rpmnet_eval_arguments()
_args = parser.parse_args(args=["--resume=../logs/clean-trained.pth", "--num_reg_iter=10"])
# print(args)
_logger, _log_path = prepare_logger(_args, log_path=_args.eval_save_path)
os.environ['CUDA_VISIBLE_DEVICES'] = str(_args.gpu)
if _args.gpu >= 0 and (_args.method == 'rpm' or _args.method == 'rpmnet'):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(_args.gpu)
    _device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
else:
    _device = torch.device('cpu')

model = get_model()
model.eval()

2020-06-24 03:28:02 li-Lenovo-Y520-15IKBM root[10932] INFO Command: /home/li/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py -f /run/user/1000/jupyter/kernel-cc4ba887-c871-4f2c-8e07-dc2c931eb22b.json
2020-06-24 03:28:02 li-Lenovo-Y520-15IKBM root[10932] INFO Source is from Commit c37e6873 (2020-04-25): Release of source code
2020-06-24 03:28:02 li-Lenovo-Y520-15IKBM root[10932] INFO Arguments: logdir: ../logs, dev: False, name: None, debug: False, dataset_path: ../datasets/modelnet40_ply_hdf5_2048, dataset_type: modelnet_hdf, num_points: 1024, noise_type: crop, rot_mag: 45.0, trans_mag: 0.5, partial: [0.7, 0.7], method: rpmnet, radius: 0.3, num_neighbors: 64, features: ['ppf', 'dxyz', 'xyz'], feat_dim: 96, no_slack: False, num_sk_iter: 5, num_reg_iter: 10, loss_type: mae, wt_inliers: 0.01, train_batch_size: 8, val_batch_size: 2, resume: ../logs/clean-trained.pth, gpu: 0, test_category_file: ./data_loader/modelnet40_half2.txt, transform_file: None, eval_save_path: ../eval_re

RPMNetEarlyFusion(
  (weights_net): ParameterPredictionNet(
    (prepool): Sequential(
      (0): Conv1d(4, 64, kernel_size=(1,), stride=(1,))
      (1): GroupNorm(8, 64, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      (4): GroupNorm(8, 64, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Conv1d(64, 64, kernel_size=(1,), stride=(1,))
      (7): GroupNorm(8, 64, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Conv1d(64, 128, kernel_size=(1,), stride=(1,))
      (10): GroupNorm(8, 128, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))
      (13): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (14): ReLU()
    )
    (pooling): AdaptiveMaxPool1d(output_size=1)
    (postpool): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): GroupNorm(16, 512, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Linear(in_features=512, out_features

## official test set

In [3]:
# main()
test_dataset = get_test_datasets(_args)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=_args.val_batch_size, shuffle=False)

2020-06-24 01:13:33 li-Lenovo-Y520-15IKBM root[6798] INFO Test transforms: SetDeterministic, SplitSourceRef, RandomCrop, RandomTransformSE3_euler, Resampler, RandomJitter, ShufflePoints
2020-06-24 01:13:33 li-Lenovo-Y520-15IKBM ModelNetHdf[6798] INFO Loading data from ../datasets/modelnet40_ply_hdf5_2048/test_files.txt for test
2020-06-24 01:13:33 li-Lenovo-Y520-15IKBM ModelNetHdf[6798] INFO Categories used: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39].
2020-06-24 01:13:34 li-Lenovo-Y520-15IKBM ModelNetHdf[6798] INFO Loaded 1266 test instances.


In [6]:
i = 0
val_data = ''
for data in test_loader:
    if i==1:
        val_data = data
    i += 1

In [7]:
# help(dict_all_to_device)
dict_all_to_device(val_data, _device)

In [8]:
model.eval()
with torch.no_grad():
    dict_all_to_device(val_data, _device)
    pred_transforms, endpoints = model(val_data, _args.num_reg_iter)

In [13]:
val_data['points_ref'].shape, val_data['points_src'].shape

(torch.Size([2, 717, 6]), torch.Size([2, 717, 6]))

## load LiDAR cloud and compute transform

In [10]:
import open3d as o3d
# pcd1 = o3d.io.read_point_cloud("/home/li/Documents/pcl_tutorial/room_scan1_sub.pcd")
# pcd2 = o3d.io.read_point_cloud("/home/li/Documents/pcl_tutorial/room_scan2_sub.pcd")
# pcd1 = o3d.io.read_point_cloud("/home/li/Documents/pcl_tutorial/remaining_cloud011_sub.pcd")
# pcd2 = o3d.io.read_point_cloud("/home/li/Documents/pcl_tutorial/remaining_cloud012_sub.pcd")
pcd1 = o3d.io.read_point_cloud("/home/li/car1_map_sub.pcd")
pcd2 = o3d.io.read_point_cloud("/home/li/car2_map_sub.pcd")


pc1 = np.asarray(pcd1.points)
pc2 = np.asarray(pcd2.points)
center1 = pc1.mean(axis=0)
center2 = pc2.mean(axis=0)

In [4]:
# np.asarray(pcd1.normals)
val_data = {}
val_data['points_ref'] = torch.Tensor([np.concatenate(
    [np.asarray(pcd1.points)-center1, np.asarray(pcd1.normals)], axis=1)])
val_data['points_src'] = torch.Tensor([np.concatenate(
    [np.asarray(pcd2.points)-center2, np.asarray(pcd2.normals)], axis=1)])

In [5]:
val_data['points_ref'].shape, val_data['points_src'].shape

(torch.Size([1, 3688, 6]), torch.Size([1, 3887, 6]))

In [6]:
model.eval()
with torch.no_grad():
    dict_all_to_device(val_data, _device)
    pred_transforms, endpoints = model(val_data, _args.num_reg_iter)

In [7]:
pred_transforms

[tensor([[[ 0.9988, -0.0480, -0.0098, -1.0657],
          [ 0.0483,  0.9983,  0.0315, -1.2318],
          [ 0.0082, -0.0320,  0.9995, -0.1457]]], device='cuda:0'),
 tensor([[[ 9.9733e-01, -7.3063e-02,  1.6052e-03, -1.0984e+00],
          [ 7.2903e-02,  9.9620e-01,  4.7739e-02, -1.6073e+00],
          [-5.0870e-03, -4.7494e-02,  9.9886e-01, -2.6481e-01]]],
        device='cuda:0'),
 tensor([[[ 0.9958, -0.0907,  0.0091, -1.1385],
          [ 0.0901,  0.9944,  0.0561, -1.8076],
          [-0.0141, -0.0551,  0.9984, -0.3168]]], device='cuda:0'),
 tensor([[[ 0.9945, -0.1037,  0.0134, -1.1950],
          [ 0.1027,  0.9928,  0.0609, -1.9474],
          [-0.0196, -0.0592,  0.9981, -0.3487]]], device='cuda:0'),
 tensor([[[ 0.9934, -0.1140,  0.0159, -1.2407],
          [ 0.1127,  0.9916,  0.0639, -2.0446],
          [-0.0231, -0.0617,  0.9978, -0.3709]]], device='cuda:0'),
 tensor([[[ 0.9923, -0.1223,  0.0175, -1.2687],
          [ 0.1209,  0.9905,  0.0658, -2.1093],
          [-0.0253, -0.0632,

In [8]:
transform = pred_transforms[-1][0].tolist()
T_ab_str = ''
for i in range(3):
    for j in range(4):
        T_ab_str += str(transform[i][j]) + ' '
    T_ab_str += '\n'
T_ab_str += '0 0 0 1'
#     print(l)
#     T_ab_str += str(rotation_ab_pred_list[i][0]) + ' ' + str(rotation_ab_pred_list[i][1])  \
#     + ' ' + str(rotation_ab_pred_list[i][2]) + ' ' \
#     + str(translation_ab_pred_list[i]) + '\n'
print(T_ab_str)

0.9991929531097412 -0.0336252860724926 0.021978307515382767 -1.7121531963348389 
0.03207755833864212 0.9972147941589355 0.0673338770866394 -2.0785439014434814 
-0.024181202054023743 -0.06657449901103973 0.9974883794784546 -0.6212092041969299 
0 0 0 1


## Load single point cloud and manually transform it as a source cloud

In [11]:
import data_loader.transforms as Transforms
import common.math.se3 as se3
SE3_Z = Transforms.RandomRotatorZ()
transform_mat = SE3_Z.generate_transform()
print(transform_mat)

[[ 0.22437914  0.9745019   0.          0.        ]
 [-0.9745019   0.22437914  0.          0.        ]
 [ 0.          0.          1.          0.        ]]


In [14]:
import open3d as o3d
pcd = o3d.io.read_point_cloud("/home/li/Lille_street_small_sub.pcd")
pc1 = np.asarray(pcd.points)
pc2 = se3.transform(transform_mat, pc1[:, :3])
# pc1.shape, pc2.shape
center1 = pc1.mean(axis=0)
center2 = pc2.mean(axis=0)
# center1, center2
pc1 = pc1 - center1
pc2 = pc2 - center2

normals1 = np.asarray(pcd.normals)
normals2 = se3.transform(transform_mat, normals1)
# normals1, normals2

In [15]:
center1, center2

(array([-1.30791559e-08,  1.22741652e-08,  1.72647062e-08]),
 array([9.02660164e-09, 1.54997566e-08, 1.72647062e-08]))

In [16]:
val_data = {}
val_data['points_ref'] = torch.Tensor([np.concatenate(
    [pc1, normals1], axis=1)])
val_data['points_src'] = torch.Tensor([np.concatenate(
    [pc2, normals2], axis=1)])

In [17]:
val_data['points_ref'].shape, val_data['points_src'].shape

(torch.Size([1, 3819, 6]), torch.Size([1, 3819, 6]))

In [18]:
model.eval()
with torch.no_grad():
    dict_all_to_device(val_data, _device)
    pred_transforms, endpoints = model(val_data, _args.num_reg_iter)

In [19]:
pred_transforms

[tensor([[[ 0.5894,  0.1467,  0.7944, -0.0415],
          [ 0.7823,  0.1418, -0.6065,  0.0311],
          [-0.2016,  0.9790, -0.0312, -0.0805]]], device='cuda:0'),
 tensor([[[ 0.2754,  0.3249,  0.9047,  0.1849],
          [ 0.9526, -0.2190, -0.2113,  0.0454],
          [ 0.1295,  0.9200, -0.3698, -0.3583]]], device='cuda:0'),
 tensor([[[ 0.3019,  0.7024,  0.6446, -0.1066],
          [ 0.9388, -0.3368, -0.0726, -0.2666],
          [ 0.1661,  0.6270, -0.7611, -0.3061]]], device='cuda:0'),
 tensor([[[ 0.3098,  0.8609,  0.4035, -0.1492],
          [ 0.9352, -0.3525,  0.0341, -0.2603],
          [ 0.1716,  0.3668, -0.9143, -0.2259]]], device='cuda:0'),
 tensor([[[ 0.3221,  0.9147,  0.2440, -0.1329],
          [ 0.9303, -0.3536,  0.0973, -0.1719],
          [ 0.1753,  0.1956, -0.9649, -0.0875]]], device='cuda:0'),
 tensor([[[ 0.3333,  0.9325,  0.1393, -0.1240],
          [ 0.9256, -0.3517,  0.1397, -0.1467],
          [ 0.1792,  0.0824, -0.9803, -0.0359]]], device='cuda:0'),
 tensor([[[ 0.34

In [20]:
transform = pred_transforms[-1][0].tolist()
T_ab_str = ''
for i in range(3):
    for j in range(4):
        T_ab_str += str(transform[i][j]) + ' '
    T_ab_str += '\n'
T_ab_str += '0 0 0 1'
print(T_ab_str)

0.3284177780151367 0.8890770673751831 -0.3188788592815399 -0.15857288241386414 
0.9220850467681885 -0.22862330079078674 0.3122354745864868 -0.21190305054187775 
0.20469826459884644 -0.39657703042030334 -0.8948885798454285 -0.018927641212940216 
0 0 0 1


In [13]:
pcd.points = o3d.utility.Vector3dVector(pc1)
o3d.io.write_point_cloud("/home/li/Lille_street_small_sub.pcd", pcd)

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(pc2)
o3d.io.write_point_cloud("/home/li/Lille_street_small_rotated.pcd", pcd2)

True

In [None]:
""" Train RPMNet

Example usage:
    python train.py --noise_type crop
    python train.py --noise_type jitter --train_batch_size 4
"""
from collections import defaultdict
import os
import random
from typing import Dict, List

from matplotlib.pyplot import cm as colormap
import numpy as np
import open3d  # Ensure this is imported before pytorch
from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.utils.data
from tqdm import tqdm

from arguments import rpmnet_train_arguments
from common.colors import BLUE, ORANGE
from common.misc import prepare_logger
from common.torch import dict_all_to_device, CheckPointManager, TorchDebugger
from common.math_torch import se3
from data_loader.datasets import get_train_datasets
from eval import compute_metrics, summarize_metrics, print_metrics
from models.rpmnet import get_model

# Set up arguments and logging
parser = rpmnet_train_arguments()
_args = parser.parse_args()
_logger, _log_path = prepare_logger(_args)
if _args.gpu >= 0:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(_args.gpu)
    _device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
else:
    _device = torch.device('cpu')


def main():
    train_set, val_set = get_train_datasets(_args)
    run(train_set, val_set)


def compute_losses(data: Dict, pred_transforms: List, endpoints: Dict,
                   loss_type: str = 'mae', reduction: str = 'mean') -> Dict:
    """Compute losses

    Args:
        data: Current mini-batch data
        pred_transforms: Predicted transform, to compute main registration loss
        endpoints: Endpoints for training. For computing outlier penalty
        loss_type: Registration loss type, either 'mae' (Mean absolute error, used in paper) or 'mse'
        reduction: Either 'mean' or 'none'. Use 'none' to accumulate losses outside
                   (useful for accumulating losses for entire validation dataset)

    Returns:
        losses: Dict containing various fields. Total loss to be optimized is in losses['total']

    """

    losses = {}
    num_iter = len(pred_transforms)

    # Compute losses
    gt_src_transformed = se3.transform(data['transform_gt'], data['points_src'][..., :3])
    if loss_type == 'mse':
        # MSE loss to the groundtruth (does not take into account possible symmetries)
        criterion = nn.MSELoss(reduction=reduction)
        for i in range(num_iter):
            pred_src_transformed = se3.transform(pred_transforms[i], data['points_src'][..., :3])
            if reduction.lower() == 'mean':
                losses['mse_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed)
            elif reduction.lower() == 'none':
                losses['mse_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed),
                                                        dim=[-1, -2])
    elif loss_type == 'mae':
        # MSE loss to the groundtruth (does not take into account possible symmetries)
        criterion = nn.L1Loss(reduction=reduction)
        for i in range(num_iter):
            pred_src_transformed = se3.transform(pred_transforms[i], data['points_src'][..., :3])
            if reduction.lower() == 'mean':
                losses['mae_{}'.format(i)] = criterion(pred_src_transformed, gt_src_transformed)
            elif reduction.lower() == 'none':
                losses['mae_{}'.format(i)] = torch.mean(criterion(pred_src_transformed, gt_src_transformed),
                                                        dim=[-1, -2])
    else:
        raise NotImplementedError

    # Penalize outliers
    for i in range(num_iter):
        ref_outliers_strength = (1.0 - torch.sum(endpoints['perm_matrices'][i], dim=1)) * _args.wt_inliers
        src_outliers_strength = (1.0 - torch.sum(endpoints['perm_matrices'][i], dim=2)) * _args.wt_inliers
        if reduction.lower() == 'mean':
            losses['outlier_{}'.format(i)] = torch.mean(ref_outliers_strength) + torch.mean(src_outliers_strength)
        elif reduction.lower() == 'none':
            losses['outlier_{}'.format(i)] = torch.mean(ref_outliers_strength, dim=1) + \
                                             torch.mean(src_outliers_strength, dim=1)

    discount_factor = 0.5  # Early iterations will be discounted
    total_losses = []
    for k in losses:
        discount = discount_factor ** (num_iter - int(k[k.rfind('_')+1:]) - 1)
        total_losses.append(losses[k] * discount)
    losses['total'] = torch.sum(torch.stack(total_losses), dim=0)

    return losses


def save_summaries(writer: SummaryWriter, data: Dict, predicted: List, endpoints: Dict = None,
                   losses: Dict = None, metrics: Dict = None, step: int = 0):
    """Save tensorboard summaries"""

    subset = [0, 1]

    with torch.no_grad():
        # Save clouds
        if 'points_src' in data:

            points_src = data['points_src'][subset, ..., :3]
            points_ref = data['points_ref'][subset, ..., :3]

            colors = torch.from_numpy(
                np.concatenate([np.tile(ORANGE, (*points_src.shape[0:2], 1)),
                                np.tile(BLUE, (*points_ref.shape[0:2], 1))], axis=1))

            iters_to_save = [0, len(predicted)-1] if len(predicted) > 1 else [0]

            # Save point cloud at iter0, iter1 and after last iter
            concat_cloud_input = torch.cat((points_src, points_ref), dim=1)
            writer.add_mesh('iter_0', vertices=concat_cloud_input, colors=colors, global_step=step)
            for i_iter in iters_to_save:
                src_transformed_first = se3.transform(predicted[i_iter][subset, ...], points_src)
                concat_cloud_first = torch.cat((src_transformed_first, points_ref), dim=1)
                writer.add_mesh('iter_{}'.format(i_iter+1), vertices=concat_cloud_first, colors=colors, global_step=step)

            if endpoints is not None and 'perm_matrices' in endpoints:
                color_mapper = colormap.ScalarMappable(norm=None, cmap=colormap.get_cmap('coolwarm'))
                for i_iter in iters_to_save:
                    ref_weights = torch.sum(endpoints['perm_matrices'][i_iter][subset, ...], dim=1)
                    ref_colors = color_mapper.to_rgba(ref_weights.detach().cpu().numpy())[..., :3]
                    writer.add_mesh('ref_weights_{}'.format(i_iter), vertices=points_ref,
                                    colors=torch.from_numpy(ref_colors) * 255, global_step=step)

        if endpoints is not None:
            if 'perm_matrices' in endpoints:
                for i_iter in range(len(endpoints['perm_matrices'])):
                    src_weights = torch.sum(endpoints['perm_matrices'][i_iter], dim=2)
                    ref_weights = torch.sum(endpoints['perm_matrices'][i_iter], dim=1)
                    writer.add_histogram('src_weights_{}'.format(i_iter), src_weights, global_step=step)
                    writer.add_histogram('ref_weights_{}'.format(i_iter), ref_weights, global_step=step)

        # Write losses and metrics
        if losses is not None:
            for l in losses:
                writer.add_scalar('losses/{}'.format(l), losses[l], step)
        if metrics is not None:
            for m in metrics:
                writer.add_scalar('metrics/{}'.format(m), metrics[m], step)

        writer.flush()


def validate(data_loader, model: torch.nn.Module, summary_writer: SummaryWriter, step: int):
    """Perform a single validation run, and saves results into tensorboard summaries"""

    _logger.info('Starting validation run...')

    with torch.no_grad():
        all_val_losses = defaultdict(list)
        all_val_metrics_np = defaultdict(list)
        for val_data in data_loader:
            dict_all_to_device(val_data, _device)
            pred_test_transforms, endpoints = model(val_data, _args.num_reg_iter)
            val_losses = compute_losses(val_data, pred_test_transforms, endpoints,
                                        loss_type=_args.loss_type, reduction='none')
            val_metrics = compute_metrics(val_data, pred_test_transforms[-1])

            for k in val_losses:
                all_val_losses[k].append(val_losses[k])
            for k in val_metrics:
                all_val_metrics_np[k].append(val_metrics[k])

        all_val_losses = {k: torch.cat(all_val_losses[k]) for k in all_val_losses}
        all_val_metrics_np = {k: np.concatenate(all_val_metrics_np[k]) for k in all_val_metrics_np}
        mean_val_losses = {k: torch.mean(all_val_losses[k]) for k in all_val_losses}

    # Rerun on random and worst data instances and save to summary
    rand_idx = random.randint(0, all_val_losses['total'].shape[0] - 1)
    worst_idx = torch.argmax(all_val_losses['{}_{}'.format(_args.loss_type, _args.num_reg_iter - 1)]).cpu().item()
    indices_to_rerun = [rand_idx, worst_idx]
    data_to_rerun = defaultdict(list)
    for i in indices_to_rerun:
        cur = data_loader.dataset[i]
        for k in cur:
            data_to_rerun[k].append(cur[k])
    for k in data_to_rerun:
        data_to_rerun[k] = torch.from_numpy(np.stack(data_to_rerun[k], axis=0))
    dict_all_to_device(data_to_rerun, _device)
    pred_transforms, endpoints = model(data_to_rerun, _args.num_reg_iter)

    summary_metrics = summarize_metrics(all_val_metrics_np)
    losses_by_iteration = torch.stack([mean_val_losses['{}_{}'.format(_args.loss_type, k)]
                                       for k in range(_args.num_reg_iter)]).cpu().numpy()
    print_metrics(_logger, summary_metrics, losses_by_iteration, 'Validation results')

    save_summaries(summary_writer, data=data_to_rerun, predicted=pred_transforms, endpoints=endpoints,
                   losses=mean_val_losses, metrics=summary_metrics, step=step)

    score = -summary_metrics['chamfer_dist']
    return score


def run(train_set, val_set):
    """Main train/val loop"""

    _logger.debug('Trainer (PID=%d), %s', os.getpid(), _args)

    model = get_model(_args)
    model.to(_device)
    global_step = 0

    # dataloaders
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=_args.train_batch_size, shuffle=True, num_workers=_args.num_workers)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=_args.val_batch_size, shuffle=False, num_workers=_args.num_workers)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=_args.lr)

    # Summary writer and Checkpoint manager
    train_writer = SummaryWriter(os.path.join(_log_path, 'train'), flush_secs=10)
    val_writer = SummaryWriter(os.path.join(_log_path, 'val'), flush_secs=10)
    saver = CheckPointManager(os.path.join(_log_path, 'ckpt', 'model'), keep_checkpoint_every_n_hours=0.5)
    if _args.resume is not None:
        global_step = saver.load(_args.resume, model, optimizer)

    # trainings
    torch.autograd.set_detect_anomaly(_args.debug)
    model.train()

    steps_per_epoch = len(train_loader)
    if _args.summary_every < 0:
        _args.summary_every = abs(_args.summary_every) * steps_per_epoch
    if _args.validate_every < 0:
        _args.validate_every = abs(_args.validate_every) * steps_per_epoch

    for epoch in range(0, _args.epochs):
        _logger.info('Begin epoch {} (steps {} - {})'.format(epoch, global_step, global_step + len(train_loader)))
        tbar = tqdm(total=len(train_loader), ncols=100)
        for train_data in train_loader:
            global_step += 1

            optimizer.zero_grad()

            # Forward through neural network
            dict_all_to_device(train_data, _device)
            pred_transforms, endpoints = model(train_data, _args.num_train_reg_iter)  # Use less iter during training

            # Compute loss, and optimize
            train_losses = compute_losses(train_data, pred_transforms, endpoints,
                                          loss_type=_args.loss_type, reduction='mean')
            if _args.debug:
                with TorchDebugger():
                    train_losses['total'].backward()
            else:
                train_losses['total'].backward()
            optimizer.step()

            tbar.set_description('Loss:{:.3g}'.format(train_losses['total']))
            tbar.update(1)

            if global_step % _args.summary_every == 0:  # Save tensorboard logs
                save_summaries(train_writer, data=train_data, predicted=pred_transforms, endpoints=endpoints,
                               losses=train_losses, step=global_step)

            if global_step % _args.validate_every == 0:  # Validation loop. Also saves checkpoints
                model.eval()
                val_score = validate(val_loader, model, val_writer, global_step)
                saver.save(model, optimizer, step=global_step, score=val_score)
                model.train()

        tbar.close()

    _logger.info('Ending training. Number of steps = {}.'.format(global_step))


if __name__ == '__main__':
    main()
