# やること
ブルーベリーの姿を分類する

In [1]:
import torch

print(torch.cuda.is_available())


True


In [2]:
# セットアップ
from __future__ import print_function

import glob
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from pathlib import Path
import seaborn as sns
import timm
from pprint import pprint

import copy
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
model_names = timm.list_models(pretrained=True)
l_in = [s for s in model_names if 'inception' in s]
pprint(l_in)

['inception_next_base.sail_in1k',
 'inception_next_base.sail_in1k_384',
 'inception_next_small.sail_in1k',
 'inception_next_tiny.sail_in1k',
 'inception_resnet_v2.tf_ens_adv_in1k',
 'inception_resnet_v2.tf_in1k',
 'inception_v3.gluon_in1k',
 'inception_v3.tf_adv_in1k',
 'inception_v3.tf_in1k',
 'inception_v3.tv_in1k',
 'inception_v4.tf_in1k']


# EfficientNet

In [3]:
# Training settings
epochs = 100
lr = 3e-5
gamma = 0.7
seed = 42

In [4]:
# シードの設定
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(seed)

In [5]:
# 学習用データセットの設定
device = 'cuda'

parents_dir='./classification-data-TreePower-all-remove-bg2/'

train_dataset_dir = Path(parents_dir+'train')
val_dataset_dir = Path(parents_dir+'val')
test_dataset_dir = Path(parents_dir+'test')
print(train_dataset_dir)

classification-data-TreePower-all-remove-bg2/train


## 普通に学習と実行

In [18]:
# transform
# ・画像のサイズを224x224にリサイズ
# ・左右反転によるData Augmentation
# ・Tensor型へデータ変更
# ・正規化

train_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
        transforms.RandomAffine(degrees=(-10,10),scale=(0.8, 1.2),translate=(0.2, 0.3)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

In [19]:
# データセットのロード
train_data = datasets.ImageFolder(train_dataset_dir,train_transforms)
valid_data = datasets.ImageFolder(val_dataset_dir, val_transforms)
test_data = datasets.ImageFolder(test_dataset_dir, test_transforms)

train_loader = DataLoader(dataset = train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset = test_data, batch_size=256, shuffle=False)

In [55]:
# モデルをcheck
model_names = timm.list_models(pretrained=True)
l_in = [s for s in model_names if 'inception' in s]
pprint(l_in)

['inception_next_base.sail_in1k',
 'inception_next_base.sail_in1k_384',
 'inception_next_small.sail_in1k',
 'inception_next_tiny.sail_in1k',
 'inception_resnet_v2.tf_ens_adv_in1k',
 'inception_resnet_v2.tf_in1k',
 'inception_v3.gluon_in1k',
 'inception_v3.tf_adv_in1k',
 'inception_v3.tf_in1k',
 'inception_v3.tv_in1k',
 'inception_v4.tf_in1k']


In [20]:


# model = timm.create_model('tf_efficientnetv2_s.in21k_ft_in1k', pretrained=True, num_classes=3)
# model = timm.create_model('tf_efficientnetv2_m.in21k_ft_in1k', pretrained=True, num_classes=3)
# model = timm.create_model('tf_efficientnetv2_xl.in21k_ft_in1k', pretrained=True, num_classes=3)


model = timm.create_model('inception_v4.tf_in1k', pretrained=True, num_classes=3)

# model = timm.create_model('mobilenetv3_large_100.miil_in21k_ft_in1k', pretrained=True, num_classes=3)

# model = timm.create_model('resnetv2_50x1_bit.goog_in21k_ft_in1k', pretrained=True, num_classes=3)


model = model.to(device)

In [21]:
# EfficientNetによる学習

# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [15]:
best_loss = None

# Accuracy計算用の関数
def calculate_accuracy(output, target):
    output = (torch.sigmoid(output) >= 0.5)
    target = (target == 1.0)
    accuracy = torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
    return accuracy

train_acc_list = []
val_acc_list = []
train_loss_list = []
val_loss_list = []

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    train_acc_list.append(epoch_accuracy)
    val_acc_list.append(epoch_val_accuracy)
    train_loss_list.append(epoch_loss)
    val_loss_list.append(epoch_val_loss)

    if (best_loss is None) or (best_loss > val_loss):
        best_loss = val_loss
        model_path = './classification-weights/all/inceptionv4-bg2_100epoch2.pth'
        torch.save(model.state_dict(), model_path)

    print()


# 学習結果の可視化-------------------------------------------------------------------------
device2 = torch.device('cpu')

train_acc = []
train_loss = []
val_acc = []
val_loss = []

for i in range(epochs):
    train_acc2 = train_acc_list[i].to(device2)
    train_acc3 = train_acc2.clone().numpy()
    train_acc.append(train_acc3)

    train_loss2 = train_loss_list[i].to(device2)
    train_loss3 = train_loss2.clone().detach().numpy()
    train_loss.append(train_loss3)

    val_acc2 = val_acc_list[i].to(device2)
    val_acc3 = val_acc2.clone().numpy()
    val_acc.append(val_acc3)

    val_loss2 = val_loss_list[i].to(device2)
    val_loss3 = val_loss2.clone().numpy()
    val_loss.append(val_loss3)

#取得したデータをグラフ化する
sns.set()
num_epochs = epochs

fig = plt.subplots(figsize=(12, 4), dpi=80)

ax1 = plt.subplot(1,2,1)
ax1.plot(range(num_epochs), train_acc, c='b', label='train acc')
ax1.plot(range(num_epochs), val_acc, c='r', label='val acc')
ax1.set_xlabel('epoch', fontsize='12')
ax1.set_ylabel('accuracy', fontsize='12')
ax1.set_title('training and val acc', fontsize='14')
ax1.legend(fontsize='12')

ax2 = plt.subplot(1,2,2)
ax2.plot(range(num_epochs), train_loss, c='b', label='train loss')
ax2.plot(range(num_epochs), val_loss, c='r', label='val loss')
ax2.set_xlabel('epoch', fontsize='12')
ax2.set_ylabel('loss', fontsize='12')
ax2.set_title('training and val loss', fontsize='14')
ax2.legend(fontsize='12')
plt.show()


100%|██████████| 38/38 [00:46<00:00,  1.22s/it]


Epoch : 1 - loss : 0.0792 - acc: 0.9720 - val_loss : 1.7635 - val_acc: 0.7400




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 2 - loss : 0.0391 - acc: 0.9868 - val_loss : 1.6389 - val_acc: 0.7150




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 3 - loss : 0.0290 - acc: 0.9934 - val_loss : 1.7798 - val_acc: 0.7350




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 4 - loss : 0.0654 - acc: 0.9786 - val_loss : 0.9031 - val_acc: 0.7725




100%|██████████| 38/38 [00:45<00:00,  1.20s/it]


Epoch : 5 - loss : 0.0615 - acc: 0.9819 - val_loss : 1.2205 - val_acc: 0.7150




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 6 - loss : 0.0530 - acc: 0.9901 - val_loss : 1.0773 - val_acc: 0.7000




100%|██████████| 38/38 [00:45<00:00,  1.19s/it]


Epoch : 7 - loss : 0.2027 - acc: 0.9276 - val_loss : 1.4083 - val_acc: 0.7400




100%|██████████| 38/38 [00:45<00:00,  1.19s/it]


Epoch : 8 - loss : 0.0938 - acc: 0.9753 - val_loss : 1.5845 - val_acc: 0.7225




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 9 - loss : 0.0451 - acc: 0.9868 - val_loss : 1.4624 - val_acc: 0.7350




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 10 - loss : 0.0325 - acc: 0.9868 - val_loss : 1.9323 - val_acc: 0.7475




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 11 - loss : 0.0473 - acc: 0.9868 - val_loss : 1.3157 - val_acc: 0.6900




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 12 - loss : 0.0228 - acc: 0.9901 - val_loss : 1.2879 - val_acc: 0.7075




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 13 - loss : 0.0214 - acc: 0.9951 - val_loss : 1.8688 - val_acc: 0.6700




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 14 - loss : 0.0179 - acc: 0.9918 - val_loss : 1.9118 - val_acc: 0.6575




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 15 - loss : 0.0172 - acc: 0.9918 - val_loss : 2.6428 - val_acc: 0.6750




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 16 - loss : 0.0188 - acc: 0.9951 - val_loss : 2.2362 - val_acc: 0.7025




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 17 - loss : 0.0042 - acc: 1.0000 - val_loss : 2.2073 - val_acc: 0.7150




100%|██████████| 38/38 [00:44<00:00,  1.17s/it]


Epoch : 18 - loss : 0.0823 - acc: 0.9704 - val_loss : 1.1227 - val_acc: 0.6600




100%|██████████| 38/38 [00:45<00:00,  1.19s/it]


Epoch : 19 - loss : 0.0572 - acc: 0.9786 - val_loss : 1.8800 - val_acc: 0.7550




100%|██████████| 38/38 [00:45<00:00,  1.19s/it]


Epoch : 20 - loss : 0.0251 - acc: 0.9868 - val_loss : 1.8684 - val_acc: 0.7025




100%|██████████| 38/38 [00:44<00:00,  1.18s/it]


Epoch : 21 - loss : 0.0774 - acc: 0.9786 - val_loss : 1.2925 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 22 - loss : 0.0215 - acc: 0.9951 - val_loss : 1.8002 - val_acc: 0.7350




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 23 - loss : 0.0686 - acc: 0.9803 - val_loss : 1.0913 - val_acc: 0.7125




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 24 - loss : 0.0308 - acc: 0.9918 - val_loss : 1.4999 - val_acc: 0.7600




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 25 - loss : 0.0763 - acc: 0.9753 - val_loss : 0.8237 - val_acc: 0.7800




100%|██████████| 38/38 [00:41<00:00,  1.08s/it]


Epoch : 26 - loss : 0.0132 - acc: 0.9951 - val_loss : 2.5604 - val_acc: 0.5675




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 27 - loss : 0.0799 - acc: 0.9819 - val_loss : 1.2495 - val_acc: 0.7725




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 28 - loss : 0.0332 - acc: 0.9901 - val_loss : 0.9220 - val_acc: 0.7575




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 29 - loss : 0.0158 - acc: 0.9934 - val_loss : 1.9794 - val_acc: 0.6950




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 30 - loss : 0.0065 - acc: 0.9967 - val_loss : 1.4430 - val_acc: 0.7850




100%|██████████| 38/38 [00:41<00:00,  1.08s/it]


Epoch : 31 - loss : 0.0465 - acc: 0.9901 - val_loss : 1.2722 - val_acc: 0.7075




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 32 - loss : 0.0392 - acc: 0.9885 - val_loss : 1.0573 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 33 - loss : 0.0289 - acc: 0.9918 - val_loss : 0.8804 - val_acc: 0.7800




100%|██████████| 38/38 [00:41<00:00,  1.10s/it]


Epoch : 34 - loss : 0.0140 - acc: 0.9951 - val_loss : 1.6647 - val_acc: 0.7200




100%|██████████| 38/38 [00:41<00:00,  1.10s/it]


Epoch : 35 - loss : 0.0034 - acc: 1.0000 - val_loss : 1.7745 - val_acc: 0.7025




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 36 - loss : 0.0705 - acc: 0.9770 - val_loss : 1.1007 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 37 - loss : 0.0302 - acc: 0.9918 - val_loss : 1.7195 - val_acc: 0.6825




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 38 - loss : 0.0263 - acc: 0.9901 - val_loss : 1.2786 - val_acc: 0.7850




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 39 - loss : 0.0094 - acc: 0.9967 - val_loss : 1.6454 - val_acc: 0.6825




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 40 - loss : 0.0362 - acc: 0.9819 - val_loss : 1.8491 - val_acc: 0.7075




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 41 - loss : 0.0567 - acc: 0.9852 - val_loss : 1.2146 - val_acc: 0.7550




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 42 - loss : 0.0030 - acc: 0.9984 - val_loss : 1.6613 - val_acc: 0.7475




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 43 - loss : 0.0090 - acc: 0.9967 - val_loss : 1.4289 - val_acc: 0.7300




100%|██████████| 38/38 [00:41<00:00,  1.08s/it]


Epoch : 44 - loss : 0.0173 - acc: 0.9951 - val_loss : 3.3106 - val_acc: 0.6575




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 45 - loss : 0.0638 - acc: 0.9803 - val_loss : 1.0670 - val_acc: 0.7475




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 46 - loss : 0.0152 - acc: 0.9934 - val_loss : 1.3731 - val_acc: 0.7150




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 47 - loss : 0.0515 - acc: 0.9885 - val_loss : 1.2459 - val_acc: 0.6825




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 48 - loss : 0.0453 - acc: 0.9885 - val_loss : 0.9882 - val_acc: 0.7675




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 49 - loss : 0.0208 - acc: 0.9934 - val_loss : 2.7721 - val_acc: 0.6575




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 50 - loss : 0.0336 - acc: 0.9918 - val_loss : 2.0976 - val_acc: 0.6775




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 51 - loss : 0.0017 - acc: 1.0000 - val_loss : 1.6809 - val_acc: 0.7150




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 52 - loss : 0.0005 - acc: 1.0000 - val_loss : 1.9122 - val_acc: 0.7125




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 53 - loss : 0.0010 - acc: 1.0000 - val_loss : 2.2055 - val_acc: 0.7025




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 54 - loss : 0.0614 - acc: 0.9868 - val_loss : 1.3886 - val_acc: 0.6500




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 55 - loss : 0.0212 - acc: 0.9918 - val_loss : 1.4157 - val_acc: 0.6800




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 56 - loss : 0.0208 - acc: 0.9934 - val_loss : 1.5527 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.10s/it]


Epoch : 57 - loss : 0.0195 - acc: 0.9934 - val_loss : 1.5273 - val_acc: 0.7150




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 58 - loss : 0.0250 - acc: 0.9951 - val_loss : 1.4243 - val_acc: 0.8000




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 59 - loss : 0.0057 - acc: 0.9984 - val_loss : 1.7715 - val_acc: 0.7400




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 60 - loss : 0.0235 - acc: 0.9901 - val_loss : 1.1551 - val_acc: 0.7350




100%|██████████| 38/38 [00:41<00:00,  1.08s/it]


Epoch : 61 - loss : 0.0070 - acc: 0.9984 - val_loss : 1.4331 - val_acc: 0.7550




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 62 - loss : 0.0427 - acc: 0.9885 - val_loss : 1.2036 - val_acc: 0.7100




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 63 - loss : 0.0226 - acc: 0.9918 - val_loss : 2.0019 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 64 - loss : 0.0535 - acc: 0.9852 - val_loss : 1.5255 - val_acc: 0.6450




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 65 - loss : 0.0223 - acc: 0.9918 - val_loss : 1.2186 - val_acc: 0.7525




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 66 - loss : 0.0098 - acc: 0.9967 - val_loss : 1.7862 - val_acc: 0.7475




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 67 - loss : 0.0003 - acc: 1.0000 - val_loss : 2.3016 - val_acc: 0.7450




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 68 - loss : 0.0003 - acc: 1.0000 - val_loss : 2.0529 - val_acc: 0.7700




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 69 - loss : 0.0070 - acc: 0.9967 - val_loss : 2.7028 - val_acc: 0.6575




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 70 - loss : 0.0078 - acc: 0.9984 - val_loss : 1.6447 - val_acc: 0.7475




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 71 - loss : 0.0070 - acc: 0.9967 - val_loss : 3.4460 - val_acc: 0.6975




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 72 - loss : 0.0232 - acc: 0.9951 - val_loss : 1.6773 - val_acc: 0.7100




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 73 - loss : 0.0081 - acc: 0.9951 - val_loss : 1.9814 - val_acc: 0.7600




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 74 - loss : 0.0618 - acc: 0.9836 - val_loss : 1.5739 - val_acc: 0.6575




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 75 - loss : 0.0224 - acc: 0.9934 - val_loss : 1.3453 - val_acc: 0.7975




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 76 - loss : 0.0044 - acc: 0.9984 - val_loss : 2.7823 - val_acc: 0.6975




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 77 - loss : 0.1005 - acc: 0.9638 - val_loss : 1.1785 - val_acc: 0.7600




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 78 - loss : 0.0303 - acc: 0.9918 - val_loss : 0.8825 - val_acc: 0.7975




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 79 - loss : 0.0076 - acc: 0.9967 - val_loss : 1.3822 - val_acc: 0.8050




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 80 - loss : 0.0079 - acc: 0.9967 - val_loss : 1.8754 - val_acc: 0.7250




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 81 - loss : 0.0743 - acc: 0.9770 - val_loss : 1.4065 - val_acc: 0.7300




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 82 - loss : 0.0257 - acc: 0.9918 - val_loss : 1.0758 - val_acc: 0.7675




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 83 - loss : 0.0018 - acc: 1.0000 - val_loss : 1.4360 - val_acc: 0.7525




100%|██████████| 38/38 [00:42<00:00,  1.12s/it]


Epoch : 84 - loss : 0.0158 - acc: 0.9918 - val_loss : 1.4277 - val_acc: 0.7650




100%|██████████| 38/38 [00:42<00:00,  1.12s/it]


Epoch : 85 - loss : 0.0217 - acc: 0.9934 - val_loss : 1.2374 - val_acc: 0.8300




100%|██████████| 38/38 [00:41<00:00,  1.10s/it]


Epoch : 86 - loss : 0.0162 - acc: 0.9918 - val_loss : 1.3095 - val_acc: 0.7575




100%|██████████| 38/38 [00:42<00:00,  1.11s/it]


Epoch : 87 - loss : 0.0105 - acc: 0.9951 - val_loss : 2.0085 - val_acc: 0.7700




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 88 - loss : 0.0401 - acc: 0.9868 - val_loss : 1.2987 - val_acc: 0.7725




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 89 - loss : 0.0236 - acc: 0.9901 - val_loss : 1.8665 - val_acc: 0.7025




100%|██████████| 38/38 [00:41<00:00,  1.09s/it]


Epoch : 90 - loss : 0.0151 - acc: 0.9967 - val_loss : 1.4299 - val_acc: 0.7975




100%|██████████| 38/38 [00:42<00:00,  1.13s/it]


Epoch : 91 - loss : 0.0038 - acc: 0.9967 - val_loss : 1.7387 - val_acc: 0.7675




100%|██████████| 38/38 [00:42<00:00,  1.11s/it]


Epoch : 92 - loss : 0.0021 - acc: 1.0000 - val_loss : 1.7184 - val_acc: 0.7550




100%|██████████| 38/38 [00:41<00:00,  1.10s/it]


Epoch : 93 - loss : 0.0048 - acc: 0.9984 - val_loss : 1.6448 - val_acc: 0.7425




100%|██████████| 38/38 [00:42<00:00,  1.11s/it]


Epoch : 94 - loss : 0.0381 - acc: 0.9934 - val_loss : 2.8501 - val_acc: 0.6350




 16%|█▌        | 6/38 [00:13<01:26,  2.70s/it]

: 

In [30]:
model.eval()  # モデルを評価モードにする
model_path = './classification-weights/all/inception_v4_100epoch_.pth'
model.load_state_dict(torch.load(model_path))

loss_sum = 0
correct = 0

predictions = []  # 予測結果を格納するリスト

with torch.no_grad():
    for data, labels in test_loader:

        # GPUが使えるならGPUにデータを送る
        data = data.to(device)
        labels = labels.to(device)

        # ニューラルネットワークの処理を実施
        outputs = model(data)

        # 損失(出力とラベルとの誤差)の計算
        loss_sum += criterion(outputs, labels)

        # 正解の値を取得
        pred = outputs.argmax(1)
        predictions.extend(pred.cpu().numpy())  # 予測結果をリストに追加

        # 正解数をカウント
        correct += pred.eq(labels.view_as(pred)).sum().item()

# 画像ファイル名、正解ラベル、予測結果を出力
file_names = [Path(test_loader.dataset.samples[i][0]).name for i in range(len(test_loader.dataset.samples))]

result_df = pd.DataFrame({
    'File Name': file_names,
    'True Label': test_loader.dataset.targets,
    'Predicted Label': predictions
})


pd.set_option('display.max_rows', 1000)

print(f"\nLoss: {loss_sum.item() / len(test_loader)}, Accuracy: {100*correct/len(test_data)}% ({correct}/{len(test_data)})")


Loss: 0.8454700708389282, Accuracy: 71.71717171717172% (71/99)


In [31]:
import pandas as pd
from sklearn.metrics import accuracy_score,confusion_matrix

# True Labelごとの精度を計算
unique_labels = result_df['True Label'].unique()
accuracy_per_label = []


for label in unique_labels:
    subset_df = result_df[result_df['True Label'] == label]
    accuracy = accuracy_score(subset_df['True Label'], subset_df['Predicted Label'])
    accuracy_per_label.append({'True Label': label, 'Accuracy': accuracy})

# 結果をDataFrameに変換
accuracy_df = pd.DataFrame(accuracy_per_label)

# 結果の表示
print(accuracy_df)


   True Label  Accuracy
0           0  0.764706
1           1  0.673469
2           2  0.757576


In [32]:
print(confusion_matrix(result_df['True Label'], result_df['Predicted Label']))

[[13  4  0]
 [ 4 33 12]
 [ 0  8 25]]


In [34]:
result_df

Unnamed: 0,File Name,True Label,Predicted Label
0,east00058.JPG,0,1
1,north00028.JPG,0,0
2,north0003.JPG,0,0
3,north00075.JPG,0,0
4,north00079.JPG,0,0
5,north0009.JPG,0,0
6,north0061.JPG,0,1
7,south00031.JPG,0,1
8,south00041.JPG,0,0
9,south00075.JPG,0,0


## 検出結果の可視化

- pytorch-grad-camを使用
    - https://www.tcom242242.net/entry/python-basic/pytorch/grad-cam-pytorch/

In [16]:
# チュートリアル
import warnings
warnings.filterwarnings('ignore')
from torchvision import models
import numpy as np
import cv2
import requests
from pytorch_grad_cam import GradCAM,GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, \
    deprocess_image, \
    preprocess_image
from PIL import Image
import datetime

# Grad-CAMのための準備
# target_layers = [model.conv_head]
# target_layers = [model.features[-1].branch3[-1].conv]
target_layers = [model.stages[-1].blocks[-1].conv3]

# camplus = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())

# サンプル画像を取得
data_iter = iter(test_loader)
images, labels = next(data_iter)

print(len(images))



file_paths = [path for (path, _) in test_data.imgs[:len(images)]]  # データセットの最初の画像数分のファイルパス
print(file_paths)

model.eval()
model_path = './classification-weights/all/inceptionv4-bg2_100epoch2.pth'
model.load_state_dict(torch.load(model_path))

weight_name=model_path.split('/')[-1]

print(weight_name)

# 出力画像の保存先ディレクトリ
current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
save_dir = f'./output_images/all/{current_time+weight_name}/'
os.makedirs(save_dir, exist_ok=True)

for i,image_url in enumerate(file_paths):
    img = np.array(Image.open(image_url))
    img = cv2.resize(img, (224, 224))
    img = np.float32(img) / 255
    input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).to(device)

    target_label = labels[i].unsqueeze(0).to(device)

    # The target for the CAM is the Bear category.
    # As usual for classication, the target is the logit output
    # before softmax, for that category.
    targets = [ClassifierOutputTarget(target_label)]
    with GradCAMPlusPlus(model=model, target_layers=target_layers) as cam:
        grayscale_cams = cam(input_tensor=input_tensor, targets=targets)
        cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)
    cam = np.uint8(255*grayscale_cams[0, :])
    cam = cv2.merge([cam, cam])
    visualization = np.hstack((np.uint8(255*img) , cam_image))

    # 画像の保存（ファイル名は元のファイル名を使用）
    image_filename = os.path.basename(image_url)  # ファイルのベース名（拡張子付き）を取得
    save_path = os.path.join(save_dir, image_filename + '.png')
    plt.imsave(save_path, visualization)
    print(f"Saved image {i} to {save_path}")

99
['classification-data-TreePower-all-remove-bg2/test/1/east00058.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north00028.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north0003.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north00075.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north00079.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north0009.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/north0061.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/south00031.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/south00041.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/south00075.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/south00079.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/south0043.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/top0003.JPG', 'classification-data-TreePower-all-remove-bg2/test/1/top0017.JPG', 'classification-data-TreePo