In [1]:
from torchvision import transforms

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import torchvision
import torch

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

  Referenced from: <F0D48035-EF9E-3141-9F63-566920E60D7C> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in:     <44B645FB-F027-3EE5-86D7-DBF8E2FC6264> /Users/bahk_insung/miniconda3/lib/python3.10/site-packages/torch/lib/libtorch_cpu.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
device = torch.device('mps')

train_dataset = torchvision.datasets.MNIST('../data/', download=True, train=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=50, shuffle=True)

test_dataset = torchvision.datasets.MNIST('../data/', download=True, train=False, transform=transforms.ToTensor())
testloader   = torch.utils.data.DataLoader(test_dataset, batch_size=50, shuffle=True)

In [3]:
class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class Deflatten(nn.Module):
    def __init__(self, k):
        super(Deflatten, self).__init__()
        self.k = k

    def forward(self, x):
        s = x.size()
        feature_size = int((s[1] // self.k) ** 0.5)
        return x.view(s[0], self.k, feature_size, feature_size)

In [4]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        k = 16
        
        self.encoder = nn.Sequential(
            nn.Conv2d(    1,     k, 3, stride=2),   nn.ReLU(),
            nn.Conv2d(    k, 2 * k, 3, stride=2),   nn.ReLU(),
            nn.Conv2d(2 * k, 4 * k, 3, stride=2),   nn.ReLU(),
            nn.Conv2d(4 * k, 8 * k, 2, stride=1),   nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2048, 20),            nn.ReLU()
            )

        self.decoder = nn.Sequential(
            nn.Linear(20, 2048),                nn.ReLU(),
            Deflatten(8 * k),
            nn.ConvTranspose2d(8 * k, 4 * k, 2, stride=1),  nn.ReLU(),
            nn.ConvTranspose2d(4 * k, 2 * k, 3, stride=2),  nn.ReLU(),
            nn.ConvTranspose2d(2 * k,     k, 3, stride=2),  nn.ReLU(),
            nn.ConvTranspose2d(    k,     1, 3, stride=2, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [5]:
model = AutoEncoder().to(device)
critention = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
for epoch in range(51):
    running_loss = 0.0
    for data in trainloader:
        inputs = data[0].to(device)
        optimizer.zero_grad()

        _, outputs = model(inputs)
        loss = critention(inputs, outputs)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
    
    cost = running_loss / len(trainloader)
    print("[%d] loss : %.6f" % (epoch + 1, cost))

: 

: 

In [None]:
prediction, actual = [], []

with torch.no_grad():
    for image, labels in testloader:
        inputs = image.to(device)
        labels = labels.to(device)
        y_pred, _ = model(inputs)

        prediction += y_pred.cpu().tolist()
        actual += labels.cpu().tolist()

actual, prediction = np.array(actual), np.array(np.argmax(prediction, axis=1))
print(classification_report(actual, prediction))

In [None]:
sns.heatmap(
    confusion_matrix(y_pred=prediction, y_true=actual), 
    annot=True, 
    cmap='Blues',
    fmt='',
)