In [None]:
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

# check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
random_seed = 42
lr = 0.001
batch_size = 32
n_epochs = 15

num_class = 10

In [80]:
kwargs = {'num_workers': 2, 'pin_memory': True} 
mnist_train = datasets.MNIST(root='MNIST_data/', 
                          train=True, 
                          transform=transforms.ToTensor(), 
                          download=True)

mnist_test = datasets.MNIST(root='MNIST_data/', 
                         train=False, # 
                         transform=transforms.ToTensor(), 
                         download=True)

train_loader = DataLoader(dataset=mnist_train, 
                          batch_size=batch_size, 
                          shuffle=True, **kwargs)

test_loader = DataLoader(dataset=mnist_test, 
                          batch_size=batch_size, 
                          shuffle=False, **kwargs)

In [81]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)

        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.dens1 = torch.nn.Linear(in_features=120, out_features=84)
        self.dens2 = torch.nn.Linear(in_features=84, out_features=10)
    def forward(self,x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(-1, self.num_flat_features(x))

        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.dens1(x))
        x = self.dens2(x)
        return x
    def num_flat_features(self, x):
        #x.size() return (256, 16, 5, 5)，size의 값은 (16, 5, 5)，256은 batch_size
        size = x.size()[1:]       
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

net = MyModel()
print(net)

MyModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (dens1): Linear(in_features=120, out_features=84, bias=True)
  (dens2): Linear(in_features=84, out_features=10, bias=True)
)


In [82]:
model.parameters()

<generator object Module.parameters at 0x7f0673cda5d0>

In [83]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

In [89]:
for _epoch in range(epoch):
    for idx, (train_x, train_label) in enumerate(train_loader):
        label_np = np.zeros((train_label.shape[0], 10))
        optimizer.zero_grad()
        predict_y = model(train_x.float())
        _error = criterion(predict_y, train_label.long())
        if idx % 10 == 0:
            print('idx: {}, _error: {}'.format(idx, _error))
        _error.backward()
        optimizer.step()

    correct = 0
    _sum = 0

    for idx, (test_x, test_label) in enumerate(test_loader):
        predict_y = model(test_x.float()).detach()
        predict_ys = np.argmax(predict_y, axis=-1)
        label_np = test_label.numpy()
        _ = predict_ys == test_label
        correct += np.sum(_.numpy(), axis=-1)
        _sum += _.shape[0]

    print('accuracy: {:.2f}'.format(correct / _sum))

size torch.Size([32, 400])
idx: 0, _error: 0.00840521976351738
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
idx: 10, _error: 0.07281673699617386
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
idx: 20, _error: 0.23840974271297455
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
size torch.Size([32, 400])
idx: 30, _error: 0.04577323794364929
size torch.Size(