In [1]:
from os.path import join as fullfile

from matplotlib import pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torchvision.datasets import ImageFolder

from pytorchtools import EarlyStopping

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
# %% Data loader

train_path = fullfile('datasets', 'train')
val_path = fullfile('datasets', 'val')
test_path = fullfile('datasets', 'test')

train_transform = transforms.Compose([transforms.Grayscale(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0, 1)])
val_test_transform = transforms.Compose([transforms.Grayscale(),
                                         transforms.ToTensor(),
                                         transforms.Normalize(0, 1)])

train_dataset = ImageFolder(train_path, train_transform)
val_dataset = ImageFolder(val_path, val_test_transform)
test_dataset = ImageFolder(test_path, val_test_transform)

In [3]:
print(train_dataset.class_to_idx)

{'bacterial': 0, 'normal': 1, 'virus': 2}


In [4]:
# %% Model setting

net = models.mobilenet_v2(pretrained=False)
net.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
net.classifier[1] = nn.Linear(1280, 3, bias=True)

net = net.to(device)

In [5]:
# %% Loss and optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr = 0.0001)
# optimizer = optim.SGD(net.parameters(), lr = 0.005, momentum=0.9, weight_decay = 0.0001)

scheduler_steplr = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scheduler_reducelr = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, min_lr=1e-5)

In [6]:
# %% Train settings

n_epoch = 20
mini_batch_size = 64

early_stopping = EarlyStopping(patience=5, verbose=True, path=fullfile('run', 'checkpoint.pt'))

In [7]:
# %% Data loader

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=mini_batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=mini_batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False)

In [9]:
def train_on_one_batch(net, data, criterion, device):
    net.train()

    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    return net, loss


def validate_on_epoch_end(net, val_loader, device, criterion):
    val_batches_loss = []
    net.eval()
    with torch.no_grad():                
        for inputs_val, labels_val in val_loader:
            inputs_val = inputs_val.to(device)
            labels_val = labels_val.to(device)

            outputs_val = net(inputs_val)
            val_batch_loss = criterion(outputs_val, labels_val)
            val_batches_loss.append(val_batch_loss.item())

        return np.mean(val_batches_loss)
        

for epoch in range(n_epoch):

    # train
    train_batches_loss = []
    for i, data in enumerate(train_loader, start=0):        
        net, train_batch_loss = train_on_one_batch(net, data, criterion, device)
        train_batches_loss.append(train_batch_loss.item())
    train_loss = np.mean(train_batches_loss)

    # validation
    val_loss = validate_on_epoch_end(net, val_loader, device, criterion)

    print(f"epoch = {epoch}, train loss = {train_loss}, validation loss = {val_loss}")

    early_stopping(val_loss, net)
    if early_stopping.early_stop:
        print("Early stopping")
        break    

    # update learning rate
    scheduler_steplr.step()


print("Finished training")

Validation loss decreased (inf --> 1.227626).  Saving model ...
epoch = 0, train loss = 0.8494431268085133, validation loss = 1.2276261821389198
EarlyStopping counter: 1 out of 5
epoch = 1, train loss = 0.7020497918128967, validation loss = 1.5722232311964035
EarlyStopping counter: 2 out of 5
epoch = 2, train loss = 0.6276898655024442, validation loss = 1.6557928621768951
EarlyStopping counter: 3 out of 5
epoch = 3, train loss = 0.5546027530323375, validation loss = 1.816786840558052
EarlyStopping counter: 4 out of 5
epoch = 4, train loss = 0.47412819212133234, validation loss = 2.027424693107605
EarlyStopping counter: 5 out of 5
epoch = 5, train loss = 0.36209767244078894, validation loss = 2.519050285220146
EarlyStopping counter: 6 out of 5
epoch = 6, train loss = 0.2526032342152162, validation loss = 2.448147814720869
EarlyStopping counter: 7 out of 5
epoch = 7, train loss = 0.13625753670930862, validation loss = 1.6542234271764755
Validation loss decreased (1.227626 --> 1.203541). 