In [13]:
import torch

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

device

'mps'

In [14]:
from torchvision import datasets, transforms
import torch

train_ds = datasets.MNIST(
    root="../../data", train=True, transform=transforms.ToTensor(), download=True
)
test_ds = datasets.MNIST(
    root="../../data", train=False, transform=transforms.ToTensor(), download=True
)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=True)

## CNN classification model
* Conv layer 를 이용한 feature extraction
* feature extraction 이 완료된 feature map 을 이용해 추론하는 과정을 FCL 로 구현

In [15]:
from torch import nn


class MNIST_CNN(nn.Module):
    def __init__(self):
        super().__init__()

        ########## Feature Extraction ##########
        self.conv_1 = nn.Conv2d(
            in_channels=1,  # gray scale image
            out_channels=32,
            kernel_size=3,  # filter shape : (32, 1, 3, 3)
            padding=1,
        )
        self.relu_1 = nn.ReLU()

        self.conv_2 = nn.Conv2d(
            in_channels=32, out_channels=32, kernel_size=3, padding=1
        )
        self.relu_2 = nn.ReLU()

        self.max_pool_1 = nn.MaxPool2d(kernel_size=2)

        self.conv_3 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )
        self.relu_3 = nn.ReLU()

        self.conv_4 = nn.Conv2d(
            in_channels=64, out_channels=64, kernel_size=3, padding=1
        )
        self.relu_4 = nn.ReLU()

        self.max_pool_2 = nn.MaxPool2d(kernel_size=2)

        self.dropout = nn.Dropout(0.25)

        ########## Fully Connected Layer ##########
        self.flatten = nn.Flatten()  # (N, 64*7*7) after flattening

        self.linear_1 = nn.Linear(7 * 7 * 64, out_features=256)
        self.relu_5 = nn.ReLU()

        self.linear_2 = nn.Linear(256, 128)
        self.relu_6 = nn.ReLU()

        # output layer
        self.out_layer = nn.Linear(128, 10)

    def forward(self, x):
        # Feature Extraction
        x = self.conv_1(x)
        x = self.relu_1(x)

        x = self.conv_2(x)
        x = self.relu_2(x)

        x = self.max_pool_1(x)

        x = self.conv_3(x)
        x = self.relu_3(x)

        x = self.conv_4(x)
        x = self.relu_4(x)

        x = self.max_pool_2(x)

        x = self.dropout(x)

        # FCL
        x = self.flatten(x)

        x = self.linear_1(x)
        x = self.relu_5(x)

        x = self.linear_2(x)
        x = self.relu_6(x)

        y = self.out_layer(x)

        return y

In [16]:
model = MNIST_CNN()
model.to(device)

MNIST_CNN(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (conv_2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
  (max_pool_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_3): ReLU()
  (conv_4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_4): ReLU()
  (max_pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.25, inplace=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_1): Linear(in_features=3136, out_features=256, bias=True)
  (relu_5): ReLU()
  (linear_2): Linear(in_features=256, out_features=128, bias=True)
  (relu_6): ReLU()
  (out_layer): Linear(in_features=128, out_features=10, bias=True)
)

In [17]:
from torchinfo import summary

summary(model, input_size=(64, 1, 28, 28), device="mps")

Layer (type:depth-idx)                   Output Shape              Param #
MNIST_CNN                                [64, 10]                  --
├─Conv2d: 1-1                            [64, 32, 28, 28]          320
├─ReLU: 1-2                              [64, 32, 28, 28]          --
├─Conv2d: 1-3                            [64, 32, 28, 28]          9,248
├─ReLU: 1-4                              [64, 32, 28, 28]          --
├─MaxPool2d: 1-5                         [64, 32, 14, 14]          --
├─Conv2d: 1-6                            [64, 64, 14, 14]          18,496
├─ReLU: 1-7                              [64, 64, 14, 14]          --
├─Conv2d: 1-8                            [64, 64, 14, 14]          36,928
├─ReLU: 1-9                              [64, 64, 14, 14]          --
├─MaxPool2d: 1-10                        [64, 64, 7, 7]            --
├─Dropout: 1-11                          [64, 64, 7, 7]            --
├─Flatten: 1-12                          [64, 3136]                --
├─L

## Training

In [19]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [20]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)

    model.train()

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()  # initialize gradients
        loss.backward()  # back propagation
        optimizer.step()  # update params

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"Train Loss : {loss:>7f} [ {current:>5d} / {size:>5d} ]")

In [22]:
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)

    test_loss, correct = 0, 0
    model.eval()

    # initialize : required_grad=False
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    num_batches = len(dataloader)

    test_loss /= num_batches
    correct /= size

    print(
        f"Test Error : \n Accuracy : {(100*correct):>0.1f}%, Avg Loss : {test_loss:>8f}\n"
    )

In [None]:
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n........................")
    train_loop(train_dl, model, loss_fn, optimizer)
    test_loop(test_dl, model, loss_fn)

print("Experiment Successful")