# Representation Similarity in Tensorflow

Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Document author: Youngbin-Ro (youngbin_ro@korea.ac.kr)

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

## 1. Implementation in tensorflow

In [2]:
def gram_linear(x):
    """Compute Gram (kernel) matrix for a linear kernel.

    Args:
        x: A num_examples x num_features matrix of features.

    Returns:
        A num_examples x num_examples Gram matrix of examples.
    """
    return x.dot(x.T)


def gram_linear_tf(x):
    """tensorflow version of gram_linear"""
    return tf.matmul(x, tf.transpose(x))

In [3]:
# check sameness
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
X_tf = tf.constant(X)

gram_X = gram_linear(X)
gram_X_tf = gram_linear_tf(X_tf).eval()

np.testing.assert_almost_equal(gram_X, gram_X_tf)
sess.close()

-------------------

In [4]:
def gram_rbf(x, threshold=1.0):
    """Compute Gram (kernel) matrix for an RBF kernel.

    Args:
        x: A num_examples x num_features matrix of features.
        threshold: Fraction of median Euclidean distance to use as RBF kernel
        bandwidth. (This is the heuristic we use in the paper. There are other
        possible ways to set the bandwidth; we didn't try them.)

    Returns:
        A num_examples x num_examples Gram matrix of examples.
    """
    dot_products = x.dot(x.T)
    sq_norms = np.diag(dot_products)
    sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]
    sq_median_distance = np.median(sq_distances)
    return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))


def gram_rbf_tf(x, threshold=1.0):
    """tensorflow version of gram_rbf"""
    dot_products = tf.matmul(x, tf.transpose(x))
    sq_norms = tf.matrix_diag_part(dot_products)
    sq_distances = -2 * dot_products + tf.reshape(sq_norms, [-1, 1]) + tf.reshape(sq_norms, [1, -1])
    temp = tfp.stats.percentile(sq_distances, 50., interpolation='lower')
    temp += tfp.stats.percentile(sq_distances, 50., interpolation='higher')
    sq_median_distance = temp / 2.
    return tf.math.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))

In [5]:
# check sameness
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
X_tf = tf.constant(X)

gram_rbf_X = gram_rbf(X)
gram_rbf_X_tf = gram_rbf_tf(X_tf).eval()

np.testing.assert_almost_equal(gram_rbf_X, gram_rbf_X_tf)
sess.close()

W0818 03:30:50.901623 140462725011264 deprecation.py:323] From /home/youngbin/anaconda3/envs/sle/lib/python3.6/site-packages/tensorflow_probability/python/stats/quantiles.py:608: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


--------------------------

In [6]:
def center_gram(gram, unbiased=False):
    """Center a symmetric Gram matrix.

    This is equvialent to centering the (possibly infinite-dimensional) features
    induced by the kernel before computing the Gram matrix.

    Args:
        gram: A num_examples x num_examples symmetric matrix.
        unbiased: Whether to adjust the Gram matrix in order to compute an unbiased
             estimate of HSIC. Note that this estimator may be negative.

    Returns:
        A symmetric matrix with centered columns and rows.
    """
    if not np.allclose(gram, gram.T):
        raise ValueError('Input must be a symmetric matrix.')
    gram = gram.copy()

    if unbiased:
        # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.
        # L. (2014). Partial distance correlation with methods for dissimilarities.
        # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically
        # stable than the alternative from Song et al. (2007).
        n = gram.shape[0]
        np.fill_diagonal(gram, 0)
        means = np.sum(gram, 0, dtype=np.float64) / (n - 2)
        means -= np.sum(means) / (2 * (n - 1))
        gram -= means[:, None]
        gram -= means[None, :]
        np.fill_diagonal(gram, 0)
    else:
        means = np.mean(gram, 0, dtype=np.float64)
        means -= np.mean(means) / 2
        gram -= means[:, None]
        gram -= means[None, :]

    return gram

        
def center_gram_tf(gram, unbiased=False):
    """tensorflow version of center_gram"""
    gram = tf.identity(gram)
    if unbiased:
        n = gram.get_shape().as_list()[0]
        gram = tf.matrix_set_diag(gram, tf.zeros(gram.shape[0:-1], tf.dtypes.float64))
        means = tf.math.reduce_sum(gram, 0) / (n - 2)
        means -= tf.math.reduce_sum(means) / (2 * (n - 1))
        gram -= tf.reshape(means, [-1, 1])
        gram -= tf.reshape(means, [1, -1])
        gram = tf.matrix_set_diag(gram, tf.zeros(gram.shape[0:-1], tf.dtypes.float64))
    else:
        means = tf.math.reduce_mean(gram, 0)
        means -= tf.math.reduce_mean(means) / 2
        gram -= tf.reshape(means, [-1, 1])
        gram -= tf.reshape(means, [1, -1])
    return gram

In [7]:
# check sameness
tf.reset_default_graph()
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
X_tf = tf.constant(X)

X_gram = gram_linear(X)
X_gram_tf = gram_linear_tf(X_tf)

center_gram_X = center_gram(X_gram)
center_gram_X_tf = center_gram_tf(X_gram_tf).eval()

np.testing.assert_almost_equal(center_gram_X, center_gram_X_tf)
sess.close()

-----------------------

In [8]:
def cka(gram_x, gram_y, debiased=False):
    """Compute CKA.

    Args:
        gram_x: A num_examples x num_examples Gram matrix.
        gram_y: A num_examples x num_examples Gram matrix.
        debiased: Use unbiased estimator of HSIC. CKA may still be biased.

    Returns:
        The value of CKA between X and Y.
    """
    gram_x = center_gram(gram_x, unbiased=debiased)
    gram_y = center_gram(gram_y, unbiased=debiased)

    # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or
    # n*(n-3) (unbiased variant), but this cancels for CKA.
    scaled_hsic = gram_x.ravel().dot(gram_y.ravel())

    normalization_x = np.linalg.norm(gram_x)
    normalization_y = np.linalg.norm(gram_y)
    return scaled_hsic / (normalization_x * normalization_y)

def cka_tf(gram_x, gram_y, debiased=False):
    """tensorflow version of CKA(Centered Kernel Alignment)"""
    gram_x = center_gram_tf(gram_x, unbiased=debiased)
    gram_y = center_gram_tf(gram_y, unbiased=debiased)
    
    scaled_hsic = tf.tensordot(tf.reshape(gram_x, [-1]), tf.reshape(gram_y, [-1]), axes=1)
    
    normalization_x = tf.norm(gram_x, axis=[-2, -1])
    normalization_y = tf.norm(gram_y, axis=[-2, -1])
    return scaled_hsic / (normalization_x * normalization_y)

In [9]:
# check sameness
tf.reset_default_graph()
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
Y = np.random.randn(100, 10) + X
X_tf = tf.constant(X)
Y_tf = tf.constant(Y)

CKA = cka(gram_linear(X), gram_linear(Y))
CKA_tf = cka_tf(gram_linear_tf(X_tf), gram_linear_tf(Y_tf)).eval()

np.testing.assert_almost_equal(CKA, CKA_tf)
sess.close()

-------------------

In [10]:
def _debiased_dot_product_similarity_helper(
    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):
    """Helper for computing debiased dot product similarity (i.e. linear HSIC)."""
    # This formula can be derived by manipulating the unbiased estimator from
    # Song et al. (2007).
    return (xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)
            + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))

def feature_space_linear_cka(features_x, features_y, debiased=False):
    """Compute CKA with a linear kernel, in feature space.

    This is typically faster than computing the Gram matrix when there are fewer
    features than examples.

    Args:
        features_x: A num_examples x num_features matrix of features.
        features_y: A num_examples x num_features matrix of features.
        debiased: Use unbiased estimator of dot product similarity. CKA may still be
          biased. Note that this estimator may be negative.

    Returns:
        The value of CKA between X and Y.
    """
    features_x = features_x - np.mean(features_x, 0, keepdims=True)
    features_y = features_y - np.mean(features_y, 0, keepdims=True)

    dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2
    normalization_x = np.linalg.norm(features_x.T.dot(features_x))
    normalization_y = np.linalg.norm(features_y.T.dot(features_y))

    if debiased:
        n = features_x.shape[0]
        # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.
        sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)
        sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)
        squared_norm_x = np.sum(sum_squared_rows_x)
        squared_norm_y = np.sum(sum_squared_rows_y)

        dot_product_similarity = _debiased_dot_product_similarity_helper(
            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
            squared_norm_x, squared_norm_y, n)
        normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,
            squared_norm_x, squared_norm_x, n))
        normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(
            normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,
            squared_norm_y, squared_norm_y, n))

    return dot_product_similarity / (normalization_x * normalization_y)


def _debiased_dot_product_similarity_helper_tf(
    xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):
    """tensorflow version of _debiased_dot_product_similarity_helper"""
    return (xty - n / (n - 2.) * tf.tensordot(sum_squared_rows_x, sum_squared_rows_y, axes=1)
            + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))

def feature_space_linear_cka_tf(features_x, features_y, debiased=False):
    """tensorflow version of feature_space_linear_cka"""
    features_x = features_x - tf.math.reduce_mean(features_x, 0, keepdims=True)
    features_y = features_y - tf.math.reduce_mean(features_y, 0, keepdims=True)

    dot_product_similarity = tf.pow(tf.norm(tf.matmul(tf.transpose(features_x), features_y)), 2)
    normalization_x = tf.norm(tf.matmul(tf.transpose(features_x), features_x), axis=[-2, -1])
    normalization_y = tf.norm(tf.matmul(tf.transpose(features_y), features_y), axis=[-2, -1])

    if debiased:
        n = features_x.get_shape().as_list()[0]
        sum_squared_rows_x = tf.einsum('ij,ij->i', features_x, features_x)
        sum_squared_rows_y = tf.einsum('ij,ij->i', features_y, features_y)
        squared_norm_x = tf.math.reduce_sum(sum_squared_rows_x)
        squared_norm_y = tf.math.reduce_sum(sum_squared_rows_y)

        dot_product_similarity = _debiased_dot_product_similarity_helper_tf(
            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,
            squared_norm_x, squared_norm_y, n)
        normalization_x = tf.math.sqrt(_debiased_dot_product_similarity_helper_tf(
            tf.pow(normalization_x, 2), sum_squared_rows_x, sum_squared_rows_x,
            squared_norm_x, squared_norm_x, n))
        normalization_y = tf.math.sqrt(_debiased_dot_product_similarity_helper_tf(
            tf.pow(normalization_y, 2), sum_squared_rows_y, sum_squared_rows_y,
            squared_norm_y, squared_norm_y, n))

    return dot_product_similarity / (normalization_x * normalization_y)

In [11]:
# check sameness
tf.reset_default_graph()
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
Y = np.random.randn(100, 10) + X
X_tf = tf.constant(X)
Y_tf = tf.constant(Y)

cka_from_features_debiased = feature_space_linear_cka(X, Y, debiased=True)
cka_from_features_debiased_tf = feature_space_linear_cka_tf(X_tf, Y_tf, debiased=True).eval()

np.testing.assert_almost_equal(cka_from_features_debiased, cka_from_features_debiased_tf)
sess.close()

----------------

In [12]:
def cca(features_x, features_y):
    """Compute the mean squared CCA correlation (R^2_{CCA}).

    Args:
        features_x: A num_examples x num_features matrix of features.
        features_y: A num_examples x num_features matrix of features.

    Returns:
        The mean squared CCA correlations between X and Y.
    """
    qx, _ = np.linalg.qr(features_x)  # Or use SVD with full_matrices=False.
    qy, _ = np.linalg.qr(features_y)
    return np.linalg.norm(qx.T.dot(qy)) ** 2 / min(features_x.shape[1], features_y.shape[1])


def cca_tf(features_x, features_y):
    """tensorflow version of cca(Canonical Correlation Analysis)"""
    qx, _ = tf.linalg.qr(features_x)
    qy, _ = tf.linalg.qr(features_y)
    dimx = features_x.get_shape().as_list()[1]
    dimy = features_y.get_shape().as_list()[1]
    demoninator = tf.dtypes.cast(tf.math.minimum(dimx, dimy), tf.float64)
    return tf.pow(tf.norm(tf.matmul(tf.transpose(qx), qy)), 2) / demoninator

In [13]:
# check sameness
tf.reset_default_graph()
np.random.seed(1337)
sess = tf.InteractiveSession()

X = np.random.randn(100, 10)
Y = np.random.randn(100, 10) + X
X_tf = tf.constant(X)
Y_tf = tf.constant(Y)

CCA = cca(X, Y)
CCA_tf = cca_tf(X_tf, Y_tf).eval()

np.testing.assert_almost_equal(CCA, CCA_tf)
sess.close()

---------------

## 2. Tutorial

In [14]:
np.random.seed(1337)
X = np.random.randn(100, 10)
Y = np.random.randn(100, 10) + X

tf.reset_default_graph()
sess = tf.InteractiveSession()

X_tf = tf.constant(X)
Y_tf = tf.constant(Y)

In [15]:
cka_from_examples = cka_tf(gram_linear_tf(X_tf), gram_linear_tf(Y_tf)).eval()
cka_from_features = feature_space_linear_cka_tf(X_tf, Y_tf).eval()

print('Linear CKA from Examples: {:.5f}'.format(cka_from_examples))
print('Linear CKA from Features: {:.5f}'.format(cka_from_features))
np.testing.assert_almost_equal(cka_from_examples, cka_from_features)

Linear CKA from Examples: 0.55761
Linear CKA from Features: 0.55761


In [16]:
rbf_cka = cka_tf(gram_rbf_tf(X_tf, 0.5), gram_rbf_tf(Y_tf, 0.5)).eval()
print('RBF CKA: {:.5f}'.format(rbf_cka))

RBF CKA: 0.65483


In [17]:
cka_from_examples_debiased = cka_tf(gram_linear_tf(X_tf), gram_linear_tf(Y_tf), debiased=True).eval()
cka_from_features_debiased = feature_space_linear_cka_tf(X_tf, Y_tf, debiased=True).eval()

print('Linear CKA from Examples (Debiased): {:.5f}'.format(cka_from_examples_debiased))
print('Linear CKA from Features (Debiased): {:.5f}'.format(cka_from_features_debiased))
np.testing.assert_almost_equal(cka_from_examples_debiased, cka_from_features_debiased)

Linear CKA from Examples (Debiased): 0.51346
Linear CKA from Features (Debiased): 0.51346


In [18]:
transform = np.random.randn(10, 10)
transform_tf = tf.constant(transform)
_, orthogonal_transform_tf = tf.linalg.eigh(gram_linear_tf(tf.transpose(transform_tf)))

# CKA is invariant only to orthogonal transformations.
np.testing.assert_almost_equal(
    feature_space_linear_cka_tf(X_tf, Y_tf).eval(),
    feature_space_linear_cka_tf(tf.matmul(X_tf, orthogonal_transform_tf), Y_tf).eval())
np.testing.assert_(not np.isclose(
    feature_space_linear_cka_tf(X_tf, Y_tf).eval(),
    feature_space_linear_cka_tf(tf.matmul(X, transform_tf), Y_tf).eval()))

# CCA is invariant to any invertible linear transform.
np.testing.assert_almost_equal(cca_tf(X_tf, Y_tf).eval(),
                               cca_tf(tf.matmul(X_tf, orthogonal_transform_tf), Y_tf).eval())
np.testing.assert_almost_equal(cca_tf(X_tf, Y_tf).eval(),
                               cca_tf(tf.matmul(X_tf, transform_tf), Y_tf).eval())

# Both CCA and CKA are invariant to isotropic scaling.
np.testing.assert_almost_equal(cca_tf(X_tf, Y_tf).eval(), cca_tf(X_tf * 1.337, Y_tf).eval())
np.testing.assert_almost_equal(
    feature_space_linear_cka_tf(X_tf, Y_tf).eval(),
    feature_space_linear_cka_tf(X_tf * 1.337, Y_tf).eval())