In [1]:
from util import *
from rbm import RestrictedBoltzmannMachine 
from dbn import DeepBeliefNet

if __name__ == "__main__":
    #################
    BATCHSIZE = 20
    NHIDDEN = 500
    #################
    image_size = [28,28]
    train_imgs,train_lbls,test_imgs,test_lbls = read_mnist(dim=image_size, n_train=60000, n_test=10000)

    ''' restricted boltzmann machine '''
    
    print ("\nStarting a Restricted Boltzmann Machine..")
    
    for n_hidden in range(500,150,-50):
        print("Nr of hidden nodes: %3d"%(n_hidden))
        rbm = RestrictedBoltzmannMachine(ndim_visible=image_size[0]*image_size[1],
                                     ndim_hidden=n_hidden,
                                     is_bottom=True,
                                     image_size=image_size,
                                     is_top=False,
                                     n_labels=10,
                                     batch_size=BATCHSIZE)
    
        rbm.cd1(visible_trainset=train_imgs, n_iterations=int(1 + 60000/BATCHSIZE))
    
    ''' deep- belief net '''

    print ("\nStarting a Deep Belief Net..")
    
    dbn = DeepBeliefNet(sizes={"vis":image_size[0]*image_size[1], "hid":500, "pen":500, "top":2000, "lbl":10},
                        image_size=image_size,
                        n_labels=10,
                        batch_size=10
    )
    
    ''' greedy layer-wise training '''

    dbn.train_greedylayerwise(vis_trainset=train_imgs, lbl_trainset=train_lbls, n_iterations=2000)

    dbn.recognize(train_imgs, train_lbls)
    
    dbn.recognize(test_imgs, test_lbls)

    for digit in range(10):
        digit_1hot = np.zeros(shape=(1,10))
        digit_1hot[0,digit] = 1
        dbn.generate(digit_1hot, name="rbms")

    ''' fine-tune wake-sleep training '''

    dbn.train_wakesleep_finetune(vis_trainset=train_imgs, lbl_trainset=train_lbls, n_iterations=2000)

    dbn.recognize(train_imgs, train_lbls)
    
    dbn.recognize(test_imgs, test_lbls)
    
    for digit in range(10):
        digit_1hot = np.zeros(shape=(1,10))
        digit_1hot[0,digit] = 1
        dbn.generate(digit_1hot, name="dbn")


Starting a Restricted Boltzmann Machine..
Nr of hidden nodes: 500
learning CD1
iteration=      0 recon_loss=2862.6237
iteration=    500 recon_loss=2205.2623
iteration=   1000 recon_loss=2195.1678
iteration=   1500 recon_loss=2178.3266
iteration=   2000 recon_loss=2174.8500
iteration=   2500 recon_loss=2156.0377
iteration=   3000 recon_loss=2166.1440
Nr of hidden nodes: 450
learning CD1
iteration=      0 recon_loss=2914.3551
iteration=    500 recon_loss=2239.8317
iteration=   1000 recon_loss=2248.6674
iteration=   1500 recon_loss=2257.6324
iteration=   2000 recon_loss=2233.1649
iteration=   2500 recon_loss=2237.8355
iteration=   3000 recon_loss=2205.1063
Nr of hidden nodes: 400
learning CD1
iteration=      0 recon_loss=2983.6743
iteration=    500 recon_loss=2180.4314
iteration=   1000 recon_loss=2169.0675
iteration=   1500 recon_loss=2178.3294
iteration=   2000 recon_loss=2158.4541
iteration=   2500 recon_loss=2178.2912
iteration=   3000 recon_loss=2182.8170
Nr of hidden nodes: 350
lea