In [1]:
# -- MNISTとクラスの手書き文字を使用したニューラルネットワーク --

import torchvision.datasets as datasets
import torchvision.transforms as transforms

data_root = "./data"
data_root_mnist = f"{data_root}/mnist"

transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.Lambda(lambda x: x.view(-1))
])

mnist_train_set = datasets.MNIST(
    root = data_root_mnist,
    train = True,
    download = True,
    transform = transform
)

mnist_test_set = datasets.MNIST(
    root = data_root_mnist,
    train = False,
    download = True,
    transform = transform
)

print(f"MNISTの学習データ件数: {len(mnist_train_set)}")
print(f"MNISTのテストデータ件数: {len(mnist_test_set)}")

MNISTの学習データ件数: 60000
MNISTのテストデータ件数: 10000


In [2]:
import shutil
import os
from glob import glob
import random
from PIL import Image

zip_root_hand_digits = "./hand_digits.zip"
data_root_hand_digits = f"{data_root}/hand_digits"
data_root_hand_digits_train = f"{data_root_hand_digits}/train"
data_root_hand_digits_test = f"{data_root_hand_digits}/test"
data_root_hand_digits_tmp = f"{data_root_hand_digits}/tmp"

shutil.unpack_archive(zip_root_hand_digits, data_root_hand_digits_tmp)

for n in range(10):
    no_dir = f"{data_root_hand_digits_train}/{n}"
    os.makedirs(no_dir, exist_ok=True)

for n in range(10):
    no_dir = f"{data_root_hand_digits_test}/{n}"
    os.makedirs(no_dir, exist_ok=True)

my_hand_digits_images = glob(f"{data_root_hand_digits_tmp}/*.png")

for file in my_hand_digits_images:
    with Image.open(file) as img:
        for angle in range(-45, 46, 5):
            if angle:
                img_r = img.rotate(angle, fillcolor=0xffffff)
                img_r.save(file.replace('.png', f'_A{angle}.png'))

my_hand_digits_images = glob(f'{data_root_hand_digits_tmp}/**/*.png', recursive=True)

print("クラスの手書き文字の画像ファイル数:", len(my_hand_digits_images))
random.seed(0)
random.shuffle(my_hand_digits_images)

train_size = int(0.8 * len(my_hand_digits_images))
test_size = len(my_hand_digits_images) - train_size

for file in my_hand_digits_images[:train_size]:
    file_name = os.path.basename(file)
    label = file_name[0]
    dest_dir = f"{data_root_hand_digits_train}/{label}"
    shutil.move(file, f"{dest_dir}/{file_name}")

for file in my_hand_digits_images[train_size:]:
    file_name = os.path.basename(file)
    label = file_name[0]
    dest_dir = f"{data_root_hand_digits_test}/{label}"
    shutil.move(file, f"{dest_dir}/{file_name}")

shutil.rmtree(data_root_hand_digits_tmp)

クラスの手書き文字の画像ファイル数: 18962


In [5]:
hand_digits_train_set = datasets.ImageFolder(
    root = data_root_hand_digits_train,
    transform = transform
)

hand_digits_test_set = datasets.ImageFolder(
    root = data_root_hand_digits_test,
    transform = transform
)

In [6]:
from torch.utils.data import ConcatDataset

train_set = ConcatDataset([mnist_train_set, hand_digits_train_set])
test_set = ConcatDataset([mnist_test_set, hand_digits_test_set])

print(f"統合した学習データ件数: {len(train_set)}")
print(f"統合したテストデータ件数: {len(test_set)}")

統合した学習データ件数: 75169
統合したテストデータ件数: 13793


In [7]:
from torch.utils.data import DataLoader

batch_size = 100

train_loader = DataLoader(
    train_set, 
    batch_size = batch_size,
    shuffle = True
)

test_loader = DataLoader(
    test_set,  
    batch_size = batch_size,
    shuffle = True
)

In [8]:
import torch.nn as nn

class Net2(nn.Module):
    def __init__(self, n_input, n_output, n_hidden):
        super().__init__()

        self.l1 = nn.Linear(n_input, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_hidden)
        self.l3 = nn.Linear(n_hidden, n_output)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.l1(x)
        x2 = self.relu(x1)
        x3 = self.l2(x2)
        x4 = self.relu(x3)
        x5 = self.l3(x4)
        return x5

In [9]:
import torch
import numpy as np

inputs, labels = next(iter(train_loader))

n_input = inputs[0].shape[0]
n_output = 10
n_hidden = 128

torch.manual_seed(123)
torch.cuda.manual_seed(123)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms = True

net = Net2(
    n_input = n_input,
    n_output = n_output,
    n_hidden = n_hidden
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net.to(device)

leaning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=leaning_rate)
num_epochs = 50
history = np.zeros((0, 5))

In [10]:
from tqdm import tqdm

for epoch in range(num_epochs):
    n_train_acc, n_val_acc = 0, 0
    train_loss, val_loss = 0, 0
    n_train, n_val = 0, 0

    for inputs, labels in tqdm(train_loader):
        train_batch_size = len(labels)
        n_train += train_batch_size

        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)

        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        predicted = torch.max(outputs, 1)[1]

        train_loss += loss.item() * train_batch_size
        n_train_acc += (predicted == labels).sum().item()

    for inputs_test, labels_test in test_loader:
        test_batch_size = len(labels_test)
        n_val += test_batch_size

        inputs_test = inputs_test.to(device)
        labels_test = labels_test.to(device)

        outputs_test = net(inputs_test)

        loss_test = criterion(outputs_test, labels_test)

        predicted_test = torch.max(outputs_test, 1)[1]

        val_loss += loss_test.item() * test_batch_size
        n_val_acc += (predicted_test == labels_test).sum().item()

    train_acc = n_train_acc / n_train
    val_acc = n_val_acc / n_val
    ave_train_loss = train_loss / n_train
    ave_val_loss = val_loss / n_val

    print (f'Epoch [{epoch + 1}/{num_epochs}], loss: {ave_train_loss:.5f} acc: {train_acc:.5f} val_loss: {ave_val_loss:.5f}, val_acc: {val_acc:.5f}')

    item = np.array([epoch + 1, ave_train_loss, train_acc, ave_val_loss, val_acc])
    history = np.vstack((history, item))

100%|██████████| 752/752 [00:37<00:00, 20.05it/s]


Epoch [1/50], loss: 1.49602 acc: 0.55490 val_loss: 0.99942, val_acc: 0.70130


100%|██████████| 752/752 [01:16<00:00,  9.88it/s]


Epoch [2/50], loss: 0.77708 acc: 0.77986 val_loss: 0.78891, val_acc: 0.79381


100%|██████████| 752/752 [01:43<00:00,  7.23it/s]


Epoch [3/50], loss: 0.61690 acc: 0.82815 val_loss: 0.63017, val_acc: 0.82020


100%|██████████| 752/752 [00:53<00:00, 13.97it/s]


Epoch [4/50], loss: 0.51578 acc: 0.85145 val_loss: 0.53612, val_acc: 0.84775


100%|██████████| 752/752 [00:50<00:00, 14.83it/s]


Epoch [5/50], loss: 0.46174 acc: 0.86582 val_loss: 0.49221, val_acc: 0.85833


100%|██████████| 752/752 [00:51<00:00, 14.72it/s]


Epoch [6/50], loss: 0.42820 acc: 0.87442 val_loss: 0.46139, val_acc: 0.86442


100%|██████████| 752/752 [01:25<00:00,  8.84it/s]


Epoch [7/50], loss: 0.40222 acc: 0.88188 val_loss: 0.43982, val_acc: 0.86812


100%|██████████| 752/752 [01:42<00:00,  7.31it/s]


Epoch [8/50], loss: 0.37792 acc: 0.88897 val_loss: 0.40526, val_acc: 0.87849


100%|██████████| 752/752 [01:35<00:00,  7.87it/s]


Epoch [9/50], loss: 0.35446 acc: 0.89564 val_loss: 0.38674, val_acc: 0.88494


100%|██████████| 752/752 [01:36<00:00,  7.82it/s]


Epoch [10/50], loss: 0.33229 acc: 0.90211 val_loss: 0.36560, val_acc: 0.88994


100%|██████████| 752/752 [01:33<00:00,  8.07it/s]


Epoch [11/50], loss: 0.31188 acc: 0.90833 val_loss: 0.33824, val_acc: 0.89995


100%|██████████| 752/752 [01:34<00:00,  7.98it/s]


Epoch [12/50], loss: 0.29191 acc: 0.91341 val_loss: 0.32609, val_acc: 0.90154


100%|██████████| 752/752 [01:38<00:00,  7.63it/s]


Epoch [13/50], loss: 0.27430 acc: 0.91943 val_loss: 0.30549, val_acc: 0.90850


100%|██████████| 752/752 [01:36<00:00,  7.78it/s]


Epoch [14/50], loss: 0.25776 acc: 0.92486 val_loss: 0.28755, val_acc: 0.91423


100%|██████████| 752/752 [01:29<00:00,  8.37it/s]


Epoch [15/50], loss: 0.24283 acc: 0.92944 val_loss: 0.27704, val_acc: 0.91633


100%|██████████| 752/752 [00:58<00:00, 12.84it/s]


Epoch [16/50], loss: 0.22872 acc: 0.93306 val_loss: 0.26362, val_acc: 0.92170


100%|██████████| 752/752 [00:48<00:00, 15.66it/s]


Epoch [17/50], loss: 0.21640 acc: 0.93709 val_loss: 0.24991, val_acc: 0.92583


100%|██████████| 752/752 [00:53<00:00, 13.96it/s]


Epoch [18/50], loss: 0.20511 acc: 0.94069 val_loss: 0.23366, val_acc: 0.93076


100%|██████████| 752/752 [00:51<00:00, 14.48it/s]


Epoch [19/50], loss: 0.19395 acc: 0.94327 val_loss: 0.22517, val_acc: 0.93185


100%|██████████| 752/752 [00:52<00:00, 14.27it/s]


Epoch [20/50], loss: 0.18374 acc: 0.94628 val_loss: 0.21877, val_acc: 0.93330


100%|██████████| 752/752 [00:55<00:00, 13.59it/s]


Epoch [21/50], loss: 0.17463 acc: 0.94911 val_loss: 0.21212, val_acc: 0.93721


100%|██████████| 752/752 [01:28<00:00,  8.47it/s]


Epoch [22/50], loss: 0.16645 acc: 0.95179 val_loss: 0.19626, val_acc: 0.93903


100%|██████████| 752/752 [01:26<00:00,  8.69it/s]


Epoch [23/50], loss: 0.15859 acc: 0.95394 val_loss: 0.19056, val_acc: 0.94164


100%|██████████| 752/752 [00:55<00:00, 13.58it/s]


Epoch [24/50], loss: 0.15080 acc: 0.95636 val_loss: 0.18332, val_acc: 0.94374


100%|██████████| 752/752 [00:47<00:00, 15.69it/s]


Epoch [25/50], loss: 0.14425 acc: 0.95807 val_loss: 0.17504, val_acc: 0.94570


100%|██████████| 752/752 [00:57<00:00, 13.09it/s]


Epoch [26/50], loss: 0.13792 acc: 0.96025 val_loss: 0.17288, val_acc: 0.94867


100%|██████████| 752/752 [00:53<00:00, 13.93it/s]


Epoch [27/50], loss: 0.13166 acc: 0.96286 val_loss: 0.16443, val_acc: 0.94889


100%|██████████| 752/752 [00:53<00:00, 14.18it/s]


Epoch [28/50], loss: 0.12660 acc: 0.96367 val_loss: 0.16057, val_acc: 0.94816


100%|██████████| 752/752 [00:49<00:00, 15.04it/s]


Epoch [29/50], loss: 0.12080 acc: 0.96546 val_loss: 0.16280, val_acc: 0.94823


100%|██████████| 752/752 [00:51<00:00, 14.74it/s]


Epoch [30/50], loss: 0.11591 acc: 0.96718 val_loss: 0.15055, val_acc: 0.95280


100%|██████████| 752/752 [00:51<00:00, 14.63it/s]


Epoch [31/50], loss: 0.11129 acc: 0.96866 val_loss: 0.14501, val_acc: 0.95519


100%|██████████| 752/752 [00:50<00:00, 14.82it/s]


Epoch [32/50], loss: 0.10741 acc: 0.96988 val_loss: 0.14849, val_acc: 0.95287


100%|██████████| 752/752 [00:53<00:00, 14.04it/s]


Epoch [33/50], loss: 0.10338 acc: 0.97096 val_loss: 0.13903, val_acc: 0.95679


100%|██████████| 752/752 [00:52<00:00, 14.29it/s]


Epoch [34/50], loss: 0.09943 acc: 0.97234 val_loss: 0.13446, val_acc: 0.95889


100%|██████████| 752/752 [00:51<00:00, 14.63it/s]


Epoch [35/50], loss: 0.09570 acc: 0.97363 val_loss: 0.13213, val_acc: 0.95860


100%|██████████| 752/752 [00:45<00:00, 16.55it/s]


Epoch [36/50], loss: 0.09238 acc: 0.97482 val_loss: 0.12833, val_acc: 0.95976


 54%|█████▎    | 404/752 [00:26<00:22, 15.26it/s]


KeyboardInterrupt: 