In [None]:
# Import packages and define common functions. Do not modify.
import os
import numpy as np
import random
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torchvision.utils import make_grid , save_image

import torch.optim as optim

%matplotlib inline
import matplotlib.pyplot as plt

def show_and_save(img, file_name):
    """Show and save the image.
    Args:
        img (Tensor): The image.
        file_name (Str): The destination.
    """
    npimg = np.transpose(img.numpy(), (1, 2, 0))
    f = "./%s.png" % file_name
    plt.imshow(npimg, cmap='gray')
    plt.show()
    plt.imsave(f, npimg)

def train(model, train_loader, n_epochs=20, lr=0.01):
    """Train a generative model.
    Args:
        model: The model.
        train_loader (DataLoader): The data loader.
        n_epochs (int, optional): The number of epochs. Defaults to 20.
        lr (Float, optional): The learning rate. Defaults to 0.01.
    Returns:
        The trained model.
    """
    train_op = optim.Adam(model.parameters(), lr)
    model.train()
    for epoch in tqdm(range(n_epochs)):
        loss_ = []
        for _, (data, target) in enumerate(train_loader):
            loss = model.get_loss(data.view(-1, 784))
            loss_.append(loss.item())
            train_op.zero_grad()
            loss.backward()
            train_op.step()
        print('Epoch %d\t Loss=%.4f' % (epoch, np.mean(loss_)))
    return model

def train_vae2(model, train_loader, n_epochs=20, lr=0.01):
    """Train a generative model.
    Args:
        model: The model.
        train_loader (DataLoader): The data loader.
        n_epochs (int, optional): The number of epochs. Defaults to 20.
        lr (Float, optional): The learning rate. Defaults to 0.01.
    Returns:
        The trained model.
    """
    train_op = optim.Adam(model.parameters(), lr)
    model.train()
    for epoch in tqdm(range(n_epochs)):
        loss_ = []
        for _, (data, target) in enumerate(train_loader):
            loss = model.get_loss(data)
            loss_.append(loss.item())
            train_op.zero_grad()
            loss.backward()
            train_op.step()
        print('Epoch %d\t Loss=%.4f' % (epoch, np.mean(loss_)))
    return model

seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [None]:
class RBM(nn.Module):
    """Restricted Boltzmann Machine.
    Args:
        n_vis (int, optional): The size of visible layer. Defaults to 784.
        n_hid (int, optional): The size of hidden layer. Defaults to 128.
        k (int, optional): The number of Gibbs sampling. Defaults to 1.
    """

    def __init__(self, n_vis=784, n_hid=128, k=1):
        """Create a RBM."""
        super(RBM, self).__init__()
        self.v = nn.Parameter(torch.randn(1, n_vis))
        self.h = nn.Parameter(torch.randn(1, n_hid))
        self.W = nn.Parameter(torch.randn(n_hid, n_vis))
        self.k = k

    def visible_to_hidden(self, v):
        r"""Conditional sampling a hidden variable given a visible variable.
        Args:
            v (Tensor): The visible variable.
        Returns:
            Tensor: The hidden variable.
        """
        p = torch.sigmoid(F.linear(v, self.W, self.h))
        return p.bernoulli()

    def hidden_to_visible(self, h):
        r"""Conditional sampling a visible variable given a hidden variable.
        Args:
            h (Tendor): The hidden variable.
        Returns:
            Tensor: The visible variable.
        """
        p = torch.sigmoid(F.linear(h, self.W.t(), self.v))
        return p.bernoulli()

    def free_energy(self, v, reduction='mean'):
        r"""Free energy function.
        .. math::
            \begin{align}
                F(x) &= -\log \sum_h \exp (-E(x, h)) \\
                &= -a^\top x - \sum_j \log (1 + \exp(W^{\top}_jx + b_j))\,.
            \end{align}
        Args:
            v (Tensor): The visible variable.
        Returns:
            energy (FloatTensor): The free energy value.
        """
        v_term = torch.matmul(v, self.v.t())
        h_term = torch.sum(F.softplus(F.linear(v, self.W, self.h)), dim=-1, keepdim=True)
        energy = -v_term - h_term
        if reduction == 'none':
            return energy.squeeze(-1)
        elif reduction == 'mean':
            return torch.mean(energy)

    def forward(self, v):
        r"""Compute the real and generated examples.
        Args:
            v (Tensor): The visible variable.
        Returns:
            (Tensor, Tensor): The real and generagted variables.
        """
        for _ in range(self.k):
            h = self.visible_to_hidden(v)
            v_gibb = self.hidden_to_visible(h)
        return v, v_gibb
    
    def get_loss(self, inputs):
        r"""Compute the loss for training the model.
        Args:
            inputs (Tensor): The visible variable.
        Returns:
            Tensor: Loss.
        """
        v, v_gibb = self.forward(inputs)
        loss = self.free_energy(v) - self.free_energy(v_gibb)
        return loss
    
    @torch.no_grad()
    def pseudo_likelihood(self, v):
        """DO NOT MODIFY THIS FUNCTION"""
        # Randomly corrupt one feature in each sample in v.
        ind = (np.arange(v.shape[0]), np.random.randint(0, v.shape[1], v.shape[0]))
        v_ = v.clone()
        v_[ind] = 1 - v_[ind]
        fe = self.free_energy(v, reduction='none')
        fe_ = self.free_energy(v_, reduction='none')
        m = torch.nn.LogSigmoid()
        score = v.shape[1] * m(fe_ - fe)
        return score


In [None]:
# DO NOT MODIFY
train_dataset = datasets.MNIST('./data',
    train=True,
    download = True,
    transform = transforms.Compose(
        [transforms.ToTensor(), lambda x: (x > 0).float()]
    )
)
test_dataset = datasets.MNIST('./data',
    train=False,
    transform = transforms.Compose(
        [transforms.ToTensor(), lambda x: (x > 0).float()]
    )
)

batch_size = 128
n_hid = 128
n_vis = 784
n_epochs = 20
lr = 0.01
rbm_ckpt_fn = 'model_rbm_seed2025.pt'

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

In [None]:
# You can try different k for the report, but use k=1 when submitting the checkpoint
# Do not modify other parts of this cell
if not os.path.exists(rbm_ckpt_fn):
    model_rbm = RBM(n_vis=n_vis, n_hid=n_hid, k=1)
    model_rbm = train(model_rbm, train_loader, n_epochs=n_epochs, lr=lr)
    # save model, do not change the filename.
    torch.save(model_rbm.state_dict(), rbm_ckpt_fn)


In [None]:
# You can use the following code to visualize reconstructed samples
model_rbm.eval()
vis_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64)
images = next(iter(vis_loader))[0]
v, v_gibbs = model_rbm(images.view(-1, 784))
show_and_save(make_grid(v_gibbs.view(64, 1, 28, 28).data), 'rbm_fake')