# 1.5 「ファインチューニング」で精度向上を実現する方法

- 本ファイルでは、学習済みのVGGモデルを使用し、ファインチューニングでアリとハチの画像を分類するモデルを学習します



# 学習目標

1.	PyTorchでGPUを使用する実装コードを書けるようになる
2.	最適化手法の設定において、層ごとに異なる学習率を設定したファインチューニングを実装できるようになる
3.	学習したネットワークを保存・ロードできるようになる



# 事前準備

- 1.4節で解説したAWS EC2 のGPUインスタンスを使用します


In [1]:
# パッケージのimport
import random
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm

In [2]:
print(torch.__version__, torch.cuda.is_available())

2.0.0+cu117 True


In [3]:
# 乱数のシードを設定
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

In [4]:
size = 256
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

batch_size = 32

In [5]:
OUTPUT_FEATURES = 2

MODEL_CATEGORY = 'EfficientNet'

OPTIMIZER_TYPE = 'Adam'
LEARNING_RATIO = 0.001
MOMENTUM = 0.0
WEIGHT_DECAY = 1e-6

NUM_EPOCHS = 30

# MLflow

In [6]:
import mlflow

In [7]:
# experimentの作成(読み込み)
experiment_id = mlflow.set_experiment(MODEL_CATEGORY)  # experimentの設定. 無ければ新規に作成.
print(experiment_id.experiment_id)

978472247857902971


# DatasetとDataLoaderを作成

In [8]:
# 1.3節で作成したクラスを同じフォルダにあるmake_dataset_dataloader.pyに記載して使用
from utils.dataloader_image_classification import (
    HymenopteraDataset,
    ImageTransform,
    make_datapath_list,
)

# アリとハチの画像へのファイルパスのリストを作成する
train_list = make_datapath_list(phase="train")
val_list = make_datapath_list(phase="val")

# Datasetを作成する
train_dataset = HymenopteraDataset(
    file_list=train_list, transform=ImageTransform(size, mean, std), phase="train"
)

val_dataset = HymenopteraDataset(
    file_list=val_list, transform=ImageTransform(size, mean, std), phase="val"
)


# DataLoaderを作成する
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

# 辞書オブジェクトにまとめる
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

./data/hymenoptera_data/train/**/*.jpg
./data/hymenoptera_data/val/**/*.jpg


# モデルを学習させる関数を作成

In [9]:
def plot_history(name, history):
    fig, axes = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True)
    axes[0].plot(
        range(len(history["train"]["loss"])),
        history["train"]["loss"],
        "r-o",
        label=f"{name}-train",
    )
    axes[0].plot(
        range(len(history["val"]["loss"])),
        history["val"]["loss"],
        "b--s",
        label=f"{name}-val",
    )
    axes[0].set_xlabel("Epochs", size=14)
    axes[0].set_ylabel("Loss", size=14)
    axes[0].tick_params(labelsize=12)
    axes[0].grid()
    axes[0].legend()

    axes[1].plot(
        range(len(history["train"]["acc"])),
        history["train"]["acc"],
        "r-o",
        label=f"{name}-train",
    )
    axes[1].plot(
        range(len(history["val"]["acc"])),
        history["val"]["acc"],
        "b--s",
        label=f"{name}-val",
    )
    axes[1].set_xlabel("Epochs", size=14)
    axes[1].set_ylabel("Accuracy", size=14)
    axes[1].tick_params(labelsize=12)
    axes[1].grid()
    axes[1].legend()
    plt.suptitle(f"{name}", size=16)
    # plt.show()
    return fig

In [10]:
def train_model(name, net, dataloaders_dict, criterion, optimizer, num_epochs):
    # 初期設定
    # GPUが使えるかを確認
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用デバイス：", device)

    # ネットワークをGPUへ
    net.to(device)

    # ネットワークがある程度固定であれば、高速化させる
    torch.backends.cudnn.benchmark = True

    # epochのループ
    history = {}
    history["train"] = {}
    history["val"] = {}
    history["train"]["loss"] = []
    history["train"]["acc"] = []
    history["val"]["loss"] = []
    history["val"]["acc"] = []

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        # print("-------------")

        # epochごとの訓練と検証のループ
        for phase in ["train", "val"]:
            if phase == "train":
                net.train()  # モデルを訓練モードに
            else:
                net.eval()  # モデルを検証モードに

            epoch_loss = 0.0  # epochの損失和
            epoch_corrects = 0  # epochの正解数

            # # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略
            # if (epoch == 0) and (phase == "train"):
            #     continue

            # データローダーからミニバッチを取り出すループ
            for inputs, labels in tqdm(dataloaders_dict[phase]):
                # GPUが使えるならGPUにデータを送る
                inputs = inputs.to(device)
                labels = labels.to(device)

                # optimizerを初期化
                optimizer.zero_grad()

                # 順伝搬（forward）計算
                with torch.set_grad_enabled(phase == "train"):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)  # 損失を計算
                    _, preds = torch.max(outputs, 1)  # ラベルを予測

                    # 訓練時はバックプロパゲーション
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                    # 結果の計算
                    epoch_loss += loss.item() * inputs.size(0)  # lossの合計を更新
                    # 正解数の合計を更新
                    epoch_corrects += torch.sum(preds == labels.data)

            # epochごとのlossと正解率を表示
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.cpu().double() / len(dataloaders_dict[phase].dataset)

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
            history[phase]["loss"].append(epoch_loss)
            history[phase]["acc"].append(epoch_acc)

            # MLflow: track metrics
            mlflow.log_metric(f"{phase}_loss", epoch_loss, step=epoch)  # Train Loss
            mlflow.log_metric(f"{phase}_acc", epoch_acc, step=epoch)  # Train Loss

    return history

# 損失関数を定義

In [11]:
# 損失関数の設定
criterion = nn.CrossEntropyLoss()

# ネットワークモデルの作成

In [None]:
def replace_last_layer_efficientnet(_net, output_features):
    last_in_features = _net.classifier[-1].in_features
    _net.classifier[-1] = nn.Linear(in_features=last_in_features, out_features=output_features)
    return _net

def params_in_last_layers(_net):
    params_to_update = []
    for param in _net.fc.parameters():
        param.requires_grad = True
        params_to_update.append(param)
    return params_to_update

In [9]:
efficientnet_models = {
    "efficientnet_b0": {'model': models.efficientnet_b0, 'weights': models.EfficientNet_B0_Weights.DEFAULT, 'layers': None},
    "efficientnet_b2": {'model': models.efficientnet_b2, 'weights': models.EfficientNet_B2_Weights.DEFAULT, 'layers': None},
    "efficientnet_b4": {'model': models.efficientnet_b4, 'weights': models.EfficientNet_B4_Weights.DEFAULT, 'layers': None},
    "efficientnet_b6": {"model": models.efficientnet_b6, 'weights': models.EfficientNet_B6_Weights.DEFAULT, 'layers': None},
}

In [17]:
mlflow.end_run()

In [None]:
resnet_results = {}
for name, model_dict in resnet_models.items():
    model = model_dict["model"]
    weights = model_dict["weights"]

    # MLflow: runの作成
    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = f"{name}_{OPTIMIZER_TYPE}_LR{LEARNING_RATIO}_EPOCH{NUM_EPOCHS}_{now}"
    mlflow_run = mlflow.start_run(
        experiment_id=experiment_id.experiment_id,  # set_experimentの返り値を入れる.
        run_name=run_name,  # run_nameに、作成時刻を用いるようにした.
    )

    # MLflow: track params
    mlflow.log_params(
        {
            "Model": name,
            "Layers": model_dict["layers"],
            "Learning_ratio": LEARNING_RATIO,  # Learning ratio
            "Num_epochs": NUM_EPOCHS,  # num of Epochs
            "Optimizer": OPTIMIZER_TYPE,  # optimizer
        }
    )

    _net = model(weights=weights)
    _net = replace_last_layer_efficientnet(_net, OUTPUT_FEATURES)
    _net.train()
    print("Model:", name, last_in_features)

    # 学習させるパラメータ以外は勾配計算をなくし、変化しないように設定
    for param in _net.parameters():
        param.requires_grad = False

    params_to_update = params_in_last_layers(_net)

    if OPTIMIZER_TYPE=='SGD':
        optimizer = optim.SGD(
            params_to_update,
            lr=LEARNING_RATIO,
            momentum=MOMENTUM,
        )
        mlflow.log_params(
            {
                "Momentum": MOMENTUM,
            }
        )
    elif OPTIMIZER_TYPE=='Adam':
        optimizer = optim.Adam(
            params_to_update,
            lr=LEARNING_RATIO,
            weight_decay=WEIGHT_DECAY,
        )
        mlflow.log_params(
            {
                "Weight_decay": WEIGHT_DECAY,
            }
        )

    history = train_model(
        name, _net, dataloaders_dict, criterion, optimizer, num_epochs=NUM_EPOCHS
    )
    resnet_results[name] = history

    # # MLflow: track figure
    # figure = plot_history(name, history)
    # mlflow.log_figure(figure, "history.png")

    # MLflowL end run
    mlflow.end_run()

Model: resnet18 512
使用デバイス： cuda:0
Epoch 1/30


100%|██████████| 8/8 [00:03<00:00,  2.55it/s]


train Loss: 0.7300 Acc: 0.5473


100%|██████████| 5/5 [00:02<00:00,  2.47it/s]


val Loss: 0.6059 Acc: 0.6471
Epoch 2/30


100%|██████████| 8/8 [00:03<00:00,  2.50it/s]


train Loss: 0.5405 Acc: 0.7037


100%|██████████| 5/5 [00:01<00:00,  2.50it/s]


val Loss: 0.4744 Acc: 0.8235
Epoch 3/30


100%|██████████| 8/8 [00:03<00:00,  2.59it/s]


train Loss: 0.4130 Acc: 0.8765


100%|██████████| 5/5 [00:02<00:00,  2.32it/s]


val Loss: 0.3882 Acc: 0.9020
Epoch 4/30


100%|██████████| 8/8 [00:03<00:00,  2.49it/s]


train Loss: 0.3609 Acc: 0.8971


100%|██████████| 5/5 [00:02<00:00,  2.50it/s]


val Loss: 0.3195 Acc: 0.9281
Epoch 5/30


100%|██████████| 8/8 [00:02<00:00,  2.67it/s]


train Loss: 0.2922 Acc: 0.9177


100%|██████████| 5/5 [00:02<00:00,  2.38it/s]


val Loss: 0.2804 Acc: 0.9412
Epoch 6/30


100%|██████████| 8/8 [00:03<00:00,  2.44it/s]


train Loss: 0.2595 Acc: 0.9342


100%|██████████| 5/5 [00:02<00:00,  2.27it/s]


val Loss: 0.2614 Acc: 0.9477
Epoch 7/30


100%|██████████| 8/8 [00:03<00:00,  2.52it/s]


train Loss: 0.2508 Acc: 0.9342


100%|██████████| 5/5 [00:02<00:00,  2.49it/s]


val Loss: 0.2359 Acc: 0.9542
Epoch 8/30


100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


train Loss: 0.2211 Acc: 0.9424


100%|██████████| 5/5 [00:01<00:00,  2.78it/s]


val Loss: 0.2270 Acc: 0.9477
Epoch 9/30


100%|██████████| 8/8 [00:03<00:00,  2.49it/s]


train Loss: 0.2029 Acc: 0.9465


100%|██████████| 5/5 [00:02<00:00,  2.39it/s]


val Loss: 0.2114 Acc: 0.9477
Epoch 10/30


100%|██████████| 8/8 [00:03<00:00,  2.53it/s]


train Loss: 0.1720 Acc: 0.9753


100%|██████████| 5/5 [00:01<00:00,  2.57it/s]


val Loss: 0.2029 Acc: 0.9477
Epoch 11/30


100%|██████████| 8/8 [00:03<00:00,  2.42it/s]


train Loss: 0.1952 Acc: 0.9342


100%|██████████| 5/5 [00:02<00:00,  2.27it/s]


val Loss: 0.2072 Acc: 0.9477
Epoch 12/30


100%|██████████| 8/8 [00:03<00:00,  2.52it/s]


train Loss: 0.1658 Acc: 0.9712


100%|██████████| 5/5 [00:01<00:00,  2.64it/s]


val Loss: 0.1888 Acc: 0.9542
Epoch 13/30


100%|██████████| 8/8 [00:02<00:00,  2.77it/s]


train Loss: 0.1877 Acc: 0.9424


100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


val Loss: 0.1847 Acc: 0.9477
Epoch 14/30


100%|██████████| 8/8 [00:03<00:00,  2.53it/s]


train Loss: 0.1618 Acc: 0.9588


100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


val Loss: 0.1862 Acc: 0.9477
Epoch 15/30


100%|██████████| 8/8 [00:02<00:00,  2.79it/s]


train Loss: 0.1663 Acc: 0.9342


100%|██████████| 5/5 [00:01<00:00,  2.74it/s]


val Loss: 0.1763 Acc: 0.9542
Epoch 16/30


100%|██████████| 8/8 [00:03<00:00,  2.51it/s]


train Loss: 0.1507 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.39it/s]


val Loss: 0.1738 Acc: 0.9477
Epoch 17/30


100%|██████████| 8/8 [00:03<00:00,  2.60it/s]


train Loss: 0.1239 Acc: 0.9753


100%|██████████| 5/5 [00:01<00:00,  2.61it/s]


val Loss: 0.1682 Acc: 0.9608
Epoch 18/30


100%|██████████| 8/8 [00:03<00:00,  2.47it/s]


train Loss: 0.1365 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.43it/s]


val Loss: 0.1723 Acc: 0.9477
Epoch 19/30


100%|██████████| 8/8 [00:03<00:00,  2.60it/s]


train Loss: 0.1222 Acc: 0.9753


100%|██████████| 5/5 [00:01<00:00,  2.50it/s]


val Loss: 0.1723 Acc: 0.9477
Epoch 20/30


100%|██████████| 8/8 [00:03<00:00,  2.59it/s]


train Loss: 0.1306 Acc: 0.9547


100%|██████████| 5/5 [00:02<00:00,  2.29it/s]


val Loss: 0.1647 Acc: 0.9542
Epoch 21/30


100%|██████████| 8/8 [00:03<00:00,  2.66it/s]


train Loss: 0.1014 Acc: 0.9794


100%|██████████| 5/5 [00:02<00:00,  2.47it/s]


val Loss: 0.1620 Acc: 0.9542
Epoch 22/30


100%|██████████| 8/8 [00:03<00:00,  2.66it/s]


train Loss: 0.1242 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.22it/s]


val Loss: 0.1615 Acc: 0.9477
Epoch 23/30


100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


train Loss: 0.1273 Acc: 0.9506


100%|██████████| 5/5 [00:02<00:00,  2.34it/s]


val Loss: 0.1616 Acc: 0.9477
Epoch 24/30


100%|██████████| 8/8 [00:03<00:00,  2.52it/s]


train Loss: 0.1100 Acc: 0.9712


100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


val Loss: 0.1564 Acc: 0.9542
Epoch 25/30


100%|██████████| 8/8 [00:02<00:00,  2.67it/s]


train Loss: 0.1239 Acc: 0.9630


100%|██████████| 5/5 [00:02<00:00,  2.43it/s]


val Loss: 0.1579 Acc: 0.9542
Epoch 26/30


100%|██████████| 8/8 [00:03<00:00,  2.55it/s]


train Loss: 0.1202 Acc: 0.9630


100%|██████████| 5/5 [00:01<00:00,  2.52it/s]


val Loss: 0.1561 Acc: 0.9542
Epoch 27/30


100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


train Loss: 0.1304 Acc: 0.9547


100%|██████████| 5/5 [00:02<00:00,  2.33it/s]


val Loss: 0.1512 Acc: 0.9542
Epoch 28/30


100%|██████████| 8/8 [00:03<00:00,  2.47it/s]


train Loss: 0.1197 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.35it/s]


val Loss: 0.1789 Acc: 0.9346
Epoch 29/30


100%|██████████| 8/8 [00:03<00:00,  2.58it/s]


train Loss: 0.1251 Acc: 0.9506


100%|██████████| 5/5 [00:02<00:00,  2.48it/s]


val Loss: 0.1532 Acc: 0.9542
Epoch 30/30


100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


train Loss: 0.1011 Acc: 0.9794


100%|██████████| 5/5 [00:01<00:00,  2.63it/s]


val Loss: 0.1574 Acc: 0.9412
Model: resnet34 512
使用デバイス： cuda:0
Epoch 1/30


100%|██████████| 8/8 [00:03<00:00,  2.05it/s]


train Loss: 0.5762 Acc: 0.7325


100%|██████████| 5/5 [00:02<00:00,  2.36it/s]


val Loss: 0.4110 Acc: 0.9085
Epoch 2/30


100%|██████████| 8/8 [00:03<00:00,  2.19it/s]


train Loss: 0.3937 Acc: 0.8765


100%|██████████| 5/5 [00:02<00:00,  2.08it/s]


val Loss: 0.3028 Acc: 0.9085
Epoch 3/30


100%|██████████| 8/8 [00:03<00:00,  2.05it/s]


train Loss: 0.2839 Acc: 0.9424


100%|██████████| 5/5 [00:02<00:00,  2.09it/s]


val Loss: 0.2548 Acc: 0.9477
Epoch 4/30


100%|██████████| 8/8 [00:03<00:00,  2.22it/s]


train Loss: 0.2685 Acc: 0.9342


100%|██████████| 5/5 [00:02<00:00,  2.16it/s]


val Loss: 0.2151 Acc: 0.9412
Epoch 5/30


100%|██████████| 8/8 [00:03<00:00,  2.18it/s]


train Loss: 0.2271 Acc: 0.9383


100%|██████████| 5/5 [00:02<00:00,  2.15it/s]


val Loss: 0.1956 Acc: 0.9412
Epoch 6/30


100%|██████████| 8/8 [00:03<00:00,  2.20it/s]


train Loss: 0.1750 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.15it/s]


val Loss: 0.1939 Acc: 0.9412
Epoch 7/30


100%|██████████| 8/8 [00:03<00:00,  2.22it/s]


train Loss: 0.1976 Acc: 0.9424


100%|██████████| 5/5 [00:02<00:00,  2.08it/s]


val Loss: 0.1716 Acc: 0.9542
Epoch 8/30


100%|██████████| 8/8 [00:03<00:00,  2.24it/s]


train Loss: 0.1555 Acc: 0.9588


100%|██████████| 5/5 [00:02<00:00,  1.97it/s]


val Loss: 0.1671 Acc: 0.9346
Epoch 9/30


100%|██████████| 8/8 [00:03<00:00,  2.25it/s]


train Loss: 0.1463 Acc: 0.9712


100%|██████████| 5/5 [00:02<00:00,  2.11it/s]


val Loss: 0.1591 Acc: 0.9477
Epoch 10/30


100%|██████████| 8/8 [00:03<00:00,  2.15it/s]


train Loss: 0.1219 Acc: 0.9794


100%|██████████| 5/5 [00:02<00:00,  2.27it/s]


val Loss: 0.1523 Acc: 0.9477
Epoch 11/30


100%|██████████| 8/8 [00:03<00:00,  2.35it/s]


train Loss: 0.1162 Acc: 0.9794


100%|██████████| 5/5 [00:02<00:00,  2.09it/s]


val Loss: 0.1483 Acc: 0.9477
Epoch 12/30


100%|██████████| 8/8 [00:03<00:00,  2.50it/s]


train Loss: 0.1322 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.17it/s]


val Loss: 0.1537 Acc: 0.9477
Epoch 13/30


100%|██████████| 8/8 [00:03<00:00,  2.07it/s]


train Loss: 0.1293 Acc: 0.9671


100%|██████████| 5/5 [00:02<00:00,  2.07it/s]


val Loss: 0.1414 Acc: 0.9673
Epoch 14/30


100%|██████████| 8/8 [00:03<00:00,  2.06it/s]


train Loss: 0.1505 Acc: 0.9465


100%|██████████| 5/5 [00:02<00:00,  1.94it/s]


val Loss: 0.1560 Acc: 0.9216
Epoch 15/30


100%|██████████| 8/8 [00:03<00:00,  2.21it/s]


train Loss: 0.1079 Acc: 0.9712


100%|██████████| 5/5 [00:02<00:00,  2.02it/s]


val Loss: 0.1356 Acc: 0.9739
Epoch 16/30


100%|██████████| 8/8 [00:03<00:00,  2.15it/s]


train Loss: 0.1047 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.28it/s]


val Loss: 0.1421 Acc: 0.9477
Epoch 17/30


100%|██████████| 8/8 [00:03<00:00,  2.36it/s]


train Loss: 0.1098 Acc: 0.9753


100%|██████████| 5/5 [00:02<00:00,  2.11it/s]


val Loss: 0.1368 Acc: 0.9477
Epoch 18/30


100%|██████████| 8/8 [00:03<00:00,  2.16it/s]


train Loss: 0.1000 Acc: 0.9671


100%|██████████| 5/5 [00:02<00:00,  2.10it/s]


val Loss: 0.1333 Acc: 0.9477
Epoch 19/30


100%|██████████| 8/8 [00:03<00:00,  2.17it/s]


train Loss: 0.1010 Acc: 0.9794


100%|██████████| 5/5 [00:02<00:00,  2.09it/s]


val Loss: 0.1234 Acc: 0.9542
Epoch 20/30


100%|██████████| 8/8 [00:03<00:00,  2.13it/s]


train Loss: 0.0828 Acc: 0.9877


100%|██████████| 5/5 [00:02<00:00,  2.23it/s]


val Loss: 0.1230 Acc: 0.9542
Epoch 21/30


100%|██████████| 8/8 [00:03<00:00,  2.30it/s]


train Loss: 0.0984 Acc: 0.9630


100%|██████████| 5/5 [00:02<00:00,  2.16it/s]


val Loss: 0.1345 Acc: 0.9477
Epoch 22/30


100%|██████████| 8/8 [00:03<00:00,  2.19it/s]


train Loss: 0.0857 Acc: 0.9918


100%|██████████| 5/5 [00:02<00:00,  2.14it/s]


val Loss: 0.1219 Acc: 0.9608
Epoch 23/30


100%|██████████| 8/8 [00:03<00:00,  2.30it/s]


train Loss: 0.0910 Acc: 0.9794


100%|██████████| 5/5 [00:02<00:00,  1.98it/s]


val Loss: 0.1184 Acc: 0.9608
Epoch 24/30


 62%|██████▎   | 5/8 [00:02<00:01,  2.16it/s]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4), tight_layout=True)
for model in resnet_results.keys():
    # for phase in resnet_results[model].keys():
    for phase in ["val"]:
        axes[0].plot(
            range(len(resnet_results[model][phase]["loss"])),
            resnet_results[model][phase]["loss"],
            # plot_style[phase],
            label=f"{model}-{phase}",
        )
axes[0].set_xlabel("Epochs", size=14)
axes[0].set_ylabel("Loss", size=14)
axes[0].tick_params(labelsize=12)
axes[0].grid()
axes[0].legend()

for model in resnet_results.keys():
    # for phase in resnet_results[model].keys():
    for phase in ["val"]:
        axes[1].plot(
            range(len(resnet_results[model][phase]["acc"])),
            resnet_results[model][phase]["acc"],
            # plot_style[phase],
            label=f"{model}-{phase}",
        )
axes[1].set_xlabel("Epochs", size=14)
axes[1].set_ylabel("Accuracy", size=14)
axes[1].tick_params(labelsize=12)
axes[1].grid()
axes[1].legend()
plt.show()

# 学習したネットワークを保存・ロード

In [None]:
# # PyTorchのネットワークパラメータの保存
# save_path = "./weights_fine_tuning.pth"
# torch.save(net.state_dict(), save_path)

In [None]:
# # PyTorchのネットワークパラメータのロード
# load_path = "./weights_fine_tuning.pth"
# load_weights = torch.load(load_path)
# net.load_state_dict(load_weights)

# # GPU上で保存された重みをCPU上でロードする場合
# load_weights = torch.load(load_path, map_location={"cuda:0": "cpu"})
# net.load_state_dict(load_weights)

以上