In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

EPOCH=10
BATCH_SIZE=64
LR=0.0005
DOWNLOAD_MNIST=True
N_TEST_IMG=5

train_data=torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
)

In [None]:
plt.imshow(train_data.train_data[0], cmap='gray')
plt.title('%i'%train_data.train_labels[0])
plt.show()

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder=nn.Sequential(
            nn.Linear(28*28, 128),
            nn.Tanh(),
            nn.Linear(128, 32),
            nn.Tanh(),
            nn.Linear(32, 8),
            nn.Tanh(),
            nn.Linear(8, 3),
        )
        self.decoder=nn.Sequential(
            nn.Linear(3, 8),
            nn.Tanh(),
            nn.Linear(8, 32),
            nn.Tanh(),
            nn.Linear(32, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        encode=self.encoder(x)
        decode=self.decoder(encode)
        return encode, decode
    
autoencoder=AutoEncoder()
print(autoencoder)

In [None]:
train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
optimizer=torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func=nn.MSELoss()

f, a=plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()

view_data=train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())
#     print(train_data.train_labels[:N_TEST_IMG][i].item())

    a[0][i].title.set_text('%i'%train_data.train_labels[:N_TEST_IMG][i].numpy().item())

for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        x=x.view(-1, 28*28)
        
        # encode:  torch.Size([64, 3]) decode:  torch.Size([64, 784])
        encode, decode=autoencoder(x)
        
#         print("encode: ", encode.size(), "decode: ", decode.size())
        loss=loss_func(decode, x)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step%100==0:
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())
            _, decode=autoencoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(decode.data.numpy()[i].reshape(28, 28), cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())

plt.ioff()
plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

view_data=train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encode, _=autoencoder(view_data)

fig=plt.figure(2)
ax=Axes3D(fig)

print(encode)
X, Y, Z=encode.data[:,0].numpy(), encode.data[:,1].numpy(), encode.data[:,2].numpy()

values=train_data.train_labels[:200].numpy()

for x, y, z, s in zip(X, Y, Z, values):
    c=cm.rainbow(int(255*s/9))
    ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max())
ax.set_ylim(Y.min(), Y.max())
ax.set_ylim(Z.min(), Z.max())
plt.show()