A Practical Guide to Training RBM (Hinton) を読んだので, 実装に変更

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

In [6]:
class RBM_hinton(nn.Module):
    def __init__(self, n_vis, n_hid, k=1):
        """
        a: Bias term of visible unit.
        b: Bias term of hidden unit.
        W: Weight parameter of RBM.
        k: The number of gibbs sampling in k-CD method.
        """
        super().__init__()
        self.a = nn.Parameter(torch.randn(1, n_vis))
        self.b = nn.Parameter(torch.randn(1, n_vis))
        self.W = nn.Parameter(torch.randn(n_hid, n_vis))
        self.k = k
    
    def encode(self, v, binarize=True):
        # Conditional sampling of h
        p = torch.sigmoid(F.linear(v, self.W, self.b))
        return p.bernoulli() if binarize else p
    
    def decode(self, h, binarize=False):
        # Conditional sampling of v
        p = torch.sigmoid(F.linear(h, self.W.t(), self.a))
        return p.bernoulli() if binarize else p
    
    def gibbs_sampling(self, v):
        for _ in range(self.k - 1): # When k=1 (CD1 method), this loop is skipped.
            h = self.encode(v) # Binary
            v = self.decode(h) # Real
        h = self.decode(v, binarize=False) # Real
        v = self.encode(h.bernoulli()) # Real
        return v, h
    
    def forward(self, v): # Reconstruction of a image
        v, h = self.gibbs_sampling(v)
        return v
        
    def free_energy(self, v):
        v_term = torch.matmul(v, self.a.t())
        w_x_h = F.linear(v, self.W, self.b)
        h_term = torch.sum(F.softplus(w_x_h), dim=1)
        return -h_term - v_term
    
    def loss(self, v):
        
        
        