In [None]:
import torch
from torch import nn
from torch.nn import functional as F


class BoltzmannMachine(nn.Module):
    def __init__(self, n_visible, n_hidden):
        super(BoltzmannMachine, self).__init__()
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        total = n_visible + n_hidden
        self.W = nn.Parameter(torch.randn(total, total))
        self.b = nn.Parameter(torch.randn(total))
    
    def energy(self, v, h):
        v_h = torch.cat([v, h], 1)
        self_energy = - torch.matmul(v_h, self.b)
        cross_energy = - torch.matmul(v_h, torch.matmul(self.W, v_h.t()))
        return self_energy + cross_energy

    def forward(self, v):
        pass


class RBM(nn.Module):
    def __init__(self, n_visible, n_hidden):
        super(RBM, self).__init__()
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        self.W = nn.Parameter(torch.randn(n_visible, n_hidden))
        self.b = nn.Parameter(torch.randn(n_visible))
        self.c = nn.Parameter(torch.randn(n_hidden))

    def energy(self, v, h):
        v_self_energy = - torch.matmul(v, self.b)
        h_self_energy = - torch.matmul(h, self.c)
        cross_energy = - torch.matmul(v, torch.matmul(self.W, h.t()))
        return v_self_energy + h_self_energy + cross_energy

    def forward(self, v):
        h = torch.sigmoid(F.linear(v, self.W.t(), self.c))
        return h, torch.sigmoid(F.linear(h, self.W, self.b))
