# Diary for CRBM implementation



This notebook shows the parts from `crbm.py` with some details

In [987]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np 
import pandas as pd
import numexpr as ne
import sklearn
from sklearn import preprocessing

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload




##### read data from  `../Datasets/motion.mat`

More data from human motion captures can be found here:

http://people.csail.mit.edu/ehsu/work/sig05stf/

In [1002]:
from scipy.io import loadmat  # this is the SciPy module that loads mat-files
data = loadmat('../Datasets/motion.mat')

In [1003]:
data.keys()

dict_keys(['__header__', '__version__', '__globals__', 'skel', 'Motion'])

In [1004]:
X1 = data["Motion"][0][0]
X2 = data["Motion"][0][1]
X3 = data["Motion"][0][2]

In [1005]:
X1.shape, X2.shape, X2.shape

((1750, 108), (1040, 108), (1040, 108))

Several features are 0

In [1006]:
#(X1 - np.min(X1,0)) / (np.max(X1,0) - np.min(X1,0))* (np.min(X1,0) != 0)

In [1007]:
X1[:,3].min(), X1[:,3].max(), X1.shape

(-1049.559326171875, 490.09881591796881, (1750, 108))

In [1008]:
n_features = X1.shape[1]
for f in range(n_features):
    max_val, min_val =  X1[:, f].max(), X1[:, f].min()
    if (max_val - min_val) != 0:
        X1[:, f] = ( X1[:, f]  - min_val)  / (max_val - min_val)
    else:
        #print(f, max_val, max_val)
        X1[:, f] = ( X1[:, f]  - min_val) # / (max_val - min_val)


In [1009]:
X1.min(), X1.max()

(0.0, 1.0)

### CRBM class

In [1010]:
a = 10
b=2
np.zeros([a,b])

array([[ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.]])

In [1011]:
class CRBM:
    def __init__(self, n_vis, n_hid, n_cond, seed=42, sigma=0.3, monitor_time=True):

        self.previous_xneg = None
        np.random.seed(seed)

        W = np.random.normal(0, sigma, [n_hid, n_vis])   # vis to hid
        A = np.random.normal(0, sigma, [n_vis, n_vis * n_cond])  # cond to vis
        B = np.random.normal(0, sigma, [n_hid, n_vis * n_cond])  # cond to hid

        v_bias = np.zeros([n_vis, 1]) 
        h_bias = np.zeros([n_hid, 1])

        dy_v_bias = np.zeros([n_vis, 1])
        dy_h_bias = np.zeros([n_hid, 1])

        self.W = np.array(W, dtype='float32')
        self.A = np.array(A, dtype='float32')
        self.B = np.array(B, dtype='float32')
        self.v_bias = v_bias
        self.h_bias = h_bias
        self.dy_v_bias = dy_v_bias
        self.dy_h_bias = dy_h_bias
        
        self.n_vis = n_vis
        self.n_hid = n_hid
        self.n_his = n_cond
        
        self.num_epochs_trained = 0
        self.lr = 0
        self.monitor_time = monitor_time

In [1012]:
crbm = CRBM(n_vis=108, n_hid=256, n_cond=20, seed=123, sigma = 0.3)

In [1013]:
crbm.W.shape, crbm.A.shape, crbm.B.shape

((256, 108), (108, 2160), (256, 2160))

### Auxiliary functions

In [1014]:

def sig(v):
    return ne.evaluate("1/(1 + exp(-v))")


def split_vis(crbm: CRBM, vis: np.ndarray):
    n_his = vis.shape[0]
    cond = vis[0:(n_his-1), :].T
    x = vis[[n_his-1],:].T
    
    assert  crbm.n_vis == x.shape[0] and crbm.n_vis == cond.shape[0], \
            "crbm.n_vis = {}, is different from x.shape[0] = {} or cond.shape[0] = {}".format(crbm.n_vis,
                                                                                                  x.shape[0],
                                                                                                  cond.shape[0])
    return x, cond


def dynamic_biases_up(crbm: CRBM, cond: np.ndarray):
    crbm.dy_v_bias = np.dot(crbm.A, cond) + crbm.v_bias 
    crbm.dy_h_bias = np.dot(crbm.B, cond) + crbm.h_bias
        
        
def hid_means(crbm: CRBM, vis: np.ndarray):
    p = np.dot(crbm.W, vis) + crbm.dy_h_bias
    return sig(p)
    
    
def vis_means(crbm: CRBM, hid: np.ndarray):   
    p = np.dot(crbm.W.T, hid) + crbm.dy_v_bias
    return sig(p)


In [1015]:
X = X1[0:21,:]
X.shape, crbm.n_his

((21, 108), 20)

In [1016]:
vis, cond = split_vis(crbm, X)
vis.shape, cond.shape

((108, 1), (108, 20))

### Compute gradients

```
function gibbs(rbm::AbstractRBM, vis::Mat; n_times=1)
    v_pos = vis
    h_pos = sample_hiddens(rbm, v_pos)
    v_neg = sample_visibles(rbm, h_pos)
    h_neg = sample_hiddens(rbm, v_neg)
    for i=1:n_times-1
        v_neg = sample_visibles(rbm, h_neg)
        h_neg = sample_hiddens(rbm, v_neg)
    end
    return v_pos, h_pos, v_neg, h_neg
end
```

In [1017]:

def sample_hiddens(crbm: CRBM, v: np.ndarray, cond: np.ndarray):
    h_mean = sig( np.dot(crbm.W, v) +  np.dot(crbm.B, cond) + crbm.h_bias)
    h_sample = h_mean > np.random.random(h_mean.shape).astype(np.float32)
    return h_sample, h_mean

def sample_visibles(crbm: CRBM, h: np.ndarray, cond: np.ndarray):
    """
    Notice we don't sample or put the sigmoid here since visible units are Gaussian
    """
    v_mean = np.dot(crbm.W.T, h) + np.dot(crbm.A, cond) + crbm.v_bias  
    return v_mean


In [1018]:
def CDK(crbm, vis,cond, K=1):
    v_pos = vis
    h_pos, h_pos_p = sample_hiddens(crbm, v_pos, cond)
    v_neg          = sample_visibles(crbm, h_pos, cond)
    h_neg, h_neg_p = sample_hiddens(crbm, v_neg, cond)

    for i in range(K-1):
        v_neg           = sample_visibles(crbm, h_neg, cond)
        h_neg, h_neg_p  = sample_hiddens(crbm, v_neg, cond)
    
    return v_pos, h_pos_p , v_neg, h_neg_p

In [1019]:
def compute_gradient(crbm, X):
    """
    Computes an approximated gradient of the likelihod (for a given minibatch X) with
    respect to the parameters. 
    """
    vis, cond = split_vis(crbm, X)
    cond = np.array([cond.flatten()]).T
    
    v_pos, h_pos, v_neg, h_neg = CDK(crbm, vis, cond)
    n_obs = vis.shape[1]
    
    # for a sigle observation:  dW = h * v^T - h_hat * v_hat^T
    dW = ( np.dot(h_pos, v_pos.T) - np.dot(h_neg, v_neg.T) ) * (1./n_obs)
    dA = ( np.dot(v_pos, cond.T)  - np.dot(v_neg, cond.T)  ) * (1./n_obs)
    dB = ( np.dot(h_pos, cond.T)  - np.dot(h_neg, cond.T)  ) * (1./n_obs) 
    
    dv_bias = np.mean(v_pos - v_neg, axis=1, keepdims=True)
    dh_bias = np.mean(h_pos - h_neg, axis=1, keepdims=True)
    #print("n_obs:", n_obs)
    
    rec_error = np.linalg.norm(v_pos - v_neg)
    #print( np.sqrt(np.sum((v_pos - v_neg)**2)))
    
    return dW, dA, dB, dv_bias, dh_bias, rec_error

In [1020]:
X = X1[0:21,:]

In [1028]:
X.shape, crbm.n_his

((21, 108), 20)

In [1029]:
# Notice that the history is converted to a "long column vector" concatenating
# all the rows of the n_his vectors into a single vector of `n_vis * n_his` elements.
# This is done by `cond = np.array([cond.flatten()]).T`

dW, dA, dB, dv_bias, dh_bias, rec_error = compute_gradient(crbm, X)

In [1030]:
X.shape, rec_error

((21, 108), 60.495788990641621)

### SGD 

In [1031]:
def update_weights_sgd(crbm, grads, learning_rate):
    
    dW, dA, dB, dv_bias, dh_bias = grads #rec_error = compute_gradient(crbm, X)
    
    crbm.W += dW * learning_rate
    crbm.A += dA * learning_rate
    crbm.B += dB * learning_rate
    
    crbm.v_bias += dv_bias * learning_rate
    crbm.h_bias += dh_bias * learning_rate


In [1032]:
dW, dA, dB, dv_bias, dh_bias, err = compute_gradient(crbm, X)
grads  = (dW, dA, dB, dv_bias, dh_bias)

In [1033]:
update_weights_sgd(crbm, grads,  0.0001)

In [1034]:
err

61.046388310193812

### Apply momentum: (TODO)

### Get slice of data

Given a timeseries where column `k` corresponds to a feature vector for the measurements of the timeseries at time `k`, we would like to take a slice of `n_his` values to feed the CRBM with a visible vector and a history.

In [1042]:
X.shape

(21, 108)

In [1043]:
def get_slice_at_position_k(X, k, n_his):
    """
    Returns a slice of shape  `(n_his + 1)` with the last column beeing the visible
    vector at the current time step `k`.
    """
    assert k > n_his, "Position k = {} is lower than n_his = {}".format(k, n_his)
    assert k <= X.shape[1], "Position k = {} is bigger than number of timesteps of X.shape[1] = {}".format(k, X.shape[0])
    return X[:, (k-(n_his+1)):k]

In [1045]:
X_tr = X1.T
print("X_tr shape: ", X_tr.shape, "\nslice shape:", get_slice_at_position_k(X_tr, 520, crbm.n_his).shape)

X_tr shape:  (108, 1750) 
slice shape: (108, 21)


### Train a single epoch 

In [1052]:
X_tr = X1.T
X_tr.shape, X_tr.shape[1],  crbm.n_vis, crbm.n_hid, crbm.n_his

((108, 1750), 1750, 108, 256, 20)

In [1054]:
for k in range(crbm.n_his+1, X_tr.shape[1]+1):
    
    X_curr = get_slice_at_position_k(X_tr, k, crbm.n_his)
    
    dW, dA, dB, dv_bias, dh_bias, rec_error = compute_gradient(crbm, X_curr.T)
    grads = (dW, dA, dB, dv_bias, dh_bias)
    update_weights_sgd(crbm, grads,  0.0001)
    
    print("rec error: ", rec_error)


rec error:  55.3881925577
rec error:  54.5098963454
rec error:  53.0580362437
rec error:  52.7970383676
rec error:  48.7119458165
rec error:  45.398930616
rec error:  43.7750633226
rec error:  42.350177711
rec error:  40.8486111396
rec error:  38.9026811986
rec error:  37.5854977154
rec error:  36.9971463117
rec error:  35.4418268555
rec error:  36.624974821
rec error:  36.8489407378
rec error:  33.0767250295
rec error:  32.4724252107
rec error:  33.2869235718
rec error:  30.7587802753
rec error:  29.8077637433
rec error:  27.2474106795
rec error:  27.3793390679
rec error:  29.934613424
rec error:  26.6409417788
rec error:  24.9038900217
rec error:  28.1521323029
rec error:  25.1077743603
rec error:  24.2203133445
rec error:  22.3123602974
rec error:  26.705000174
rec error:  22.7454360293
rec error:  26.5909217002
rec error:  21.2504182855
rec error:  22.7226010035
rec error:  20.1392513212
rec error:  23.347041038
rec error:  22.3475621581
rec error:  23.4646106842
rec error:  21.538

rec error:  15.9155815951
rec error:  15.2308923908
rec error:  16.1441118328
rec error:  15.4568322686
rec error:  16.8126340912
rec error:  15.966857211
rec error:  17.5378711108
rec error:  15.1442356416
rec error:  16.1201754937
rec error:  18.0380987181
rec error:  15.5891210513
rec error:  17.1702547809
rec error:  16.8229022489
rec error:  17.8729127745
rec error:  17.3289966555
rec error:  19.8665520596
rec error:  16.543596357
rec error:  17.6497085093
rec error:  16.8500435614
rec error:  15.604233681
rec error:  18.6218287279
rec error:  17.0551566517
rec error:  15.7736823147
rec error:  17.5583865305
rec error:  15.7980046502
rec error:  16.0992529124
rec error:  16.3608378631
rec error:  16.0845357232
rec error:  16.2460918956
rec error:  16.2959240209
rec error:  18.186491573
rec error:  17.2826246458
rec error:  16.5738269909
rec error:  13.8235569435
rec error:  14.493625988
rec error:  15.9877586064
rec error:  14.4872729004
rec error:  14.0067906502
rec error:  14.26

rec error:  15.8078989709
rec error:  16.5687507319
rec error:  17.6472454267
rec error:  16.7849804982
rec error:  15.8093342656
rec error:  16.0190399148
rec error:  15.6072910381
rec error:  15.8144892033
rec error:  17.7927121636
rec error:  14.781316472
rec error:  15.5913493957
rec error:  15.8604273395
rec error:  16.1406175381
rec error:  17.6011338272
rec error:  16.8525598897
rec error:  18.542709275
rec error:  17.4880567478
rec error:  19.0581304659
rec error:  19.2088301268
rec error:  18.3671882558
rec error:  19.4292836219
rec error:  17.4179063693
rec error:  15.9947500313
rec error:  18.747692415
rec error:  15.7876179383
rec error:  16.6389809263
rec error:  17.1333022488
rec error:  16.5227422881
rec error:  17.8159277399
rec error:  18.6199260337
rec error:  16.9968611741
rec error:  17.9731015521
rec error:  18.6741606335
rec error:  17.6471492674
rec error:  16.755181639
rec error:  19.1001313451
rec error:  18.1623589608
rec error:  19.333450793
rec error:  16.23

rec error:  10.9879658827
rec error:  12.4137921283
rec error:  14.2842034764
rec error:  12.2567332555
rec error:  10.9444346435
rec error:  11.8164666461
rec error:  11.2287556117
rec error:  11.2604511525
rec error:  12.5296452265
rec error:  12.216942218
rec error:  12.2028586016
rec error:  12.5092362942
rec error:  13.2927220073
rec error:  12.7339803394
rec error:  10.7009269388
rec error:  12.6825190876
rec error:  12.367105144
rec error:  12.0735787516
rec error:  13.2113824472
rec error:  10.757432086
rec error:  12.3398146325
rec error:  11.7752101797
rec error:  11.6984307746
rec error:  13.305891233
rec error:  13.3987056887
rec error:  11.4332236758
rec error:  11.6528676828
rec error:  11.116378094
rec error:  10.8391745344
rec error:  12.7834093361
rec error:  11.8115675812
rec error:  13.0366886338
rec error:  14.4755139102
rec error:  13.2676663604
rec error:  12.8676754444
rec error:  12.0266565848
rec error:  12.689604191
rec error:  11.8993315555
rec error:  11.231

rec error:  12.1390525249
rec error:  10.3353124862
rec error:  11.4054735136
rec error:  11.057142929
rec error:  11.0307897327
rec error:  10.9153602188
rec error:  12.9838398165
rec error:  11.1388653055
rec error:  10.9127346154
rec error:  11.8282054842
rec error:  11.9555131862
rec error:  11.7157683278
rec error:  11.7179215006
rec error:  11.8646491568
rec error:  12.8910280489
rec error:  12.6708455989
rec error:  11.1493731648
rec error:  12.1762587914
rec error:  11.3857608568
rec error:  11.9146745701
rec error:  12.5509983343
rec error:  13.4214772076
rec error:  13.2177132336
rec error:  12.1081458464
rec error:  13.3161591927
rec error:  12.9287967231
rec error:  11.7072253755
rec error:  12.7152849565
rec error:  11.3336662244
rec error:  12.3721036748
rec error:  12.3939824637
rec error:  11.0180132407
rec error:  14.2167561847
rec error:  12.5059130926
rec error:  11.3044701954
rec error:  11.4395834831
rec error:  12.1600785696
rec error:  11.08494394
rec error:  11.

rec error:  10.7948029045
rec error:  9.29470624246
rec error:  9.19351854491
rec error:  9.56003411943
rec error:  9.33404355688
rec error:  8.72303806242
rec error:  9.53652764709
rec error:  9.55443174038
rec error:  9.9655399403
rec error:  9.38410431749
rec error:  9.90368955991
rec error:  9.9491073092
rec error:  10.0060861031
rec error:  9.84323137136
rec error:  10.1253011987
rec error:  9.86556426619
rec error:  10.0611589212
rec error:  9.7321904803
rec error:  8.78089465833
rec error:  8.93651378438
rec error:  9.43192799423
rec error:  8.58140143331
rec error:  9.80392519288
rec error:  9.62341936892
rec error:  9.05613794263
rec error:  8.67460384481
rec error:  9.1315509041
rec error:  10.4312828985
rec error:  8.35334823947
rec error:  8.7958200908
rec error:  10.6625490565
rec error:  8.98049063189
rec error:  8.65807508804
rec error:  9.96006160115
rec error:  11.3022162087
rec error:  9.18500583367
rec error:  9.52514837402
rec error:  9.14442857363
rec error:  9.593

### Make predictions with the model

Prepare an example that trains with several data and predict feature values

### Plot predictions