In [7]:
import os
import time
import copy
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from tqdm.auto import tqdm

# Set device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load dataset
data_dir = "/home/kevinluo/breast_density_classification/datasets"
transform = transforms.Compose([
    transforms.Resize((960, 960)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), transform=transform)
    for x in ["train", "valid", "test"]
}

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=8, shuffle=True, num_workers=4)
    for x in ["train", "valid", "test"]
}

dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "valid", "test"]}
class_names = image_datasets["train"].classes

# Load the pre-trained ResNet101 model
model = models.resnet101(pretrained=False)
num_ftrs = model.fc.in_features

# Modify the model's output layer
model.fc = nn.Linear(num_ftrs, len(class_names))
model = model.to(device)

# Set the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


### training with mlflow

In [None]:
import mlflow
import mlflow.pytorch

def train_model(model, criterion, optimizer, num_epochs):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # 初始化 MLflow 實驗
    mlflow.set_experiment("Resnet101model_baseline")
    
    # 開始MLflow跟踪
    with mlflow.start_run():
        for epoch in range(num_epochs):
            print("Epoch {}/{}".format(epoch + 1, num_epochs))
            print("-" * 10)

            for phase in ["train", "valid"]:
                if phase == "train":
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in tqdm(dataloaders[phase]):
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        if phase == "train":
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))
                
                # 記錄 MLflow 指標
                mlflow.log_metric("{}_loss".format(phase), epoch_loss)
                mlflow.log_metric("{}_acc".format(phase), epoch_acc)

                if phase == "valid" and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
        print("Best validation accuracy: {:4f}".format(best_acc))

        model.load_state_dict(best_model_wts)
        
        # 儲存最佳模型
        mlflow.pytorch.log_model(model, "best_model")
        
    return model

trained_model = train_model(model, criterion, optimizer, num_epochs=200)


In [8]:
def train_model(model, criterion, optimizer, num_epochs):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch + 1, num_epochs))
        print("-" * 10)

        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

            if phase == "valid" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
    print("Best validation accuracy: {:4f}".format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

trained_model = train_model(model, criterion, optimizer, num_epochs=200)


Epoch 1/200
----------


100%|██████████| 526/526 [08:05<00:00,  1.08it/s]


train Loss: 1.4048 Acc: 0.3928


100%|██████████| 54/54 [00:18<00:00,  2.99it/s]


valid Loss: 1.1398 Acc: 0.5163
Epoch 2/200
----------


100%|██████████| 526/526 [08:48<00:00,  1.00s/it]


train Loss: 1.2216 Acc: 0.4601


100%|██████████| 54/54 [00:18<00:00,  2.93it/s]


valid Loss: 1.2776 Acc: 0.4558
Epoch 3/200
----------


100%|██████████| 526/526 [08:55<00:00,  1.02s/it]


train Loss: 1.1502 Acc: 0.4826


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 1.1734 Acc: 0.5279
Epoch 4/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 1.1242 Acc: 0.5005


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 1.0855 Acc: 0.5605
Epoch 5/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 1.0992 Acc: 0.4969


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 1.2186 Acc: 0.5209
Epoch 6/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 1.0918 Acc: 0.5131


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 1.3713 Acc: 0.4977
Epoch 7/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 1.0620 Acc: 0.5228


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 1.6137 Acc: 0.4930
Epoch 8/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 1.0411 Acc: 0.5338


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 1.0644 Acc: 0.5791
Epoch 9/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.02s/it]


train Loss: 0.9929 Acc: 0.5495


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 1.2520 Acc: 0.5000
Epoch 10/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.9931 Acc: 0.5514


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 1.2164 Acc: 0.4698
Epoch 11/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.9834 Acc: 0.5554


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 1.1222 Acc: 0.5558
Epoch 12/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.9888 Acc: 0.5516


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 1.1649 Acc: 0.5674
Epoch 13/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.9608 Acc: 0.5694


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 1.2216 Acc: 0.4977
Epoch 14/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.9459 Acc: 0.5854


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 0.9774 Acc: 0.5628
Epoch 15/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.9222 Acc: 0.5823


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 3.3131 Acc: 0.2581
Epoch 16/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.9059 Acc: 0.5987


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 1.7023 Acc: 0.4721
Epoch 17/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.8844 Acc: 0.6063


100%|██████████| 54/54 [00:19<00:00,  2.77it/s]


valid Loss: 3.5245 Acc: 0.4023
Epoch 18/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.8804 Acc: 0.6165


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 4.8170 Acc: 0.2256
Epoch 19/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.8556 Acc: 0.6215


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 4.1026 Acc: 0.4233
Epoch 20/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.8489 Acc: 0.6293


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 0.9817 Acc: 0.5907
Epoch 21/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.8427 Acc: 0.6282


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 1.0112 Acc: 0.6256
Epoch 22/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.8229 Acc: 0.6391


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 3.4077 Acc: 0.4209
Epoch 23/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.02s/it]


train Loss: 0.8474 Acc: 0.6343


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 1.4749 Acc: 0.4721
Epoch 24/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.8165 Acc: 0.6434


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 3.0493 Acc: 0.4372
Epoch 25/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.8112 Acc: 0.6481


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 1.9031 Acc: 0.3326
Epoch 26/200
----------


100%|██████████| 526/526 [08:56<00:00,  1.02s/it]


train Loss: 0.7915 Acc: 0.6643


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 1.6016 Acc: 0.3419
Epoch 27/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.7962 Acc: 0.6641


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 2.3244 Acc: 0.4512
Epoch 28/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.7676 Acc: 0.6757


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 3.1319 Acc: 0.2698
Epoch 29/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.7724 Acc: 0.6695


100%|██████████| 54/54 [00:18<00:00,  2.91it/s]


valid Loss: 3.6645 Acc: 0.2628
Epoch 30/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.7639 Acc: 0.6726


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 2.3686 Acc: 0.4395
Epoch 31/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.7297 Acc: 0.6852


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 3.3512 Acc: 0.4674
Epoch 32/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.7195 Acc: 0.6857


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 2.3463 Acc: 0.3721
Epoch 33/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.7077 Acc: 0.6866


100%|██████████| 54/54 [00:18<00:00,  2.89it/s]


valid Loss: 1.5959 Acc: 0.4791
Epoch 34/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.7299 Acc: 0.6874


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 2.9491 Acc: 0.4674
Epoch 35/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.7167 Acc: 0.6988


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 2.2432 Acc: 0.5000
Epoch 36/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.6748 Acc: 0.7204


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 2.5342 Acc: 0.2930
Epoch 37/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.6738 Acc: 0.7118


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 5.9969 Acc: 0.2558
Epoch 38/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.6348 Acc: 0.7275


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 1.8416 Acc: 0.4744
Epoch 39/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.02s/it]


train Loss: 0.6123 Acc: 0.7463


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 2.0792 Acc: 0.5256
Epoch 40/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.6099 Acc: 0.7451


100%|██████████| 54/54 [00:19<00:00,  2.81it/s]


valid Loss: 1.1041 Acc: 0.6070
Epoch 41/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.5987 Acc: 0.7506


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 7.2978 Acc: 0.2651
Epoch 42/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.5900 Acc: 0.7568


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 2.9097 Acc: 0.5279
Epoch 43/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.5780 Acc: 0.7513


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 3.0911 Acc: 0.4977
Epoch 44/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.5403 Acc: 0.7765


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 2.7804 Acc: 0.4070
Epoch 45/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.5147 Acc: 0.7870


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.2429 Acc: 0.4884
Epoch 46/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.5025 Acc: 0.7991


100%|██████████| 54/54 [00:19<00:00,  2.70it/s]


valid Loss: 1.9084 Acc: 0.5186
Epoch 47/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.4717 Acc: 0.8103


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 4.1657 Acc: 0.3093
Epoch 48/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.4962 Acc: 0.7991


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 3.1281 Acc: 0.4000
Epoch 49/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.4815 Acc: 0.8086


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 1.9517 Acc: 0.5116
Epoch 50/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.4435 Acc: 0.8224


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 4.3445 Acc: 0.3093
Epoch 51/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.4167 Acc: 0.8331


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 9.6912 Acc: 0.4163
Epoch 52/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.3893 Acc: 0.8459


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 11.5999 Acc: 0.2395
Epoch 53/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.3700 Acc: 0.8497


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 2.3419 Acc: 0.4837
Epoch 54/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.3620 Acc: 0.8688


100%|██████████| 54/54 [00:18<00:00,  2.91it/s]


valid Loss: 13.3437 Acc: 0.3744
Epoch 55/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.3332 Acc: 0.8733


100%|██████████| 54/54 [00:18<00:00,  2.89it/s]


valid Loss: 10.2615 Acc: 0.3744
Epoch 56/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.3057 Acc: 0.8785


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 16.0229 Acc: 0.2209
Epoch 57/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.3098 Acc: 0.8761


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 1.6720 Acc: 0.5651
Epoch 58/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.2681 Acc: 0.8992


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 2.0590 Acc: 0.5209
Epoch 59/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.2695 Acc: 0.8990


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 11.8953 Acc: 0.1953
Epoch 60/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.2503 Acc: 0.9073


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.0305 Acc: 0.4977
Epoch 61/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.2422 Acc: 0.9085


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 12.9644 Acc: 0.1977
Epoch 62/200
----------


100%|██████████| 526/526 [08:56<00:00,  1.02s/it]


train Loss: 0.2659 Acc: 0.8990


100%|██████████| 54/54 [00:19<00:00,  2.80it/s]


valid Loss: 1.6763 Acc: 0.5860
Epoch 63/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.2399 Acc: 0.9177


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 1.9064 Acc: 0.5628
Epoch 64/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.1906 Acc: 0.9284


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 2.5902 Acc: 0.5302
Epoch 65/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1819 Acc: 0.9353


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 5.3097 Acc: 0.2953
Epoch 66/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1293 Acc: 0.9517


100%|██████████| 54/54 [00:18<00:00,  2.88it/s]


valid Loss: 2.7129 Acc: 0.5279
Epoch 67/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1832 Acc: 0.9344


100%|██████████| 54/54 [00:19<00:00,  2.81it/s]


valid Loss: 18.7246 Acc: 0.2163
Epoch 68/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.02s/it]


train Loss: 0.2254 Acc: 0.9208


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 3.3761 Acc: 0.3977
Epoch 69/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.1256 Acc: 0.9574


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 1.8863 Acc: 0.5674
Epoch 70/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.1099 Acc: 0.9631


100%|██████████| 54/54 [00:18<00:00,  2.89it/s]


valid Loss: 7.8868 Acc: 0.3558
Epoch 71/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1175 Acc: 0.9579


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 8.4634 Acc: 0.3535
Epoch 72/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.1013 Acc: 0.9662


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 15.7285 Acc: 0.3209
Epoch 73/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1204 Acc: 0.9555


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 1.9662 Acc: 0.5558
Epoch 74/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.1111 Acc: 0.9603


100%|██████████| 54/54 [00:18<00:00,  2.87it/s]


valid Loss: 7.4035 Acc: 0.2837
Epoch 75/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.02s/it]


train Loss: 0.0939 Acc: 0.9700


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 7.7326 Acc: 0.3209
Epoch 76/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0779 Acc: 0.9750


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 4.4956 Acc: 0.3930
Epoch 77/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0822 Acc: 0.9710


100%|██████████| 54/54 [00:19<00:00,  2.83it/s]


valid Loss: 4.0316 Acc: 0.3930
Epoch 78/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.0560 Acc: 0.9798


100%|██████████| 54/54 [00:18<00:00,  2.84it/s]


valid Loss: 8.2279 Acc: 0.2535
Epoch 79/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0700 Acc: 0.9779


100%|██████████| 54/54 [00:19<00:00,  2.80it/s]


valid Loss: 2.6783 Acc: 0.5558
Epoch 80/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0694 Acc: 0.9753


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 14.3784 Acc: 0.4326
Epoch 81/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0488 Acc: 0.9834


100%|██████████| 54/54 [00:19<00:00,  2.82it/s]


valid Loss: 2.3628 Acc: 0.6186
Epoch 82/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0276 Acc: 0.9903


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.9898 Acc: 0.5395
Epoch 83/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0352 Acc: 0.9893


100%|██████████| 54/54 [00:19<00:00,  2.80it/s]


valid Loss: 3.4440 Acc: 0.4419
Epoch 84/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0158 Acc: 0.9950


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 14.0772 Acc: 0.3302
Epoch 85/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.0221 Acc: 0.9950


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 8.3880 Acc: 0.3093
Epoch 86/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.0389 Acc: 0.9872


100%|██████████| 54/54 [00:19<00:00,  2.80it/s]


valid Loss: 18.3479 Acc: 0.2721
Epoch 87/200
----------


100%|██████████| 526/526 [08:56<00:00,  1.02s/it]


train Loss: 0.0374 Acc: 0.9883


100%|██████████| 54/54 [00:18<00:00,  2.86it/s]


valid Loss: 3.9346 Acc: 0.4372
Epoch 88/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0211 Acc: 0.9929


100%|██████████| 54/54 [00:19<00:00,  2.78it/s]


valid Loss: 3.9824 Acc: 0.4372
Epoch 89/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0158 Acc: 0.9948


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.1633 Acc: 0.6163
Epoch 90/200
----------


100%|██████████| 526/526 [08:57<00:00,  1.02s/it]


train Loss: 0.0146 Acc: 0.9960


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.0207 Acc: 0.6116
Epoch 91/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0162 Acc: 0.9955


100%|██████████| 54/54 [00:18<00:00,  2.89it/s]


valid Loss: 4.6763 Acc: 0.3884
Epoch 92/200
----------


100%|██████████| 526/526 [08:59<00:00,  1.03s/it]


train Loss: 0.0194 Acc: 0.9950


100%|██████████| 54/54 [00:19<00:00,  2.84it/s]


valid Loss: 17.8778 Acc: 0.3209
Epoch 93/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0346 Acc: 0.9886


100%|██████████| 54/54 [00:18<00:00,  2.91it/s]


valid Loss: 12.2102 Acc: 0.3721
Epoch 94/200
----------


100%|██████████| 526/526 [09:00<00:00,  1.03s/it]


train Loss: 0.0619 Acc: 0.9793


100%|██████████| 54/54 [00:18<00:00,  2.85it/s]


valid Loss: 2.0998 Acc: 0.5837
Epoch 95/200
----------


100%|██████████| 526/526 [08:58<00:00,  1.02s/it]


train Loss: 0.0461 Acc: 0.9843


100%|██████████| 54/54 [00:18<00:00,  2.90it/s]


valid Loss: 3.9038 Acc: 0.5209
Epoch 96/200
----------


 95%|█████████▌| 502/526 [08:33<00:24,  1.02s/it]

### Test the model on the test set

In [None]:
test_corrects = 0
predictions = []
ground_truth = []

with torch.no_grad():
    for inputs, labels in tqdm(dataloaders["test"]):
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = trained_model(inputs)
        _, preds = torch.max(outputs, 1)

        test_corrects += torch.sum(preds == labels.data)
        predictions.extend(preds.cpu().numpy())
        ground_truth.extend(labels.cpu().numpy())
test_acc = test_corrects.double() / dataset_sizes["test"]
print("Test accuracy: {:.4f}".format(test_acc))

### metrics and plot

In [None]:
cm = confusion_matrix(ground_truth, predictions)
plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", square=True)
plt.xlabel("Predicted class")
plt.ylabel("True class")
plt.title("Confusion Matrix")
plt.show()

print(classification_report(ground_truth, predictions, target_names=class_names))