Skip to content

Commit

Permalink
Merge pull request #2025 from shak360/sv-cosine-fix
Browse files Browse the repository at this point in the history
[WIP] Fix small issue with layers._cosine_dist
  • Loading branch information
Bharath Ramsundar committed Jul 22, 2020
2 parents a48fbdb + 7ef4f57 commit 1b7d83b
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 10 deletions.
93 changes: 83 additions & 10 deletions deepchem/models/layers.py
Expand Up @@ -470,19 +470,92 @@ def call(self, inputs):
return h, [h, c]


def _cosine_dist(x, y):
"""Computes the inner product (cosine distance) between two tensors.
def cosine_dist(x, y):
"""Computes the inner product (cosine similarity) between two tensors.
This assumes that the two input tensors contain rows of vectors where
each column represents a different feature. The output tensor will have
elements that represent the inner product between pairs of normalized vectors
in the rows of `x` and `y`. The two tensors need to have the same number of
columns, because one cannot take the dot product between vectors of different
lengths. For example, in sentence similarity and sentence classification tasks,
the number of columns is the embedding size. In these tasks, the rows of the
input tensors would be different test vectors or sentences. The input tensors
themselves could be different batches. Using vectors or tensors of all 0s
should be avoided.
Method
------
The vectors in the input tensors are first l2-normalized such that each vector
has length or magnitude of 1. The inner product (dot product) is then taken
between corresponding pairs of row vectors in the input tensors and returned.
Examples
--------
The cosine similarity between two equivalent vectors will be 1. The cosine
similarity between two equivalent tensors (tensors where all the elements are
the same) will be a tensor of 1s. In this scenario, if the input tensors `x` and
`y` are each of shape `(n,p)`, where each element in `x` and `y` is the same, then
the output tensor would be a tensor of shape `(n,n)` with 1 in every entry.
>>> import tensorflow as tf
>>> import deepchem.models.layers as layers
>>> x = tf.ones((6, 4), dtype=tf.dtypes.float32, name=None)
>>> y_same = tf.ones((6, 4), dtype=tf.dtypes.float32, name=None)
>>> cos_sim_same = layers.cosine_dist(x,y_same)
`x` and `y_same` are the same tensor (equivalent at every element, in this
case 1). As such, the pairwise inner product of the rows in `x` and `y` will
always be 1. The output tensor will be of shape (6,6).
>>> diff = cos_sim_same - tf.ones((6, 6), dtype=tf.dtypes.float32, name=None)
>>> tf.reduce_sum(diff) == 0 # True
<tf.Tensor: shape=(), dtype=bool, numpy=True>
>>> cos_sim_same.shape
TensorShape([6, 6])
The cosine similarity between two orthogonal vectors will be 0 (by definition).
If every row in `x` is orthogonal to every row in `y`, then the output will be a
tensor of 0s. In the following example, each row in the tensor `x1` is orthogonal
to each row in `x2` because they are halves of an identity matrix.
>>> identity_tensor = tf.eye(512, dtype=tf.dtypes.float32)
>>> x1 = identity_tensor[0:256,:]
>>> x2 = identity_tensor[256:512,:]
>>> cos_sim_orth = layers.cosine_dist(x1,x2)
Each row in `x1` is orthogonal to each row in `x2`. As such, the pairwise inner
product of the rows in `x1`and `x2` will always be 0. Furthermore, because the
shape of the input tensors are both of shape `(256,512)`, the output tensor will
be of shape `(256,256)`.
>>> tf.reduce_sum(cos_sim_orth) == 0 # True
<tf.Tensor: shape=(), dtype=bool, numpy=True>
>>> cos_sim_orth.shape
TensorShape([256, 256])
Parameters
----------
x: tf.Tensor
Input Tensor
Input Tensor of shape `(n, p)`.
The shape of this input tensor should be `n` rows by `p` columns.
Note that `n` need not equal `m` (the number of rows in `y`).
y: tf.Tensor
Input Tensor
Input Tensor of shape `(m, p)`
The shape of this input tensor should be `m` rows by `p` columns.
Note that `m` need not equal `n` (the number of rows in `x`).
Returns
-------
tf.Tensor
Returns a tensor of shape `(n, m)`, that is, `n` rows by `m` columns.
Each `i,j`-th entry of this output tensor is the inner product between
the l2-normalized `i`-th row of the input tensor `x` and the
the l2-normalized `j`-th row of the output tensor `y`.
"""
denom = (backend.sqrt(backend.sum(tf.square(x)) * backend.sum(tf.square(y))) +
backend.epsilon())
return backend.dot(x, tf.transpose(y)) / denom
x_norm = tf.math.l2_normalize(x, axis=1)
y_norm = tf.math.l2_normalize(y, axis=1)
return backend.dot(x_norm, tf.transpose(y_norm))


class AttnLSTMEmbedding(tf.keras.layers.Layer):
Expand Down Expand Up @@ -572,7 +645,7 @@ def call(self, inputs):
for d in range(self.max_depth):
# Process using attention
# Eqn (4), appendix A.1 of Matching Networks paper
e = _cosine_dist(x + q, xp)
e = cosine_dist(x + q, xp)
a = tf.nn.softmax(e)
r = backend.dot(a, xp)

Expand Down Expand Up @@ -674,13 +747,13 @@ def call(self, inputs):

for d in range(self.max_depth):
# Process support xp using attention
e = _cosine_dist(z + q, xp)
e = cosine_dist(z + q, xp)
a = tf.nn.softmax(e)
# Get linear combination of support set
r = backend.dot(a, xp)

# Process test x using attention
x_e = _cosine_dist(x + p, z)
x_e = cosine_dist(x + p, z)
x_a = tf.nn.softmax(x_e)
s = backend.dot(x_a, z)

Expand Down
23 changes: 23 additions & 0 deletions deepchem/models/tests/test_layers.py
Expand Up @@ -5,6 +5,29 @@
from tensorflow.python.framework import test_util


def test_cosine_dist():
"""Test invoking cosine_dist."""
x = tf.ones((5, 4), dtype=tf.dtypes.float32, name=None)
y_same = tf.ones((5, 4), dtype=tf.dtypes.float32, name=None)
# x and y are the same tensor (equivalent at every element)
# the pairwise inner product of the rows in x and y will always be 1
# the output tensor will be of shape (5,5)
cos_sim_same = layers.cosine_dist(x, y_same)
diff = cos_sim_same - tf.ones((5, 5), dtype=tf.dtypes.float32, name=None)
assert tf.reduce_sum(diff) == 0 # True

identity_tensor = tf.eye(
512, dtype=tf.dtypes.float32) # identity matrix of shape (512,512)
x1 = identity_tensor[0:256, :]
x2 = identity_tensor[256:512, :]
# each row in x1 is orthogonal to each row in x2
# the pairwise inner product of the rows in x and y will always be 0
# the output tensor will be of shape (256,256)
cos_sim_orth = layers.cosine_dist(x1, x2)
assert tf.reduce_sum(cos_sim_orth) == 0 # True
assert all([cos_sim_orth.shape[dim] == 256 for dim in range(2)]) # True


def test_highway():
"""Test invoking Highway."""
width = 5
Expand Down
2 changes: 2 additions & 0 deletions docs/layers.rst
Expand Up @@ -99,3 +99,5 @@ another tensor. DeepChem maintains an extensive collection of layers which perfo

.. autoclass:: deepchem.models.layers.SetGather
:members:

.. autofunction:: deepchem.models.layers.cosine_dist

0 comments on commit 1b7d83b

Please sign in to comment.