In [None]:
import torch
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
torch.manual_seed(0)

t = transforms.ToTensor()

training_data = Subset(
    datasets.MNIST(root=".tmp/torch-data/", train=True, download=True, transform=t),
    range(20000)
)

dl_train = DataLoader(training_data, batch_size=500, shuffle=True)

In [None]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 24),
            nn.ReLU(),
            nn.Linear(24, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 24),
            nn.ReLU(),
            nn.Linear(24, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
model = autoencoder()

opt = optim.Adam(model.parameters(), lr=0.01)

criterion = nn.MSELoss()

l_train = []

for epoch in tqdm(range(30)):
    train_loss = 0
    
    for batch, _ in dl_train:
        output = model(batch.view(-1, 784))
        opt.zero_grad()
        loss = criterion(output, batch.view(-1, 784))
        loss.backward()
        opt.step()
        train_loss += loss.item()
        
    train_loss /= len(dl_train)
    l_train.append(train_loss)
    
    print(epoch, train_loss)

In [None]:
plt.plot(l_train)
plt.grid()
plt.show()

In [None]:
# Show encoding of 10 images.
dl = DataLoader(training_data, batch_size=1, shuffle=True)

with torch.no_grad():
    for idx, data in zip(range(10), dl):
        batch, label = data
        output = model.encoder(batch.view(-1, 784))
        print(output, label)

In [None]:
# Create random representation and show the reconstruction.

with torch.no_grad():
    t = (torch.rand(3) - 0.5) * 5  # random values in [-2.5, 2.5]
    #t[0] -= 0.5
    print(t)
    output = model.decoder(t)
    img = output.view(28, 28)
    plt.imshow(img, cmap="gray")
    plt.show()

In [None]:
# Just get an image and show it.

with torch.no_grad():
    i = iter(dl)
    img, _ = i.next()
    print(img.shape)
    img = img.view(28, 28)
    plt.imshow(img, cmap="gray")
    plt.show()

In [None]:
# Add some noise to the image.

noise = img + torch.randn_like(img) * 0.1
plt.imshow(noise, cmap="gray")
plt.show()

In [None]:
# Compute reconstruction of noisy input.

with torch.no_grad():  
    output = model(noise.view(-1, 784))
    plt.imshow(output.view(28, 28), cmap="gray")
    plt.show()

In [None]:
# Use encoder as dimensionality reduction.

dl = DataLoader(training_data, batch_size=5000, shuffle=True)

with torch.no_grad():
    i = iter(dl)
    batch, labels = i.next()
    output = model.encoder(batch.view(-1, 784))

In [None]:
%matplotlib notebook

from mpl_toolkits import mplot3d

fig = plt.figure(figsize=(7,7))
ax = plt.axes(projection='3d')

for i in range(10):
    selected = output[labels == i,:]
    ax.scatter3D(selected[:,0], selected[:,1], selected[:,2])

plt.show()