In [48]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms

In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [50]:
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root = "./mnistTrain", train = True, download = True, transform = transform)
test_dataset = datasets.MNIST(root = "./mnistTest", train = False, download = True, transform = transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 64, shuffle = False)

In [51]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size = 3, stride = 1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size = 3, stride = 1)
        self.activation = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size = 2, stride = 2)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(800, 10)
    def forward(self, x):
        x = self.pool(self.activation(self.conv1(x)))
        x = self.pool(self.activation(self.conv2(x)))
        x = self.flatten(x)
        x = self.linear(x)
        return x

In [52]:
model = CNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

In [54]:
model.train()
for _ in range(10):
    total_loss = 0.0
    for image, label in train_loader:
        image = image.to(device)
        label = label.to(device)
        output = model(image)
        loss = criterion(output, label)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Train Loss : {total_loss/len(train_loader):.6f}")

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for image, label in test_loader:
        image = image.to(device)
        label = label.to(device)
        pred = model(image).argmax(dim = 1)
        correct += (pred == label).sum().item()
        total += label.size(0)
print(f"ACC : {correct/total * 100:.2f}%")

Train Loss : 0.036151
Train Loss : 0.034331
Train Loss : 0.032326
Train Loss : 0.030116
Train Loss : 0.027974
Train Loss : 0.026903
Train Loss : 0.025486
Train Loss : 0.023618
Train Loss : 0.022714
Train Loss : 0.021117
ACC : 98.85%
