In [1]:
import numpy as np

In [2]:
class RBM:
    def __init__(self, a, b, W):
        self.a = a.copy()
        self.b = b.copy()
        self.W = W.copy()
        self.m, self.n = W.shape
        self.eps = 1e-4
        
    def set_train_data(self, train_data):
        self.train_data = train_data
    
    def run(self):
        for _ in range(1000000):
            idx = np.random.choice(self.train_data.shape[0])
            v = self.train_data[idx]
            dW, da, db = self.contrastive_divergence(v)
            self.W += dW
            self.a += da
            self.b += db
    
    def E(self, v, h):
        return -self.a@v-self.b@h-v@self.W@h
    
    def logistic(self, x):
        return 1/(1+np.exp(-x))
    
    def P(self, of, x):
        if of == "v":
            return self.logistic(self.a + self.W@x)
        elif of == "h":
            return self.logistic(self.b + self.W.T@x)
        else:
            pass
        
    def contrastive_divergence(self, v):
        h = (np.random.random(self.n) > self.P("h", v)).astype(float)
        v_hat = (np.random.random(self.m) > self.P("v", h)).astype(float)
        h_hat = (np.random.random(self.n) > self.P("h", v_hat)).astype(float)
        positive_gradient = np.outer(v, h)
        negative_gradient = np.outer(v_hat, h_hat)
        dW = self.eps*(positive_gradient - negative_gradient)
        da = self.eps*(v-v_hat)
        db = self.eps*(h-h_hat)
        return dW, da, db

In [3]:
m = 3
n = 4
a = np.random.random(m)
b = np.random.random(n)
W = np.random.random([m, n])
print("a: ", a);print()
print("b: ", b);print()
print("W: ", W);print()

a:  [0.08148159 0.42050556 0.64306478]

b:  [0.6327993  0.04476601 0.62810601 0.79006776]

W:  [[0.38442171 0.59714095 0.69381665 0.21918121]
 [0.04177668 0.27251822 0.93705166 0.92497118]
 [0.57520265 0.13598719 0.18152514 0.04186073]]



In [4]:
train_data = (np.random.random([100, m]) > 0.5).astype(float)
model = RBM(a, b, W)
model.set_train_data(train_data)

In [5]:
model.run()
print(model.a)
print(model.b)
print(model.W)

[2098.31648156 2100.19650553 2798.14006489]
[-1827.62520066 -1828.25223395 -1827.81989395 -1827.4459322 ]
[[ 818.4904217   818.80314094  818.56081664  818.2891812 ]
 [1006.65777667 1007.0285182  1006.77905164 1006.59897116]
 [1011.59920264 1011.93098717 1011.67352513 1011.49886071]]


In [6]:
model.run()
print(model.a)
print(model.b)
print(model.W)

[4198.66148199 4199.90350596 5595.80406546]
[-3656.95220098 -3657.56823427 -3657.12589427 -3656.69493252]
[[1638.13842168 1638.46214092 1638.22981662 1638.01518118]
 [2016.22577664 2016.60151818 2016.32605162 2016.17197114]
 [2021.36320261 2021.70098715 2021.4795251  2021.33586069]]
