In [None]:
import numpy as np

def train_gaussian_crbm(CRBMConfig):
    numdims = CRBMConfig['data'].shape[1]
    order = CRBMConfig['order']
    numhid = CRBMConfig['numhid']
    numepochs = CRBMConfig['numepochs']
    cd_steps = CRBMConfig['cdSteps']
    gsd = CRBMConfig['gsd']
    data = CRBMConfig['data']
    classes = CRBMConfig['classes']

    # Learning rates and parameters
    epsilonw = epsilonbi = epsilonbj = epsilonA = epsilonB = 1e-6
    wdecay = 0.0002
    mom = 0.9

    # Initialize weights
    w = 0.0001 * np.random.randn(numhid, numdims)
    bi = 0.0001 * np.random.randn(numdims, 1)
    bj = -1 + 0.0001 * np.random.randn(numhid, 1)
    A = 0.0001 * np.random.randn(numdims, numdims, order)
    B = 0.0001 * np.random.randn(numhid, numdims, order)

    # Momentum buffers
    wupdate = np.zeros_like(w)
    biupdate = np.zeros_like(bi)
    bjupdate = np.zeros_like(bj)
    Aupdate = np.zeros_like(A)
    Bupdate = np.zeros_like(B)

    # Prepare training indices
    class_lengths = [cls.shape[0] for cls in classes]
    class_indices = np.cumsum(class_lengths)
    ranges = [np.arange(class_indices[i] - class_lengths[i] + order, class_indices[i]) for i in range(len(class_indices))]
    INDICES = np.concatenate(ranges)

    numcases = len(INDICES)
    data_seq = np.zeros((numcases, numdims, order + 1))
    data_seq[:, :, 0] = data[INDICES, :]
    for hh in range(order):
        data_seq[:, :, hh + 1] = data[INDICES - hh - 1, :]

    # Training loop
    for epoch in range(numepochs):
        errsum = 0

        # Positive phase
        bistar = sum(A[:, :, hh] @ data_seq[:, :, hh + 1].T for hh in range(order))
        bjstar = sum(B[:, :, hh] @ data_seq[:, :, hh + 1].T for hh in range(order))
        bottomup = w @ data_seq[:, :, 0].T
        eta = (bottomup / gsd) + bj + bjstar
        hposteriors = 1 / (1 + np.exp(-eta))
        hidstates = (hposteriors.T > np.random.rand(numcases, numhid)).astype(float)

        # Gradients
        wgrad = hidstates.T @ (data_seq[:, :, 0] / gsd)
        bigrad = np.sum((data_seq[:, :, 0].T - bi - bistar) / gsd**2, axis=1, keepdims=True)
        bjgrad = np.sum(hidstates, axis=0, keepdims=True).T
        Agrad = np.array([((data_seq[:, :, 0].T - bi - bistar) / gsd**2) @ data_seq[:, :, hh + 1] for hh in range(order)]).transpose(1, 2, 0)
        Bgrad = np.array([hidstates.T @ data_seq[:, :, hh + 1] for hh in range(order)]).transpose(1, 2, 0)

        # Negative phase (CD)
        for cdn in range(cd_steps):
            topdown = gsd * (hidstates @ w)
            negdata = topdown + bi.T + bistar.T
            eta = (w @ (negdata / gsd).T) + bj + bjstar
            hposteriors = 1 / (1 + np.exp(-eta))
            if cdn == 0:
                errsum += np.sum((data_seq[:, :, 0] - negdata) ** 2)
            hidstates = (hposteriors.T > np.random.rand(numcases, numhid)).astype(float)

        # Negative gradients
        negwgrad = hposteriors @ (negdata / gsd)
        negbigrad = np.sum((negdata.T - bi - bistar) / gsd**2, axis=1, keepdims=True)
        negbjgrad = np.sum(hposteriors, axis=1, keepdims=True)
        negAgrad = np.array([((negdata.T - bi - bistar) / gsd**2) @ data_seq[:, :, hh + 1] for hh in range(order)]).transpose(1, 2, 0)
        negBgrad = np.array([hposteriors @ data_seq[:, :, hh + 1] for hh in range(order)]).transpose(1, 2, 0)

        # Momentum
        momentum = mom if epoch > 5 else 0

        # Update weights
        wupdate = momentum * wupdate + epsilonw * ((wgrad - negwgrad) / numcases - wdecay * w)
        biupdate = momentum * biupdate + (epsilonbi / numcases) * (bigrad - negbigrad)
        bjupdate = momentum * bjupdate + (epsilonbj / numcases) * (bjgrad - negbjgrad)
        Aupdate = momentum * Aupdate + epsilonA * ((Agrad - negAgrad) / numcases - wdecay * A)
        Bupdate = momentum * Bupdate + epsilonB * ((Bgrad - negBgrad) / numcases - wdecay * B)

        # Apply updates
        w += wupdate
        bi += biupdate
        bj += bjupdate
        A += Aupdate
        B += Bupdate

        if epoch % 2 == 0:
            print(f"Epoch {epoch:4d} Error {errsum:6.1f}")

    CRBMConfig['model'] = {'w': w, 'bi': bi, 'bj': bj, 'A': A, 'B': B}
    return CRBMConfig
