In [None]:
%matplotlib notebook

import multiprocessing
import numpy as np

import sklearn.model_selection

from recova.clustering import IdentityClusteringAlgorithm
from recova.covariance import CensiCovarianceComputationAlgorithm, SamplingCovarianceComputationAlgorithm
from recova.registration_result_database import RegistrationPairDatabase
from recova.util import kullback_leibler, bat_distance

# Split into training and validation dataset

In [None]:
db = RegistrationPairDatabase('/home/dlandry/dataset/normal-estimates/')
pairs = db.registration_pairs()

pairs_training, pairs_validation = sklearn.model_selection.train_test_split(pairs, test_size=0.3)

# Reference covariances of validation

In [None]:
clustering_algo = IdentityClusteringAlgorithm()
reference_algo = SamplingCovarianceComputationAlgorithm(clustering_algo)

with multiprocessing.Pool(7) as pool:
    reference_covariances = pool.map(reference_algo.compute, pairs_validation)

In [None]:
reference_covariances = np.array(reference_covariances)

# Compute Censi covariances

In [None]:
censi_algo = CensiCovarianceComputationAlgorithm()

with multiprocessing.Pool(7) as pool:
    censi_covariances = pool.map(censi_algo.compute, pairs_validation)



In [None]:
losses = [kullback_leibler(censi_covariances[i], reference_covariances[i]) for i in range(len(reference_covariances))]
losses2 = [kullback_leibler(reference_covariances[i], censi_covariances[i]) for i in range(len(reference_covariances))]

In [None]:
losses

In [None]:
losses2

In [None]:
bat_distances = [bat_distance(reference_covariances[i], censi_covariances[i]) for i in range(len(reference_covariances))]

In [None]:
bat_distances