In [33]:
# 1
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import time
import argparse
import numpy as np

# from tqdm import tqdm
from tqdm.notebook import tqdm_notebook

matplotlib.style.use("ggplot")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [34]:
# 2
train_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)
val_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)

In [36]:
train_dataset = datasets.ImageFolder(
    root=r'C:\chatbot\python\pythorch\data\archive\train',
    transform=train_transform
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True,
)
val_dataset = datasets.ImageFolder(
    root=r'C:\chatbot\python\pythorch\data\archive\test',
    transform=val_transform
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, shuffle=False,
)

In [37]:
# 4
def resnet50(pretrained=True, requires_grad=False):
    model = models.resnet50(progress=True, pretrained=pretrained)
    if not requires_grad:
        for param in model.parameters():
            param.requires_grad = False
    elif requires_grad:
        for param in model.parameters():
            param.requires_grad = True
    model.fc = nn.Linear(2048, 2)
    return model

In [38]:
# 5 learning rate scheduler
class LRScheduler:
    def __init__(self, optimizer, patience=5, min_lr=1e-6, factor=0.5):
        self.optimizer = optimizer
        self.patience = patience
        self.min_lr = min_lr
        self.factor = factor
        self.lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            patience=self.patience,
            min_lr=self.min_lr,
            factor=self.factor,
            verbose=True,
        )

    def __call__(self, val_loss):
        self.lr_scheduler.step(val_loss)

In [39]:
# 6 early stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0, path="data/checkpoint.pt"):
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f}")
            print(f"--> {val_loss:.6f}).  Saving model ...")
            torch.save(model.state_dict(), self.path)
            self.val_loss_min = val_loss

In [28]:
# 7 parser
# parser = argparse.ArgumentParser()
# parser.add_argument(
#     "--lr-scheduler", dest="lr_scheduler", action="store_true", default=False
# )
# parser.add_argument(
#     "--early-stopping", dest="early_stopping", action="store_true", default=False
# )
# args = vars(parser.parse_args())

usage: ipykernel_launcher.py [-h] [--lr-scheduler] [--early-stopping]
ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme="hmac-sha256" --Session.key=b"02f3bd53-3e81-45ef-b701-4a0d6cc3cf26" --shell=9002 --transport="tcp" --iopub=9004 --f=c:\Users\admin\AppData\Roaming\jupyter\runtime\kernel-v2-8036ZX39BA86bUsl.json


SystemExit: 2

In [40]:
# 8
print(f"Computation device: {device}\n")
model = models.resnet50(pretrained=True).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters\n")

Computation device: cpu





25,557,032 total parameters
25,557,032 training parameters



In [41]:
# 8.1 lr scheduler
lr = 0.001
epoch_num = 10
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

loss_plot_name = 'loss'
acc_plot_name = 'accuracy'
model_name = 'resnet50'


lr_scheduler_bool = False
early_stopping_bool = False
if lr_scheduler_bool:
    print('info: init lr scheduler')
    lr_scheduler = LRScheduler(optimizer)
    loss_plot_name = 'loss_lr_scheduler'
    acc_plot_name = 'accuracy_lr_scheduler'
    model_name = 'resnet50_lr_scheduler'
if early_stopping_bool:
    print('info: init early stopping')
    early_stopping = EarlyStopping()
    loss_plot_name = 'loss_early_stopping'
    acc_plot_name = 'accuracy_early_stopping'
    model_name = 'resnet50_early_stopping'

In [45]:
# 9 training
def training(model, train_dataloader, train_dataset, optimizer, criterion):
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    total = 0
    prog_bar = tqdm_notebook(enumerate(train_dataloader), total=int(len(train_dataset)/train_dataloader.batch_size))

    for i, data in prog_bar:
        counter += 1
        data, target = data[0].to(device), data[1].to(device)
        total += target.size(0)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == target).sum().item()
        loss.backward()
        optimizer.step()

    train_loss = train_running_loss / counter
    train_accuracy = 100.0 * train_running_correct / total
    return train_loss, train_accuracy

In [46]:
# 10 validate
def validate(model, test_dataloader, val_dataset, criterion):
    print('VALIDATION')
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    counter = 0
    total = 0
    prog_bar = tqdm_notebook(enumerate(test_dataloader), total=int(len(val_dataset)/test_dataloader.batch_size))

    with torch.no_grad():
        for i, data in prog_bar:
            counter += 1
            data, target = data[0].to(device), data[1].to(device)
            total += target.size(0)
            outputs = model(data)
            loss = criterion(outputs, target)

            val_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            val_running_correct += (preds == target).sum().item()

        val_loss = val_running_loss / counter
        val_accuracy = 100.0 * val_running_correct / total
        return val_loss, val_accuracy

In [50]:
# 11 model training
train_loss, train_accuracy = [], []

val_loss, val_accuracy = [], []

start = time.time()
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_accuracy = training(
        model, train_dataloader, train_dataset, optimizer, criterion
    )
    val_epoch_loss, val_epoch_accuracy = validate(
        model, val_dataloader, val_dataset, criterion
    )
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    if lr_scheduler_bool:
        lr_scheduler(val_epoch_loss)
    if early_stopping_bool:
        early_stopping(val_epoch_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break
    print(f"Train Loss: {train_epoch_loss:.4f}, Train Acc: {train_epoch_accuracy:.2f}")
    print(f'Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_accuracy:.2f}')
end = time.time()
print(f"Training time: {(end-start)/0:.3f} minutes")

Epoch 1 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.5702, Train Acc: 74.10
Val Loss: 1.3116, Val Acc: 62.20
Epoch 2 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.4310, Train Acc: 81.33
Val Loss: 0.5829, Val Acc: 76.80
Epoch 3 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.3457, Train Acc: 85.34
Val Loss: 0.8282, Val Acc: 68.60
Epoch 4 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.3357, Train Acc: 85.94
Val Loss: 0.6573, Val Acc: 74.20
Epoch 5 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.2454, Train Acc: 90.56
Val Loss: 0.6881, Val Acc: 70.00
Epoch 6 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.2258, Train Acc: 90.96
Val Loss: 1.5646, Val Acc: 63.40
Epoch 7 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.1782, Train Acc: 93.78
Val Loss: 0.6170, Val Acc: 75.20
Epoch 8 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

VALIDATION


  0%|          | 0/15 [00:00<?, ?it/s]

Train Loss: 0.1466, Train Acc: 95.38
Val Loss: 0.6422, Val Acc: 77.60
Epoch 9 of 100


  0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
print('Saving loss and accuracy plots...')
plt.figure(figsize=(10, 7))
plt.plot(train_accuracy, color='green', label='train accuracy')
plt.plot(val_accuracy, color='blue', label='validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig(f"../chap08/img/{acc_plot_name}.png")
plt.show()
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.savefig(f"../chap08/img/{loss_plot_name}.png")
plt.show()

print('Saving model...')
torch.save(model.state_dict(), f"../chap08/img/{model_name}.pth")
print('TRAINING COMPLETE')