## implementation stollen from https://blog.paperspace.com/beginners-guide-to-boltzmann-machines-pytorch/

In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image
%matplotlib inline
import matplotlib.pyplot as plt

In [301]:
def truth_func(x, y):
    #return np.logical_xor(x, y)
    return x*y
    #return np.logical_or(x, y)
    #return np.logical_not(x)

def gen_xor_training(N):
    data   = np.random.randint(0, 2, (N, 2))
    target = truth_func(data[:, 0:1], data[:, 1:2])
    return data,target

def xor_to_flat(data, target):
    concat = np.concatenate((data, target), axis=1).astype(np.single) #.flatten()
    #concat = 2*concat - 1
    return concat

def flat_to_xor(data):
    #data = data.reshape(-1, 3)
    #data = (data + 1)/2
    return data[:, 0:2], data[:, 2:3]

def xor_sucess_rate(data, target):
    ground_truth = truth_func(data[:, 0:1], data[:, 1:2])
    return 1 - np.mean(np.abs(target - ground_truth))

In [302]:
# this is a descrete RBM
class RBM(nn.Module):
   def __init__(self,
               n_vis,
               n_hin,
               k):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hin,n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
    
   def sample_from_p(self,p):
       return torch.relu(torch.sign(p - Variable(torch.rand(p.size()))))

   def v_to_h(self,v):
        p_h = F.sigmoid(F.linear(v,self.W,self.h_bias))
        sample_h = self.sample_from_p(p_h)
        return p_h,sample_h
    
   def h_to_v(self,h):
        p_v = F.sigmoid(F.linear(h,self.W.t(),self.v_bias))
        sample_v = self.sample_from_p(p_v)
        return p_v,sample_v

   def forward(self,v):
        pre_h1,h1 = self.v_to_h(v)
        
        h_ = h1
        for _ in range(self.k):
            pre_v_,v_ = self.h_to_v(h_)
            pre_h_,h_ = self.v_to_h(v_)
        
        return v,v_

   def free_energy(self,v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v, self.W, self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()

In [304]:
batch_size = 200
rbm = RBM(3, 5, k=5)
train_op = optim.SGD(rbm.parameters(), 0.1)

losses = []
sucess_rates = []

for epoch in range(20):
    loss_ = []
    sucess_rate_ = []
    for i in range(500):
        data,target = gen_xor_training(batch_size)
        data_var = Variable(torch.from_numpy(xor_to_flat(data, target).astype(np.double)))
        
        v,v1 = rbm(data_var)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        #loss = torch.abs(v - v1).mean()
        
        loss_.append(loss.data)
        train_op.zero_grad()
        loss.backward()
        train_op.step()

        sucess_rate = xor_sucess_rate(*flat_to_xor(v1.detach().numpy()))
        sucess_rate_.append(sucess_rate)

    losses.append(np.mean(loss_))
    sucess_rates.append(np.mean(sucess_rate_))
    print("Training loss and sucess rate for {} epoch: {}, {}%".format(epoch, losses[-1], int(sucess_rates[-1]*100)))
    
print("\nv_bias:", rbm.v_bias.detach().numpy())
print("h_bias:", rbm.h_bias.detach().numpy())
print("W:\n", rbm.W.detach().numpy())

Training loss and sucess rate for 0 epoch: -0.009890688599752406, 63%


KeyboardInterrupt: 

In [None]:
#print(v.reshape(batch_size, 3)[:20, :])
print(v1.detach().numpy().reshape(batch_size, 3)[:20, :])

In [305]:
# Continuous RBM
class RBM_C(nn.Module):
    def __init__(self,
                n_vis,
                n_hin,
                k):
        torch.set_default_dtype(torch.double)
        super(RBM_C, self).__init__()
        
        self.W      = nn.Parameter(torch.randn(n_hin, n_vis)*1e-2)
        self.v_bias = nn.Parameter(torch.zeros(n_vis))
        self.h_bias = nn.Parameter(torch.zeros(n_hin))
        self.k = k
        
        self.eps=1e-2

    def v_to_h(self,v):
        p_h  = F.linear(v, self.W, self.h_bias)
        p_h_ = p_h + torch.sign(p_h)*self.eps
        p_h_[p_h_ == 0] = self.eps

        lim_up = 1 - torch.exp(-p_h_)
        lim_up = torch.clamp(lim_up, min=-1 + 0.1)
        
        y = Variable(torch.rand(p_h_.size()))*lim_up
        h = torch.log(1 - y).div(-p_h_)

        return h
    
    def h_to_v(self, h):
        p_v  = F.linear(h, self.W.t(), self.v_bias)
        p_v_ = p_v + torch.sign(p_v)*self.eps
        p_v_[p_v_ == 0] = self.eps

        lim_up = 1 - torch.exp(-p_v_)
        lim_up = torch.clamp(lim_up, min=-1 + 0.1)
        
        y = Variable(torch.rand(p_v_.size()))*lim_up
        v = torch.log(1 - y).div(-p_v_)

        return v

    def forward(self,v):
        h1 = self.v_to_h(v)
        
        h_ = h1
        for i in range(self.k):
            v_ = self.h_to_v(h_)
            h_ = self.v_to_h(v_)
        
        return v,v_
    
    def free_energy(self,v):
        vbias_term = v.mv(self.v_bias)
        wx_b = F.linear(v, self.W, self.h_bias)
        hidden_term = wx_b.exp().add(1).log().sum(1)
        return (-hidden_term - vbias_term).mean()

In [309]:
batch_size = 200
rbm = RBM_C(3, 5, k=5)
train_op = optim.SGD(rbm.parameters(), 0.01)

losses = []
sucess_rates = []

for epoch in range(20):
    loss_ = []
    sucess_rate_ = []
    for i in range(200):
        data,target = gen_xor_training(batch_size)
        data_var = Variable(torch.from_numpy(xor_to_flat(data, target).astype(np.double)))
        
        v,v1 = rbm(data_var)
        loss = rbm.free_energy(v) - rbm.free_energy(v1)
        
        loss_.append(loss.data)
        train_op.zero_grad()
        loss.backward()
        train_op.step()

        sucess_rate = xor_sucess_rate(*flat_to_xor(v1.detach().numpy()))
        sucess_rate_.append(sucess_rate)

    losses.append(np.mean(loss_))
    sucess_rates.append(np.mean(sucess_rate_))
    print("Training loss and sucess rate for {} epoch: {}, {}%".format(epoch, losses[-1], int(sucess_rates[-1]*100)))
    
print("\nv_bias:", rbm.v_bias.detach().numpy())
print("h_bias:", rbm.h_bias.detach().numpy())
print("W:\n", rbm.W.detach().numpy())

Training loss and sucess rate for 0 epoch: -0.12758637770081757, 62%
Training loss and sucess rate for 1 epoch: -0.17742360799709353, 63%
Training loss and sucess rate for 2 epoch: -0.17964864919233375, 62%
Training loss and sucess rate for 3 epoch: -0.19802568584431401, 62%
Training loss and sucess rate for 4 epoch: -0.22849368493296013, 61%
Training loss and sucess rate for 5 epoch: -0.2940543604560367, 60%
Training loss and sucess rate for 6 epoch: -0.5785320498036836, 57%
Training loss and sucess rate for 7 epoch: -1.507186619984551, 54%
Training loss and sucess rate for 8 epoch: -3.427162593724862, 52%
Training loss and sucess rate for 9 epoch: -5.97065926604967, 50%
Training loss and sucess rate for 10 epoch: -8.938049953934128, 49%
Training loss and sucess rate for 11 epoch: -11.736259608300935, 48%
Training loss and sucess rate for 12 epoch: -14.797369136603125, 49%
Training loss and sucess rate for 13 epoch: -17.860300767517757, 49%
Training loss and sucess rate for 14 epoch: 

In [310]:
print(v1.detach().numpy().reshape(batch_size, 3)[:20, :])

[[2.34381249e-02 9.98845079e-01 7.25670195e-01]
 [6.89076978e-02 3.20080812e-01 4.52135134e-01]
 [3.60916064e-03 6.44501470e-01 4.53499518e-01]
 [2.17591723e-02 4.46516540e-01 5.27059602e-01]
 [2.88576909e-02 9.51125585e-01 5.88168672e-01]
 [4.08472139e-02 2.75099894e-01 8.12144877e-01]
 [2.88160936e-03 9.49134635e-01 4.21803481e-01]
 [2.24940596e-02 5.94745548e-01 2.53419331e-01]
 [6.53212788e-04 8.35104956e-01 1.12892377e-01]
 [1.51665856e-02 1.47207501e-01 2.10933987e-01]
 [8.51292044e-03 8.37180824e-02 9.67250632e-01]
 [1.93526801e-03 8.10765649e-01 6.00704767e-01]
 [7.97965054e-02 3.31486453e-02 8.19776312e-01]
 [1.12509342e-01 6.93695094e-01 5.18540718e-01]
 [5.43909451e-05 9.18096211e-01 5.28081426e-01]
 [7.97623108e-03 3.34002242e-01 3.68821547e-01]
 [8.57434785e-03 1.64113698e-01 9.10566478e-01]
 [3.16243466e-03 5.94430701e-01 5.29796614e-01]
 [3.53399191e-02 3.10819375e-01 7.93509706e-01]
 [3.39267104e-02 5.91438615e-01 3.71228191e-01]]
