In [10]:
import torch
import torchvision

class mnist_classifier(torch.nn.Module):
    def __init__(self, batch_size = 256, num_classes = 10):
        super(mnist_classifier, self).__init__()
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.layer0 = torch.nn.Sequential(
                                           torch.nn.Conv2d(1, 16, 3),
                                           torch.nn.ReLU(),
                                           torch.nn.Conv2d(16, 32, 3),
                                           torch.nn.ReLU(),
                                           torch.nn.MaxPool2d(2,2),
                                           torch.nn.Conv2d(32, 64, 3),
                                           torch.nn.ReLU(),
                                           torch.nn.MaxPool2d(2,2)
                                                                
        )

        self.flatten = torch.nn.Sequential(
                                            torch.nn.Linear(64*5*5, 100),
                                            torch.nn.ReLU(),
                                            torch.nn.Linear(100, self.num_classes)
        )

    def forward(self, inputs):
        out = self.layer0(inputs)
        out = out.reshape(self.batch_size,-1)
        return self.flatten(out)


#하이퍼 파라미터 
BATCH_SIZE = 256
LEARNING_RATE = 0.001
NUM_EPOCHS = 10


#augmentation & 데이터 로드 
transforms = torchvision.transforms.Compose([
                                             torchvision.transforms.ToTensor()
])
train_datagen = torchvision.datasets.MNIST(root = "./data", train = True, transform = transforms, target_transform=None, download=True)
test_datagen = torchvision.datasets.MNIST(root = "./data", train = False, transform = transforms, target_transform=None, download = True)
train_loader = torch.utils.data.DataLoader(train_datagen, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4, drop_last = True)
test_loader = torch.utils.data.DataLoader(test_datagen, batch_size = BATCH_SIZE, shuffle = False, num_workers = 4, drop_last= True)



#loss function & optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = mnist_classifier().to(device)
loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

#train
for epoch in range(NUM_EPOCHS):
    for i , (image, label) in enumerate(train_loader):
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        logits = model.forward(image)
        loss = loss_func(logits, label)
        loss.backward()
        optimizer.step()

        if i & 1000 ==0:
            print("loss : {}".format(loss))


#test
total = 0
correct = 0
with torch.no_grad():
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)

        logits = model.forward(image)
        _,  y_pred=torch.max(logits, 1)
        
        total += y_pred.shape[0]
        correct += (label == y_pred).sum().float()


print("accuracy : {}%".format(100 * correct / total))
        
 





loss : 2.3034162521362305
loss : 2.288954257965088
loss : 2.2750258445739746
loss : 2.2612826824188232
loss : 2.238125801086426
loss : 2.188901901245117
loss : 2.1385960578918457
loss : 2.1020667552948
loss : 0.9626778960227966
loss : 0.8317955136299133
loss : 0.7459065914154053
loss : 0.7567909955978394
loss : 0.6365780830383301
loss : 0.5527166724205017
loss : 0.6185261011123657
loss : 0.5654776096343994
loss : 0.09019298106431961
loss : 0.12619976699352264
loss : 0.08495225757360458
loss : 0.06955520808696747
loss : 0.08536488562822342
loss : 0.10468530654907227
loss : 0.09387943148612976
loss : 0.05416139215230942
loss : 0.08150587230920792
loss : 0.058031968772411346
loss : 0.09686620533466339
loss : 0.043970827013254166
loss : 0.05175198242068291
loss : 0.07467679679393768
loss : 0.08703519403934479
loss : 0.05318763107061386
loss : 0.017171800136566162
loss : 0.05447135120630264
loss : 0.04932793229818344
loss : 0.06138625368475914
loss : 0.03853132203221321
loss : 0.01437142491