In [3]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader

In [5]:
class RBM(nn.Module):
    def __init__(self, visible_size, hidden_size):
        super(RBM,self).__init__()
        self.W = nn.Parameter(torch.randn(visible_size, hidden_size))
        self.v_bias = nn.Parameter(torch.randn(visible_size))
        self.h_bias = nn.Parameter(torch.randn(hidden_size))
        
    def forward(self,x):
        # 은닉층의 확률값
        hidden_prob = torch.sigmoid(torch.matmul(x, self.W) + self.h_bias)  # torch.matmul = 행렬곱
        # 확률값 > 함수화
        hidden_state = torch.bernoulli(hidden_prob)
        # 가시층의 확률값
        visible_prob = torch.sigmoid(torch.matmul(hidden_state, torch.transpose(self.W, 0, 1))+self.v_bias) 
        # torch.transpose = 전치행렬, (0,1) = 출력값 크기 지정
        return visible_prob, hidden_state

In [11]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

visible_size = 28*28 # MNIST 이미지 크기 28*28
hidden_size = 256
rbm = RBM(visible_size, hidden_size)

criterion = nn.BCELoss()
optimizer = torch.optim.SGD(rbm.parameters(), lr=0.01)

In [12]:
for epoch in range(10):
    for images, _ in train_loader:
        # 입력 데이터 이진화
        inputs = images.view(-1, visible_size)
        
        visible_prob, _ = rbm(inputs)
        loss = criterion(visible_prob, inputs)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch: {epoch+1}, Loss: {loss:.4f}')
    
    # 가중치 이미지 저장
    #vutils.save_image(rbm.W.view(hidden_size, 1, 28, 28), f'weights_epoch_{epoch+1}.png', normalize=True)
    
    # 입력 이미지 & 재출력 이미지 저장
    inputs_display = inputs.view(-1,1,28,28)  # 앞서 데이터를 넣을때 1차원으로 넣었는데 그걸다시 풀어서 되돌림
    outputs_display = visible_prob.view(-1,1,28,28)
    comparison = torch.cat([inputs_display, outputs_display], dim=3)
    #vutils.save_image(comparison, f'reconstruction_epoch_{epoch+1}.png', normalize=True)

Epoch: 1, Loss: 3.9591
Epoch: 2, Loss: 1.2509
Epoch: 3, Loss: -1.1453
Epoch: 4, Loss: -3.6860
Epoch: 5, Loss: -5.7566
Epoch: 6, Loss: -7.9294
Epoch: 7, Loss: -9.5972
Epoch: 8, Loss: -11.1037
Epoch: 9, Loss: -12.8195
Epoch: 10, Loss: -13.9716
