In [None]:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
from torchvision.datasets import ImageFolder

data_folder = "../img"
data = ImageFolder(root=data_folder, transform=transform)

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(data))
val_size  = len(data) - train_size      
data_size  = {"train":train_size, "val":val_size}
data_train, data_val = random_split(data, [train_size, val_size])

In [None]:
from torch.utils.data import DataLoader

batch_size = 128

train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(data_val,   batch_size=batch_size, shuffle=False)

In [None]:
import torchvision.models as models

net = models.resnet18(pretrained=True)
print(net)

In [None]:
import torch.nn as nn

for param in net.parameters():
    param.requires_grad = False
    
net.fc = nn.Linear(512, 3)

# net.cuda()
print(net)

In [None]:
import copy
from torch import optim

loss_fnc = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

record_loss_train = []
record_loss_test = []

best_loss_test = 0.16
for i in range(20):
# for i in range(25):
    net.train()
    loss_train = 0
    for j, (x, t) in enumerate(train_loader):
        # x, t = x.cuda(), t.cuda()
        y = net(x)
        loss = loss_fnc(y, t)
        loss_train += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_train /= j+1
    record_loss_train.append(loss_train)
    
    net.eval()
    loss_test = 0
    for j, (x, t) in enumerate(val_loader):
        # x, t = x.cuda(), t.cuda()
        y = net(x)
        loss = loss_fnc(y, t)
        loss_test += loss.item()
    loss_test /= j+1
    record_loss_test.append(loss_test)
    
    if best_loss_test > loss_test:
        best_loss_test = loss_test
        best_model = copy.deepcopy(net.state_dict())

    print("Epoch:", i, "Loss_Train:", loss_train, "Loss_Test:", loss_test)

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(record_loss_train)), record_loss_train, label='Train')
plt.plot(range(len(record_loss_test)), record_loss_test, label='Test')

plt.xlabel("Epochs")
plt.ylabel("Error")
plt.legend()
plt.show()

In [None]:
correct = 0
total = 0
net.eval()  # 評価モード
for i, (x, t) in enumerate(val_loader):
    # x, t = x.cuda(), t.cuda()  # GPU対応
    y = net(x)
    correct += (y.argmax(1) == t).sum().item()
    total += len(x)
print("正解率:", str(correct/total*100) + "%")

In [None]:
import torch

torch.save(net, './model/resnet18.pth')