In [17]:
import sklearn.neighbors
import sklearn.gaussian_process
import numpy as np
sklearn.neighbors.ball_tree.VALID_METRICS.append("KernelDistance")

In [107]:
class KernelDistance(sklearn.neighbors.dist_metrics.PyFuncDistance):
    def __init__(self, kernel):
        self.kernel = kernel
        super().__init__(self.dist)
    def dist(self, x1, x2):
        return self.rdist_to_dist(self.rdist(x1,x2))
    def rdist(self, x1, x2):
        x1 = np.atleast_2d(x1)
        x2 = np.atleast_2d(x2)
        rah = self.kernel.diag(x1).reshape((-1,1)) - 2*self.kernel(x1,x2) + self.kernel.diag(x2).reshape((1,-1))
        return rah
    def rdist_to_dist(self, rdist):
        return np.sqrt(rdist)
    def dist_to_rdist(self, dist):
        return dist ** 2

In [None]:
class TruncatedRBFKernel(sklearn.gaussian_process.kernels.RBF):
    def __init__(self, 
                 length_scale=1.0, length_scale_bounds=(1e-05, 100000.0),
                 a = -np.inf, a_bounds=(-np.inf, np.inf),
                 b = np.inf, b_bounds = (-np.inf, np.inf)                
                ):
        super().__init__(length_scale=length_scale, length_scale_bounds=length_scale_bounds)
        

In [None]:
class TruncatedConstantKernel(sklearn.gaussian_process.kernels.ConstantKernel):
    def __init__(self,
                constant_value=1.0, constant_value_bounds=(1e-05, 100000.0),
                a = -np.inf, a_bounds=(-np.inf, np.inf),
                b = np.inf, b_bounds = (-np.inf, np.inf)                
                ):

In [108]:
dm = KernelDistance(sklearn.gaussian_process.kernels.RBF())

In [109]:
dm.dist(np.array([[5]]),np.array([[7]]))

array([[ 1.31503971]])

In [110]:
x = np.array([[1,2],[3,4],[5,6]])

In [111]:
a = sklearn.neighbors.BallTree(x, metric = dm)

In [112]:
dm.pairwise(x)

array([[ 0.        ,  1.4012026 ,  1.41421348],
       [ 1.4012026 ,  0.        ,  1.4012026 ],
       [ 1.41421348,  1.4012026 ,  0.        ]])

In [113]:
a.query(np.array([[1.5,2.5]]), k=2)

querying!
4
6
7
(array([[ 0.66513039,  1.33761039]]), array([[0, 1]]))


(array([[ 0.81555526,  1.15655107]]), array([[0, 1]]))

In [114]:
dm.dist(np.array([[1.5,2.5]]),x)

array([[ 0.66513039,  1.33761039,  1.41421018]])

In [115]:
dm.dist(x,np.array([[1.5,2.5],[2.5,3.5]]))

array([[ 0.66513039,  1.33761039],
       [ 1.33761039,  0.66513039],
       [ 1.41421018,  1.41284787]])

In [116]:
dm.kernel(x)

array([[  1.00000000e+00,   1.83156389e-02,   1.12535175e-07],
       [  1.83156389e-02,   1.00000000e+00,   1.83156389e-02],
       [  1.12535175e-07,   1.83156389e-02,   1.00000000e+00]])

In [117]:
dm.kernel.diag(x).reshape((-1,1)) + dm.kernel(x,np.array([[1.5,2.5],[2.5,3.5]]))

array([[ 1.77880078,  1.10539922],
       [ 1.10539922,  1.77880078],
       [ 1.00000479,  1.00193045]])