In [1]:
from src.network_description import NeuralLayer, LinearLayer, InteractionLink, GraphDescription
from src.JAX.RBM.constructor import RBM

import numpy as np
import jax.numpy as jnp
import jax

from timeit import timeit

In [2]:
layer1 = NeuralLayer(name="v1", type="Gaussian", units_number=16*16, priority=0)
layer2 = NeuralLayer(name="v2", type="Gaussian", units_number=16*16, priority=0)
layer3 = NeuralLayer(name="h1", type="Binary", units_number=100, priority=1)
layer4 = LinearLayer(name="F1", units_number=32)
layer5 = NeuralLayer(name="h2", type="Binary", units_number=300, priority=1)
layer6 = LinearLayer(name="F2", units_number=64)
layer7 = NeuralLayer(name="h3", type="Binary", units_number=133, priority=2)
layer8 = NeuralLayer(name="h4", type="Binary", units_number=144, priority=1)

layers = [layer1, layer2, layer3, layer4, layer5, layer6, layer7, layer8]

In [3]:
link1 = InteractionLink(layer1, layer4, bidirectional=False)
link2 = InteractionLink(layer2, layer4)
link3 = InteractionLink(layer3, layer4)
link4 = InteractionLink(layer1, layer2, bidirectional=False)
link5 = InteractionLink(layer1, layer5, bidirectional=False)
link6 = InteractionLink(layer1, layer6, bidirectional=False)
link7 = InteractionLink(layer5, layer6)
link8 = InteractionLink(layer3, layer6)
link9 = InteractionLink(layer3, layer7)
link10 = InteractionLink(layer3, layer8, bidirectional=False)


links = [link1, link2, link3, link4, link5, link6,link7, link8, link9, link10]

In [4]:
my_graph = GraphDescription(layers, links)

In [5]:
my_rbm = RBM(my_graph)

In [6]:
placeholders = dict()
for layer in my_graph.nodes:
    if type(layer).__name__ == "NeuralLayer":
        placeholders[layer.name] = np.clip(np.random.random((layer.units_number,)),0,1)

In [7]:
batch_size = 128
batched_placeholders = dict()
for layer in my_graph.nodes:
    if type(layer).__name__ == "NeuralLayer":
        batched_placeholders[layer.name] = np.clip(np.random.random((batch_size,layer.units_number)),0,1)

In [8]:
%timeit my_rbm._get_energy(*list(placeholders.values()), weights_dict=my_rbm.weights_dict, biases_dict=my_rbm.biases_dict, sigmas_dict=my_rbm.sigmas_dict)

4.83 ms ± 650 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit my_rbm.get_energy(*list(placeholders.values()))

The slowest run took 13.76 times longer than the fastest. This could mean that an intermediate result is being cached.
931 µs ± 1.15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%timeit my_rbm._batched_get_energy(*list(batched_placeholders.values()), weights_dict=my_rbm.weights_dict, biases_dict=my_rbm.biases_dict, sigmas_dict=my_rbm.sigmas_dict)

43.1 ms ± 2.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%timeit my_rbm.batched_get_energy(*list(batched_placeholders.values()))

483 µs ± 77 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
%timeit my_rbm._get_energy_grad(*list(placeholders.values()), my_rbm.weights_dict, my_rbm.biases_dict, my_rbm.sigmas_dict)

70.2 ms ± 3.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%timeit my_rbm._batched_get_energy_grad(*list(batched_placeholders.values()), my_rbm.weights_dict, my_rbm.biases_dict, my_rbm.sigmas_dict)

90.6 ms ± 1.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
%timeit my_rbm.get_energy_grad(*list(placeholders.values()))

408 µs ± 191 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%timeit my_rbm.batched_get_energy_grad(*list(batched_placeholders.values()))

615 µs ± 115 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%timeit my_rbm.bottom_up_propagation(v1=placeholders['v1'], v2=placeholders['v2'], verbose=False)

1.38 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
%timeit my_rbm.get_positive_states(v1=placeholders['v1'], v2=placeholders['v2'])

1.41 ms ± 155 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%timeit my_rbm.get_negative_states(h3=placeholders['h3'], v1=placeholders['v1'])

2.89 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%timeit my_rbm.batched_bottom_up_propagation(batched_placeholders["v1"], batched_placeholders["v2"])

3.36 ms ± 65.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [23]:
%timeit my_rbm.batched_top_down_propagation(batched_placeholders['h3'], batched_placeholders['v1'])

3.08 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
%timeit my_rbm.batched_get_positive_states(batched_placeholders['v1'], batched_placeholders['v2'])

3.22 ms ± 68.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [27]:
%timeit my_rbm.batched_get_negative_states(batched_placeholders['h3'], batched_placeholders['v1'])

5.73 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
%timeit my_rbm.batched_get_energy(*[batched_placeholders[k]for k in ['v1', 'v2', 'h1', 'h2', 'h3', 'h4']])

472 µs ± 48.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
%timeit my_rbm.batched_get_energy(*[batched_placeholders[k] for k in ['v1', 'v2', 'h1', 'h2', 'h3', 'h4']])

413 µs ± 37 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [31]:
%timeit my_rbm.batched_get_energy_grad(*[batched_placeholders[k] for k in ['v1', 'v2', 'h1', 'h2', 'h3', 'h4']])

550 µs ± 56.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [32]:
%timeit my_rbm.CD_k(v1=placeholders['v1'], v2=placeholders['v2'], K=5)

3.81 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [39]:
my_rbm.batched_CD_k?

[0;31mSignature:[0m     
[0mmy_rbm[0m[0;34m.[0m[0mbatched_CD_k[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mv1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mv2[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mK[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mNb_stabilization_steps[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mespsilon_stabilization[0m[0;34m=[0m[0;36m1e-10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mCall signature:[0m [0mmy_rbm[0m[0;34m.[0m[0mbatched_CD_k[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           PjitFunction
[0;31mString form:[0m    <PjitFunction of <function _create_compute_grads_fns.<locals>.batched_CD_k_fn at 0x7f93e1ead430>>
[0;31mFile:[0m           ~/Workspace/Pytorch_JAX/Boltzmann_Machines/src/JAX/RBM/constructor_helpers.py
[0;31mDocstring:[0m     
** Batch

In [40]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=1)

8.87 ms ± 80.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=2)

8.73 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [42]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=5)

7.17 ms ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [43]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=10)

8.82 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=1, Nb_stabilization_steps=1)

8.77 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [45]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=1, Nb_stabilization_steps=10)

8.82 ms ± 165 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [46]:
%timeit my_rbm.batched_CD_k(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], K=10, Nb_stabilization_steps=10)

8.86 ms ± 194 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
my_rbm.test?

[0;31mSignature:[0m     
[0mmy_rbm[0m[0;34m.[0m[0mtest[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0margs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mK[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mNb_stabilization_steps[0m[0;34m=[0m[0;36m5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mespsilon_stabilization[0m[0;34m=[0m[0;36m1e-10[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mCall signature:[0m [0mmy_rbm[0m[0;34m.[0m[0mtest[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           PjitFunction
[0;31mString form:[0m    <PjitFunction of <function _create_train_and_test_fns.<locals>.test_fn at 0x7f93e1eb1700>>
[0;31mFile:[0m           ~/Workspace/Pytorch_JAX/Boltzmann_Machines/src/JAX/RBM/constructor_helpers.py

In [35]:
def mse(ground_truths, predictions):
    return [jnp.mean(jnp.sqrt(jnp.mean(jnp.square(gt['probabilities']-pred['probabilities']), 1))) for gt,pred in zip(ground_truths, predictions)]

In [36]:
def sse(ground_truths, predictions):
    return [jnp.mean(jnp.sqrt(jnp.sum(jnp.square(gt['probabilities']-pred['probabilities']), 1))) for gt,pred in zip(ground_truths, predictions)]

In [48]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=jax.jit(sse))

2.15 s ± 16.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [49]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse)

4.69 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [50]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse, K=1, Nb_stabilization_steps=1)

5.8 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [51]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse, K=10, Nb_stabilization_steps=1)

5.8 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [52]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse, K=1, Nb_stabilization_steps=10)

5.7 ms ± 115 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse, K=5, Nb_stabilization_steps=10)

5.9 ms ± 105 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [54]:
%timeit my_rbm.test(v1=batched_placeholders['v1'], v2=batched_placeholders['v2'], error_fn=sse, K=10, Nb_stabilization_steps=5)

5.78 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
