In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from rbm.definitions import DATASET_DIR, OUTPUT_DIR
from rbm.models import TMCRBM

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float

# Train a TMC-RBM on the 1d2c artificial dataset using the TMC1D algorithm

In [2]:
data = torch.tensor(np.genfromtxt(DATASET_DIR.joinpath('data_1d_2c_balanced.dat')).T, dtype=dtype, device=device)
# The data is binary with values in {-1, 1}, so we scale it to be in {0, 1}
data = (data+1)/2

The cell below defines all the parameters necessary to initialize the TMC-RBM

In [3]:
# Learning rate
lr = 0.01
# Number of MCMC iterations
NGibbs = 10
# Mini-batch size
mb_s = 4000
# Number of hidden units 
Nh = 100 
# Number of visible units
Nv = 1000 
# steps used for the temporal mean
it_mean = 5
# constraint on the gaussian bath
N = 20000
# nb of chains for each constraint point
nb_chain = 10
# nb of constraint point
nb_point = 250
# if True uses PCA for the projection else uses the SVD of the weight matrix
PCA = True
# index of the direction you want to put the constraints on (careful : first direction is 0, second is 1 etc...)
direction = 0
# the distance added to the extremas of the data projection for the discretization
border_length = 0.2
# permanent chain
ResetChain = False
# Wether to use Centered gradient update or not
UpdCentered = True
# Number of persistent Markov chains for the PCD algorithm. Not useful in this notebook but should be initialized if 
# one wants to use the PCD method alongside the TMC
num_pcd = 200


In [4]:
# We initialize the model using all the previous parameters
RBM_TMC = TMCRBM(num_visible=Nv,
                num_hidden=Nh,
                device=device,
                lr=lr,
                gibbs_steps=NGibbs,
                UpdCentered=UpdCentered,
                mb_s=mb_s,
                direction=direction,
                PCA=PCA,
                num_pcd=num_pcd,
                nb_point=nb_point,
                nb_chain=nb_chain,
                border_length=border_length,
                N=N,
                it_mean=it_mean,
                ResetPermChainBatch=ResetChain)

In [5]:
# Number of training epochs
ep_max = 500

# All saved filenames will have this stamp 
RBM_TMC.file_stamp = 'Demo1d2cB'

# We define here the timesteps where we will save the RBM during training
fq_msr_RBM = 20
base = 1.7
v = np.array([0,1],dtype=int)
allm = np.append(np.array(0),base**np.array(list(range(30))))
for k in range(30):
    for m in allm:
        v = np.append(v,int(base**k)+int(m)) 
v = np.array(list(set(v)))
v = np.sort(v)
RBM_TMC.list_save_time = v
RBM_TMC.list_save_rbm = np.arange(1, ep_max, fq_msr_RBM)

# Visible bias initialisation
RBM_TMC.SetVisBias(data)

In [6]:
# Epoch frequency on which to generate figures in the output folders
# Useful to monitor the learning
fq_fig = 20

if RBM_TMC.PCA:
    _, _, V = torch.svd(data.T)
    Xsc = torch.mm(data.T,V).cpu()/ (RBM_TMC.Nv**0.5)

for t in range(ep_max):
    RBM_TMC.fit(data,ep_max=1)
    
    if t%fq_fig == 0:
        if not(RBM_TMC.PCA):
            _,_,V = torch.svd(RBM_TMC.W)
            Xsc = torch.mm(data.T,V).cpu()/ (RBM_TMC.Nv**0.5)
        
            fig = plt.figure()
            plt.hist(Xsc.numpy()[:,0],bins=100,density=True);
            plt.savefig(OUTPUT_DIR.joinpath(f"{RBM_TMC.file_stamp}_fig_scat_{t}.png"))
            plt.close(fig)

        fig = plt.figure()
        plt.plot(RBM_TMC.w_hat_b.cpu()[1:], RBM_TMC.p_m.cpu())
        plt.hist(Xsc.numpy()[:,0],bins=100,density=True);
        plt.savefig(OUTPUT_DIR.joinpath(f"{RBM_TMC.file_stamp}_fig_proba_{t}.png"))
        plt.close(fig)

        fig = plt.figure()
        plt.plot(RBM_TMC.w_hat_b.cpu()[1:], RBM_TMC.Ω)
        plt.savefig(OUTPUT_DIR.joinpath(f"{RBM_TMC.file_stamp}_fig_pot_{t}.png"))
        plt.close(fig)

IT 0
Saving nb_upd=0
Saving nb_upd=1
Saving nb_upd=2
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux/rbm/model/RBMDemo1d2cB.h5
IT 1
Saving nb_upd=3
Saving nb_upd=4
Saving nb_upd=5
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux/rbm/model/RBMDemo1d2cB.h5
IT 2
Saving nb_upd=6
Saving nb_upd=8
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux/rbm/model/RBMDemo1d2cB.h5
IT 3
Saving nb_upd=9
Saving nb_upd=10
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux/rbm/model/RBMDemo1d2cB.h5
IT 4
Saving nb_upd=12
Saving nb_upd=14
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux/rbm/model/RBMDemo1d2cB.h5
IT 5
Saving nb_upd=15
Saving nb_upd=16
model updates saved at /home/nbereux/rbm/model/AllParametersDemo1d2cB.h5
model saved at /home/nbereux