Skip to content

Commit

Permalink
Fix derivatives in PyRbm
Browse files Browse the repository at this point in the history
  • Loading branch information
gcarleo committed Jun 26, 2019
1 parent 5d27b3a commit 3909267
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions Examples/CustomMachine/rbm.py
Expand Up @@ -107,6 +107,7 @@ def _set_parameters(self, p):
if self._b is not None:
self._b[:] = p[i : i + self._b.size]
i += self._b.size

self._w[:] = p[i : i + self._w.size].reshape(self._w.shape, order="C")

def log_val(self, x):
Expand Down Expand Up @@ -145,8 +146,9 @@ def der_log(self, x):
i += self._b.size

out = grad[i : i + self._w.size]
out.shape = (x.size, tanh_stuff.size)
_np.outer(x, tanh_stuff, out=out)
out.shape = (tanh_stuff.size, x.size)
_np.outer(tanh_stuff, x, out=out)

return grad

def _is_holomorphic(self):
Expand Down

0 comments on commit 3909267

Please sign in to comment.