Skip to content

Commit

Permalink
Add losses for VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
hsjang001205 committed Sep 24, 2020
1 parent a81fe7f commit 5b9ac4c
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 44 deletions.
24 changes: 14 additions & 10 deletions deepchem/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,23 +238,23 @@ class VAE_ELBO(Loss):
tensor([0.7017, 0.7624], dtype=torch.float64)
"""

def _compute_tf_loss(self, logvar, mu, x, reconstruction_x, kl_scale = 1):
def _compute_tf_loss(self, logvar, mu, x, reconstruction_x, kl_scale=1):
import tensorflow as tf
x, reconstruction_x = _make_tf_shapes_consistent(x, reconstruction_x)
x, reconstruction_x = _ensure_float(x, reconstruction_x)
BCE = tf.keras.losses.binary_crossentropy(x, reconstruction_x)
KLD = VAE_KLDivergence()._compute_tf_loss(logvar, mu)
return BCE + kl_scale*KLD
return BCE + kl_scale * KLD

def _create_pytorch_loss(self):
import torch
bce = torch.nn.BCELoss(reduction='none')

def loss(logvar, mu, x, reconstruction_x, kl_scale = 1):
def loss(logvar, mu, x, reconstruction_x, kl_scale=1):
x, reconstruction_x = _make_pytorch_shapes_consistent(x, reconstruction_x)
BCE = torch.mean(bce(reconstruction_x, x), dim=-1)
KLD = (VAE_KLDivergence()._create_pytorch_loss())(logvar, mu)
return BCE + kl_scale*KLD
return BCE + kl_scale * KLD

return loss

Expand Down Expand Up @@ -291,14 +291,18 @@ def _compute_tf_loss(self, logvar, mu):
import tensorflow as tf
logvar, mu = _make_tf_shapes_consistent(logvar, mu)
logvar, mu = _ensure_float(logvar, mu)
return 0.5 * tf.reduce_mean(tf.square(mu) + tf.square(logvar) - tf.math.log(1e-20 + tf.square(logvar)) - 1,-1)
return 0.5 * tf.reduce_mean(
tf.square(mu) + tf.square(logvar) -
tf.math.log(1e-20 + tf.square(logvar)) - 1, -1)

def _create_pytorch_loss(self):
import torch

def loss(logvar, mu):
logvar, mu = _make_pytorch_shapes_consistent(logvar, mu)
return 0.5 * torch.mean(torch.square(mu) + torch.square(logvar) - torch.log(1e-20 + torch.square(logvar)) - 1,-1)
return 0.5 * torch.mean(
torch.square(mu) + torch.square(logvar) -
torch.log(1e-20 + torch.square(logvar)) - 1, -1)

return loss

Expand Down Expand Up @@ -333,17 +337,17 @@ def _compute_tf_loss(self, inputs):
import tensorflow as tf
#extended one of probabilites to binary distribution
if inputs.shape[-1] == 1:
inputs = tf.concat([inputs,1-inputs], axis = -1)
return tf.reduce_mean(-inputs*tf.math.log(1e-20+inputs), -1)
inputs = tf.concat([inputs, 1 - inputs], axis=-1)
return tf.reduce_mean(-inputs * tf.math.log(1e-20 + inputs), -1)

def _create_pytorch_loss(self):
import torch

def loss(inputs):
#extended one of probabilites to binary distribution
if inputs.shape[-1] == 1:
inputs = torch.cat((inputs,1-inputs), dim = -1)
return torch.mean(-inputs*torch.log(1e-20+inputs), -1)
inputs = torch.cat((inputs, 1 - inputs), dim=-1)
return torch.mean(-inputs * torch.log(1e-20 + inputs), -1)

return loss

Expand Down
107 changes: 73 additions & 34 deletions deepchem/models/tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,70 +202,109 @@ def test_sparse_softmax_cross_entropy_pytorch(self):
def test_VAE_ELBO_tf(self):
"""."""
loss = losses.VAE_ELBO()
logvar = tf.constant([[1.0,1.3],[0.6,1.2]])
mu = tf.constant([[0.2,0.7],[1.2,0.4]])
x = tf.constant([[0.9,0.4,0.8],[0.3,0,1]])
reconstruction_x = tf.constant([[0.8,0.3,0.7],[0.2,0,0.9]])
logvar = tf.constant([[1.0, 1.3], [0.6, 1.2]])
mu = tf.constant([[0.2, 0.7], [1.2, 0.4]])
x = tf.constant([[0.9, 0.4, 0.8], [0.3, 0, 1]])
reconstruction_x = tf.constant([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
result = loss._compute_tf_loss(logvar, mu, x, reconstruction_x).numpy()
expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1])
-np.mean(np.array([0.9,0.4,0.8])*np.log([0.8,0.3,0.7])+np.array([0.1,0.6,0.2])*np.log([0.2,0.7,0.3])),
0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])
-np.mean(np.array([0.3,0,1])*np.log([0.2,1e-20,0.9])+np.array([0.7,1,0])*np.log([0.8,1,0.1]))]
expected = [
0.5 * np.mean([
0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
]) - np.mean(
np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
0.5 * np.mean([
1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
]) - np.mean(
np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
]
assert np.allclose(expected, result)

@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_VAE_ELBO_pytorch(self):
"""."""
loss = losses.VAE_ELBO()
logvar = torch.tensor([[1.0,1.3],[0.6,1.2]])
mu = torch.tensor([[0.2,0.7],[1.2,0.4]])
x = torch.tensor([[0.9,0.4,0.8],[0.3,0,1]])
reconstruction_x = torch.tensor([[0.8,0.3,0.7],[0.2,0,0.9]])
result = loss._create_pytorch_loss()(logvar, mu, x, reconstruction_x).numpy()
expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1])
-np.mean(np.array([0.9,0.4,0.8])*np.log([0.8,0.3,0.7])+np.array([0.1,0.6,0.2])*np.log([0.2,0.7,0.3])),
0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])
-np.mean(np.array([0.3,0,1])*np.log([0.2,1e-20,0.9])+np.array([0.7,1,0])*np.log([0.8,1,0.1]))]
logvar = torch.tensor([[1.0, 1.3], [0.6, 1.2]])
mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
x = torch.tensor([[0.9, 0.4, 0.8], [0.3, 0, 1]])
reconstruction_x = torch.tensor([[0.8, 0.3, 0.7], [0.2, 0, 0.9]])
result = loss._create_pytorch_loss()(logvar, mu, x,
reconstruction_x).numpy()
expected = [
0.5 * np.mean([
0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
]) - np.mean(
np.array([0.9, 0.4, 0.8]) * np.log([0.8, 0.3, 0.7]) +
np.array([0.1, 0.6, 0.2]) * np.log([0.2, 0.7, 0.3])),
0.5 * np.mean([
1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
]) - np.mean(
np.array([0.3, 0, 1]) * np.log([0.2, 1e-20, 0.9]) +
np.array([0.7, 1, 0]) * np.log([0.8, 1, 0.1]))
]
assert np.allclose(expected, result)

@unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
def test_VAE_KLDivergence_tf(self):
"""."""
loss = losses.VAE_KLDivergence()
logvar = tf.constant([[1.0,1.3],[0.6,1.2]])
mu = tf.constant([[0.2,0.7],[1.2,0.4]])
logvar = tf.constant([[1.0, 1.3], [0.6, 1.2]])
mu = tf.constant([[0.2, 0.7], [1.2, 0.4]])
result = loss._compute_tf_loss(logvar, mu).numpy()
expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1]),
0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])]
expected = [
0.5 * np.mean([
0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
]), 0.5 * np.mean([
1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
])
]
assert np.allclose(expected, result)

@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_VAE_KLDivergence_pytorch(self):
"""."""
loss = losses.VAE_KLDivergence()
logvar = torch.tensor([[1.0,1.3],[0.6,1.2]])
mu = torch.tensor([[0.2,0.7],[1.2,0.4]])
logvar = torch.tensor([[1.0, 1.3], [0.6, 1.2]])
mu = torch.tensor([[0.2, 0.7], [1.2, 0.4]])
result = loss._create_pytorch_loss()(logvar, mu).numpy()
expected = [0.5 * np.mean([0.04+1.0-np.log(1e-20+1.0)-1, 0.49+1.69 - np.log(1e-20 +1.69) - 1]),
0.5 * np.mean([1.44+0.36-np.log(1e-20+0.36)-1, 0.16+1.44 - np.log(1e-20 +1.44) - 1])]
expected = [
0.5 * np.mean([
0.04 + 1.0 - np.log(1e-20 + 1.0) - 1,
0.49 + 1.69 - np.log(1e-20 + 1.69) - 1
]), 0.5 * np.mean([
1.44 + 0.36 - np.log(1e-20 + 0.36) - 1,
0.16 + 1.44 - np.log(1e-20 + 1.44) - 1
])
]
assert np.allclose(expected, result)

@unittest.skipIf(not has_tensorflow, 'TensorFlow is not installed')
def test_ShannonEntropy_tf(self):
"""."""
loss = losses.ShannonEntropy()
inputs = tf.constant([[0.7,0.3],[0.9,0.1]])
inputs = tf.constant([[0.7, 0.3], [0.9, 0.1]])
result = loss._compute_tf_loss(inputs).numpy()
expected = [-np.mean([0.7*np.log(0.7),0.3*np.log(0.3)]),
-np.mean([0.9*np.log(0.9),0.1*np.log(0.1)])]
expected = [
-np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
-np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
]
assert np.allclose(expected, result)

@unittest.skipIf(not has_pytorch, 'PyTorch is not installed')
def test_ShannonEntropy_pytorch(self):
"""."""
loss = losses.ShannonEntropy()
inputs = torch.tensor([[0.7,0.3],[0.9,0.1]])
inputs = torch.tensor([[0.7, 0.3], [0.9, 0.1]])
result = loss._create_pytorch_loss()(inputs).numpy()
expected = [-np.mean([0.7*np.log(0.7),0.3*np.log(0.3)]),
-np.mean([0.9*np.log(0.9),0.1*np.log(0.1)])]
assert np.allclose(expected, result)
expected = [
-np.mean([0.7 * np.log(0.7), 0.3 * np.log(0.3)]),
-np.mean([0.9 * np.log(0.9), 0.1 * np.log(0.1)])
]
assert np.allclose(expected, result)

0 comments on commit 5b9ac4c

Please sign in to comment.