In [1]:
from torchvision import transforms

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torchvision
import torch

import matplotlib.pyplot as plt
import numpy as np

  Referenced from: <F0D48035-EF9E-3141-9F63-566920E60D7C> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <44B645FB-F027-3EE5-86D7-DBF8E2FC6264> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
device = torch.device('mps')

dataset = torchvision.datasets.MNIST('../data/', download=True, train=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(dataset, batch_size=50, shuffle=True)

In [7]:
class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class Deflatten(nn.Module):
    def __init__(self, k):
        super(Deflatten, self).__init__()
        self.k = k

    def forward(self, x):
        s = x.size()
        feature_size = int((s[1] // self.k) ** 0.5)
        return x.view(s[0], self.k, feature_size, feature_size)

In [10]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        k = 16
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, k, 3, stride=2),       nn.ReLU(),
            nn.Conv2d(k, 2 * k, 3, stride=2),   nn.ReLU(),
            nn.Conv2d(2 * k, 4 * k, 3, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 10),            nn.ReLU()
            )

        self.decoder = nn.Sequential(
            nn.Linear(10, 1024),                nn.ReLU(),
            Deflatten(4 * k),
            nn.ConvTranspose2d(4 * k, 2 * k, 3, stride=1), nn.ReLU(),
            nn.ConvTranspose2d(2 * k,     k, 3, stride=2),  nn.ReLU(),
            nn.ConvTranspose2d(    k,     1, 3, stride=2, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [11]:
model = AutoEncoder().to(device)
critention = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [13]:
for epoch in range(51):
    running_loss = 0.0
    for data in trainloader:
        inputs = data[0].to(device)
        optimizer.zero_grad()

        _, outputs = model(inputs)
        loss = critention(inputs, outputs)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
    
    cost = running_loss / len(trainloader)
    print("[%d] loss : %.6f" % (epoch + 1, cost))

[1] loss : 0.053370
[2] loss : 0.031154
[3] loss : 0.025750
[4] loss : 0.023207
[5] loss : 0.021826
[6] loss : 0.020950
[7] loss : 0.020365
[8] loss : 0.019942
[9] loss : 0.019616
[10] loss : 0.019364
[11] loss : 0.019129
[12] loss : 0.018941
[13] loss : 0.018777
[14] loss : 0.018638
[15] loss : 0.018503
[16] loss : 0.018384
[17] loss : 0.018261
[18] loss : 0.018157
[19] loss : 0.018077
[20] loss : 0.017979
[21] loss : 0.017900
[22] loss : 0.017857
[23] loss : 0.017746
[24] loss : 0.017683
[25] loss : 0.017636
[26] loss : 0.017562
[27] loss : 0.017516
[28] loss : 0.017448
[29] loss : 0.017405
[30] loss : 0.017369
