In [2]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [3]:
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root= '.', train= True, download= True, transform = transform)

data_loader = torch.utils.data.DataLoader(dataset= mnist_data,batch_size= 64, shuffle= True)

In [4]:
data_iter = iter(data_loader)
images, labels = next(data_iter)
print(torch.min(images), torch.max(images)) 

tensor(0.) tensor(1.)


In [5]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128), # N, 784 -> N, 128
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12, 3) # N, 3  
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3, 12), # N, 3 
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Sigmoid() # last layer  
        )
        
    def forward(self, X):

        encoder = self.encoder(X)
        decoder = self.decoder(encoder)

        return decoder


In [6]:
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3, weight_decay= 1e-5)


In [None]:
# training loop:
num_epoch = 10
output = []
for epoch in range(num_epoch):
    for (img, _) in data_loader:
        img = img.reshape(-1, 28*28)
        recon = model(img)
        loss = criterion(recon, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f} ')
    output.append((epoch, img, recon))


Epoch: 1, Loss: 0.0509 
Epoch: 2, Loss: 0.0401 
Epoch: 3, Loss: 0.0409 
Epoch: 4, Loss: 0.0433 
Epoch: 5, Loss: 0.0327 
Epoch: 6, Loss: 0.0387 
Epoch: 7, Loss: 0.0365 
Epoch: 8, Loss: 0.0339 
Epoch: 9, Loss: 0.0421 
Epoch: 10, Loss: 0.0350 


: 

In [None]:
for k in range(0, num_epoch, 4):
    plt.figure(figsize= (9,2))
    plt.gray()
    imgs = output[k][1].detach().numpy()
    recon = output[k][2].detach().numpy()
    for i, items in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        item = items.reshape(-1, 28,28)
        plt.imshow(item[0])

    for i, item in enumerate(recon):
        if i>= 9: break
        plt.subplot(2, 9, 9+i+1)
        item = items.reshape(-1, 28,28)
        plt.imshow(item[0])
