In [11]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import time
import numpy as np
from sklearn import metrics
from sklearn import linear_model
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [31]:
transform = transforms.Compose([transforms.ToTensor(),])
                                #transforms.Normalize((0.1307,), (0.3081,))])

trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

testset = datasets.MNIST('./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset,  batch_size=4, shuffle=True)

In [34]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12, 3), 
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available(): print(torch.cuda.get_device_name(0))

In [36]:
autoencoder = AutoEncoder().to(device)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
criterion = nn.MSELoss()

for epoch in range(2):
    tmp=0
    for step, (x, label) in enumerate(train_loader):
        x, label = x.to(device),label.to(device)
        encoded, decoded = autoencoder(x.view(-1, 28*28))
        loss = criterion(decoded, x.view(-1, 28*28))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tmp+=loss.item()
        if step%1000 == 0:
            print(step,len(train_loader),tmp)
            tmp=0

0 15000 0.08855892717838287
1000 15000 61.06475191004574
2000 15000 57.60065796226263
3000 15000 56.73878778889775
4000 15000 55.63073338754475
5000 15000 55.42377556115389
6000 15000 54.45968600735068
7000 15000 53.89133690856397
8000 15000 53.57426704093814
9000 15000 54.8458361774683


KeyboardInterrupt: 

In [37]:
class AutoEncoder_Conv(nn.Module):
    def __init__(self):
        super(AutoEncoder_Conv, self).__init__()
        self.conv1 = nn.Conv2d(1,4,3)
        self.conv2 = nn.Conv2d(4,10,3)
        self.fc1 = nn.Linear(10*5*5, 60)
        self.fc2 = nn.Linear(60,10)
        
        self.fc3 = nn.Linear(10, 60)
        self.fc4 = nn.Linear(60,10*5*5)
        self.conv3 = nn.ConvTranspose2d(10,4,3)
        self.conv4 = nn.ConvTranspose2d(4,1,3)
        
        self.dropout = nn.Dropout(0.2)
        self.unpool = nn.MaxUnpool2d(2, stride=2)
        
    def forward(self, x):
        x = self.dropout(F.relu(self.conv1(x))) #4,26,26
        x,indices1 = F.max_pool2d(x, 2, 2, return_indices=True) #4,13,13
        x = self.dropout(F.relu(self.conv2(x))) #10,11,11
        x,indices2 = F.max_pool2d(x, 2, 2, return_indices=True) #10,5,5
        x = x.view(-1,10*5*5)
        x = self.dropout(F.relu(self.fc1(x)))
        encoded = self.fc2(x)
        
        x = self.dropout(F.relu(self.fc3(encoded)))
        x = self.dropout(F.relu(self.fc4(x)))
        x = x.view(-1,10,5,5)
        x = self.unpool(x, indices=indices2, output_size=torch.Size([-1, 10, 11, 11]))
        x = self.dropout(F.relu(self.conv3(x)))
        x = self.unpool(x, indices=indices1, output_size=torch.Size([-1, 4, 26, 26]))
        decoded = self.conv4(x)
        return encoded, decoded

In [38]:
autoencoder = AutoEncoder_Conv().to(device)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
criterion = nn.MSELoss()

for epoch in range(2):
    tmp = 0
    for step, (x, label) in enumerate(train_loader):
        x, label = x.to(device),label.to(device)
        encoded, decoded = autoencoder(x)
        loss = criterion(decoded, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tmp+=loss.item()
        if step%1000 == 999:
            print(step,len(train_loader),tmp)
            tmp=0

999 15000 67.65954528748989
1999 15000 58.875234646722674
2999 15000 57.91450025886297
3999 15000 57.47817151993513
4999 15000 56.860047267749906
5999 15000 56.84721294417977
6999 15000 56.442729430273175
7999 15000 56.616437405347824
8999 15000 56.3183146007359
9999 15000 55.729653012007475
10999 15000 55.68821342848241
11999 15000 55.8713449370116
12999 15000 55.79573033004999
13999 15000 55.85951363667846
14999 15000 55.45056273974478
999 15000 55.87870620936155
1999 15000 55.57280600257218
2999 15000 55.129319755360484
3999 15000 55.18465956300497
4999 15000 55.86387385055423
5999 15000 55.35528535209596
6999 15000 55.965941574424505


KeyboardInterrupt: 

In [29]:
autoencoder = AutoEncoder_Conv()
trainset[0][0]

tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0