### 1. Import

In [46]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

### 2. Dataset

In [47]:
train_dataset = datasets.MNIST(root='./data', train=True,  transform=transforms.ToTensor(), download=True)
train_dataset, val_dataset = random_split(train_dataset, [50000, 10000])
test_dataset  = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)

### 3. Dataloader

In [49]:
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_loader   = DataLoader(dataset=val_dataset,   batch_size=128, shuffle=False)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=128, shuffle=False)

In [20]:
x = torch.randn(128, 1, 28, 28)
conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
x = conv1(x)
print(x.shape)

torch.Size([128, 6, 24, 24])


In [34]:
x = torch.randn(128, 6, 24, 24)
torch.flatten(x, 1).shape

torch.Size([128, 3456])

### 3. CNNを定義

In [58]:
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16,kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=256, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))  # 16, 5, 5
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)  # 256 -> 64
        x = self.fc2(x)  # 64 -> 10
        return x

In [57]:
model = Net()
images, labels = next(iter(train_loader))
preds = model(images)
preds.shape

torch.Size([128, 256])


torch.Size([128, 10])

### 4. しょぼしょぼCNNを学習させてみる

In [59]:
from torch.optim import SGD
from sklearn.metrics import accuracy_score

model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.01)

model.zero_grad()

EPOCHS = 10

for epoch in range(EPOCHS):
    train_loss = []
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        
        preds = model(images)
        
        loss = criterion(preds, labels)
        train_loss.append(loss.item())
        
        loss.backward()
        optimizer.step()
    
    val_loss = []
    acc_list = []
    for i, (images, labels) in enumerate(val_loader):
        preds = model(images)
        loss = criterion(preds, labels)
        val_loss.append(loss.item())
        
        acc = accuracy_score(labels, preds.argmax(dim=1))
        acc_list.append(acc)
        
    print(f'Epoch: {epoch+1}, '\
          + f'Train Loss: {sum(train_loss)/len(train_loss):.4f}, '\
          + f'Val Loss: {sum(val_loss)/len(val_loss):.4f}, '
          + f'Val Acc: {sum(acc_list)/len(acc_list):.4f}')

Epoch: 1, Train Loss: 1.7913, Val Loss: 0.6478, Val Acc: 0.8174
Epoch: 2, Train Loss: 0.4280, Val Loss: 0.3253, Val Acc: 0.9042
Epoch: 3, Train Loss: 0.2880, Val Loss: 0.2415, Val Acc: 0.9295
Epoch: 4, Train Loss: 0.2316, Val Loss: 0.2150, Val Acc: 0.9355
Epoch: 5, Train Loss: 0.1965, Val Loss: 0.1861, Val Acc: 0.9442
Epoch: 6, Train Loss: 0.1712, Val Loss: 0.1601, Val Acc: 0.9532
Epoch: 7, Train Loss: 0.1535, Val Loss: 0.1426, Val Acc: 0.9592
Epoch: 8, Train Loss: 0.1401, Val Loss: 0.1346, Val Acc: 0.9616
Epoch: 9, Train Loss: 0.1295, Val Loss: 0.1306, Val Acc: 0.9621
Epoch: 10, Train Loss: 0.1212, Val Loss: 0.1166, Val Acc: 0.9669


In [None]:
"""
Epoch: 1, Train Loss: 0.6841, Val Loss: 0.3671, Val Acc: 0.8958
Epoch: 2, Train Loss: 0.3554, Val Loss: 0.3280, Val Acc: 0.9065
Epoch: 3, Train Loss: 0.3311, Val Loss: 0.3205, Val Acc: 0.9069
Epoch: 4, Train Loss: 0.3189, Val Loss: 0.3071, Val Acc: 0.9104
Epoch: 5, Train Loss: 0.3111, Val Loss: 0.3044, Val Acc: 0.9127
Epoch: 6, Train Loss: 0.3056, Val Loss: 0.2959, Val Acc: 0.9133
Epoch: 7, Train Loss: 0.3011, Val Loss: 0.2945, Val Acc: 0.9117
Epoch: 8, Train Loss: 0.2968, Val Loss: 0.2902, Val Acc: 0.9158
Epoch: 9, Train Loss: 0.2935, Val Loss: 0.2965, Val Acc: 0.9126
Epoch: 10, Train Loss: 0.2911, Val Loss: 0.2948, Val Acc: 0.9129
"""