In [26]:
import numpy as np
from scipy.sparse import csc_matrix, csr_matrix, kron, identity
from scipy.linalg import expm

In [120]:
X = csr_matrix(([1, 1], ([0, 1], [1, 0])))
Z = csr_matrix(([1, -1], ([0, 1], [0, 1])))

def σ(i, N, A):
    if i != 0 and i != N - 1:
        return kron(kron(identity(2 ** i), A), identity(2 ** (N - i - 1)))
    if i == 0:
        return kron(A, identity(2 ** (N - 1)))
    if i == N - 1:
        return kron(identity(2 ** (N - 1)), A)

def compute_H(a, b, W, Γ):
    n_visible = len(a)
    n_hidden = len(b)
    N = n_visible + n_hidden
    H = csr_matrix((2 ** N, 2 ** N), dtype=np.float64)
    for i in range(n_visible):
        H -= a[i] * σ(i, N, Z) + Γ * σ(i, N, X)
    for j in range(n_hidden):
        H -= b[j] * σ(n_visible + j, N, Z) + Γ * σ(n_visible + j, N, X)
    for i in range(n_visible):
        for j in range(n_hidden):
            H -= W[i, j] * σ(i, N, Z) * σ(n_visible + j, N, Z)
            
    return H

In [135]:
n_visible = 6
n_hidden = 6
a = np.random.normal(0, 0.01, n_visible)
b = np.random.normal(0, 0.01, n_hidden)
W = np.random.normal(0, 0.01, (n_visible, n_hidden))
m = np.max([np.abs(a).max(), np.abs(b).max(), np.abs(W).max()])
a /= m
b /= m
W /= m
Γ = 1

In [164]:
from time import time
t = time()
H = compute_H(a, b, W, Γ).toarray()
ρ = expm(-H)
ρ /= np.trace(ρ)
print(time() - t)

6.285596132278442


In [171]:
H_diag = np.diag(H)
ρ_diag = np.diag(ρ)
sorted_indices = H_diag.argsort()
print(ρ_diag.argmax())
print(H_diag.argmin())

3763
3763


In [172]:
for i in range(len(H_diag)):
    i = sorted_indices[i]
    print(H_diag[i], ρ_diag[i])

-8.336970190056306 0.01560022718282043
-8.171207517011126 0.014035300747425069
-7.478870475250008 0.009546293250116574
-7.199158493396859 0.006692298961639161
-7.058858651303219 0.005362684130981819
-7.051595233622321 0.007656669442213688
-6.874250446084109 0.006024981288602924
-6.781156909919129 0.004900591728729423
-6.771965983140585 0.004295181130233355
-6.7604831034804524 0.006115993984488773
-6.7150266679808235 0.0040323757828279756
-6.674530078814985 0.005098344976729401
-6.638274682902199 0.004084261367055986
-6.5290971958307855 0.004149409836766756
-6.496781828392211 0.004336338726669334
-6.473678397463961 0.005004215566303377
-6.472421060951548 0.004017743065335864
-6.465597509950475 0.003328563028805135
-6.41751232819363 0.0042113310857203315
-6.333207861852766 0.005019282092500214
-6.321689022661935 0.004224108023451611
-6.320548480779934 0.0035844788810293748
-6.302751229768721 0.003740670162350025
-6.2783310176388705 0.00319273249004908
-6.232217397141863 0.003669837670773

In [125]:
N = 2
for i in range(2 ** N):
    print(bin(i)[2:].split(""))

ValueError: empty separator