In [86]:
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import torch

In [87]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [88]:
# Config
batch_size = 100
n_iters = 5000
num_epochs = None


In [89]:
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [90]:
num_epochs = int(n_iters / (len(train_dataset) / batch_size))

In [91]:
samples = iter(train_dataloader)
samples._next_data()[0].shape

torch.Size([100, 1, 28, 28])

In [92]:
class FullyConnectedAutoEncode(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FullyConnectedAutoEncode, self).__init__()
        # Encoder: Affine Function
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        # Decoder: Affine Function
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        # Encoder: affine function
        out = self.fc1(x)
        # Encoder: non-linear function
        out = F.leaky_relu(out)

        # Decoder: affine function
        out = self.fc2(out)
        # Decoder: non-linear function
        out = torch.sigmoid(out)

        return out
    
input_dim = 28*28
hidden_dim = int(input_dim * 1.5)
output_dim = input_dim
learning_rate = 1e-3

model = FullyConnectedAutoEncode(input_dim, hidden_dim, output_dim).to(device=device)
criterion = nn.MSELoss()
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [93]:
idx = 0
# Dropout for creating noisy images
# by dropping out pixel with a 50% probability
dropout = nn.Dropout(0.5)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_dataloader):
        images = images.reshape(-1, 28*28)
        noisy_image = dropout(torch.ones(images.shape)) * images
        noisy_image = noisy_image.to(device)
        images = images.to(device)

        optim.zero_grad()

        output = model(noisy_image)

        loss = criterion(images, output)

        loss.backward()

        optim.step()

        idx += 1

        if idx % 500 == 0:
            with torch.no_grad():
                total_test_samples = 0
                total_test_loss = 0
                for images_, labels_ in test_dataloader:
                    
                    # Noisy images
                    images_ = images_
                    noisy_image = dropout(torch.ones(images_.shape)) * images_
                    images_ = images_.to(device)
                    noisy_image = noisy_image.to(device)
                    
                    # Forward pass only to get logits/output
                    outputs = model(noisy_image.view(-1, 28*28))

                    # Test loss
                    test_loss = criterion(outputs, images_.view(-1, 28*28))

                    # Total number of labels
                    total_test_samples += labels.size(0)

                    # Total test loss
                    total_test_loss += test_loss

                mean_test_loss = total_test_loss / total_test_samples

                # Print Loss
                print(f'Iteration: {idx}. Average Test Loss: {mean_test_loss.item()}.')

Iteration: 500. Average Test Loss: 0.00016181077808141708.
Iteration: 1000. Average Test Loss: 0.00014505648869089782.
Iteration: 1500. Average Test Loss: 0.00013259751722216606.
Iteration: 2000. Average Test Loss: 0.00012482349120546132.
Iteration: 2500. Average Test Loss: 0.0001206278902827762.
Iteration: 3000. Average Test Loss: 0.00011856437777169049.
Iteration: 3500. Average Test Loss: 0.00011562420695554465.
Iteration: 4000. Average Test Loss: 0.0001144960115198046.
Iteration: 4500. Average Test Loss: 0.0001131142198573798.
