In [None]:
import torch
import torch.nn as nn

In [None]:
from tqdm.auto import tqdm

In [None]:
class NADE(nn.Module):
    def __init__(self, input_dim, hidden_size):
        super(NADE, self).__init__()
        self.input_dim = input_dim
        self.hidden_size = hidden_size

        self.W_in = nn.Parameter(torch.randn(input_dim, hidden_size))
        self.b_in = nn.Parameter(torch.randn(input_dim))

        self.W_hidden = nn.Parameter(torch.randn(hidden_size, input_dim))
        self.b_hidden = nn.Parameter(torch.randn(hidden_size))

    def forward(self, x):
        batch_size = x.shape[0]
        num_features = x.shape[1]

        a = self.b_hidden
        
        probs = torch.zeros(batch_size, num_features)

        for i in range(num_features):
            h = torch.sigmoid(a).reshape(-1, self.hidden_size)
            probs[:, i] = torch.sigmoid(self.W_in[i, :] @ h.T + self.b_in[i])
            a = torch.matmul(self.W_hidden[:, i].reshape(-1, 1), x[:, i].reshape(1, -1)).reshape(-1, self.hidden_size) + a
        
        return probs
    
    def sample(self, device):
        with torch.no_grad():
            preds = torch.zeros(1, self.input_dim).to(device)
            for i in tqdm(range(self.input_dim)):
                p = self.forward(preds)
                preds[:, i] = torch.bernoulli(p[:, i])
            
            return torch.reshape(preds.cpu(), (28, 28))

In [None]:
in_size = 784
hidden_size = 500

In [None]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    lambda x: (x >= 0.5).float(),
])

original_train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
original_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

filter_digit = 3

filtered_train_dataset = [item for item in original_train_dataset if item[1] == filter_digit]
filtered_test_dataset = [item for item in original_test_dataset if item[1] == filter_digit]

filtered_dataset = filtered_train_dataset + filtered_test_dataset

data_loader = DataLoader(filtered_dataset, batch_size=64, shuffle=True, num_workers=4)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
device

In [None]:
nade = NADE(in_size, hidden_size).to(device)

In [None]:
criterion = nn.BCELoss().to(device)
optim = optim.Adam(nade.parameters(), lr=0.001)

In [None]:
from timeit import default_timer as timer

epochs = 30
for epoch in range(epochs):
    t0 = timer()
    epoch_loss = 0
    for batch, (data, _) in enumerate(data_loader):
        data = data.view(in_size, -1).to(device)

        optim.zero_grad()
        output = nade(data)
        loss = criterion(output.to(device), data)
        epoch_loss += loss.item()
        loss.backward()
        optim.step()

    t1 = timer()
    print(f'Epoch: {epoch + 1}, Loss: {epoch_loss/len(data_loader)}, Time taken: {(t1 - t0):.2f} s')

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
torch.save(nade, 'nade.pt')

In [None]:
nade = torch.load('nade.pt')

In [None]:
img = nade.sample(device)

In [None]:
import matplotlib.pyplot as plt

In [None]:
def show_sample(img):
    plt.imshow(img.squeeze(), cmap='gray')
    plt.axis('off')
    plt.show()

In [None]:
show_sample(img)