In [1]:
import torchvision
import torch
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
DATA_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Datasets\Augment"
CHECKPOINT_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Models\PyTorch ResNet18 Model"
ONNX_SAVE_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Models\Quantized Model\pytorch_resnet18.onnx"
NUM_CLASSES = 199
TRAIN_EPOCHS = 150

In [3]:
mean = [0.3882, 0.3525, 0.3215]
std = [0.3120, 0.2901, 0.2704]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
])

data = torchvision.datasets.ImageFolder(root=DATA_PATH, transform=transform)

train_size = int(0.85 * len(data))
test_size = len(data) - train_size
train_data, test_data = torch.utils.data.random_split(data, [train_size, test_size])

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [5]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

In [6]:
def set_device():
    if torch.cuda.is_available(): dev = 'cuda:0'
    else: dev = 'cpu'
    
    return torch.device(dev)

def evaluate_model(model, test_loader):
    model.eval()
    
    predicted_correctly_on_epoch = 0
    total = 0
    
    device = set_device()
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            predicted_correctly_on_epoch += (predicted == labels).sum().item()
    
    epoch_acc = 100.0 * (predicted_correctly_on_epoch / total)
    print(f"Testing dataset: {int(predicted_correctly_on_epoch)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}")
    
    return epoch_acc

def save_checkpoint(model, epoch, optimizer, best_acc):
    state = {
        'epoch': epoch+1,
        'model': model.state_dict(),
        'best accuracy': best_acc,
        'optimizer': optimizer.state_dict()
    }
    
    torch.save(state, f'{CHECKPOINT_PATH}\model_best_checkpoint.pth.tar')

In [7]:
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
    device = set_device()
    best_acc = 0
    
    for epoch in tqdm(range(n_epochs)):
        print(f"Epoch: {epoch+1}")
        model.train()
        running_loss = 0.0
        running_correct = 0.0
        total = 0
        
        for data in train_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            optimizer.step()
            
            running_loss += loss.item()
            running_correct += (labels == predicted).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100.00 * (running_correct / total)
        
        print(f"Training dataset: {int(running_correct)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}. Loss: {round(epoch_loss, 3)}")
        
        test_acc = evaluate_model(model, test_loader)
        
        if (test_acc > best_acc):
            best_acc = test_acc
            save_checkpoint(model, epoch, optimizer, best_acc)
        
    print("Finished")
    return model

In [8]:
resnet18 = models.resnet18(pretrained=False)
num_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_features, NUM_CLASSES)

device = set_device()
resnet18 = resnet18.to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)

In [9]:
train_nn(resnet18, train_loader, test_loader, loss_func, optimizer, n_epochs=TRAIN_EPOCHS)

  0%|                                                                                          | 0/150 [00:00<?, ?it/s]

Epoch: 1
Training dataset: 409349 out of 1493921 correct. Accuracy: 27.401. Loss: 3.527


  1%|▌                                                                           | 1/150 [43:27<107:54:44, 2607.28s/it]

Testing dataset: 102862 out of 263634 correct. Accuracy: 39.017
Epoch: 2
Training dataset: 753682 out of 1493921 correct. Accuracy: 50.45. Loss: 2.285


  1%|▉                                                                         | 2/150 [1:26:32<106:39:33, 2594.41s/it]

Testing dataset: 128570 out of 263634 correct. Accuracy: 48.768
Epoch: 3
Training dataset: 844824 out of 1493921 correct. Accuracy: 56.551. Loss: 2.002


  2%|█▍                                                                        | 3/150 [2:09:33<105:41:28, 2588.35s/it]

Testing dataset: 139384 out of 263634 correct. Accuracy: 52.87
Epoch: 4
Training dataset: 875492 out of 1493921 correct. Accuracy: 58.604. Loss: 1.906


  3%|█▉                                                                        | 4/150 [2:52:35<104:51:46, 2585.66s/it]

Testing dataset: 129796 out of 263634 correct. Accuracy: 49.233
Epoch: 5
Training dataset: 891178 out of 1493921 correct. Accuracy: 59.654. Loss: 1.858


  3%|██▍                                                                       | 5/150 [3:35:42<104:10:05, 2586.24s/it]

Testing dataset: 141636 out of 263634 correct. Accuracy: 53.724
Epoch: 6
Training dataset: 901260 out of 1493921 correct. Accuracy: 60.328. Loss: 1.826


  4%|██▉                                                                       | 6/150 [4:18:48<103:26:48, 2586.17s/it]

Testing dataset: 141664 out of 263634 correct. Accuracy: 53.735
Epoch: 7
Training dataset: 908587 out of 1493921 correct. Accuracy: 60.819. Loss: 1.807


  5%|███▍                                                                      | 7/150 [5:01:46<102:37:31, 2583.57s/it]

Testing dataset: 139803 out of 263634 correct. Accuracy: 53.029
Epoch: 8
Training dataset: 910758 out of 1493921 correct. Accuracy: 60.964. Loss: 1.796


  5%|███▉                                                                      | 8/150 [5:44:45<101:50:48, 2582.03s/it]

Testing dataset: 140745 out of 263634 correct. Accuracy: 53.387
Epoch: 9
Training dataset: 915529 out of 1493921 correct. Accuracy: 61.284. Loss: 1.784


  6%|████▍                                                                     | 9/150 [6:27:45<101:06:02, 2581.29s/it]

Testing dataset: 142029 out of 263634 correct. Accuracy: 53.874
Epoch: 10
Training dataset: 918979 out of 1493921 correct. Accuracy: 61.515. Loss: 1.773


  7%|████▊                                                                    | 10/150 [7:10:49<100:25:28, 2582.35s/it]

Testing dataset: 146721 out of 263634 correct. Accuracy: 55.653
Epoch: 11
Training dataset: 922664 out of 1493921 correct. Accuracy: 61.761. Loss: 1.764


  7%|█████▍                                                                    | 11/150 [7:53:44<99:37:02, 2580.02s/it]

Testing dataset: 152250 out of 263634 correct. Accuracy: 57.751
Epoch: 12
Training dataset: 924478 out of 1493921 correct. Accuracy: 61.883. Loss: 1.758


  8%|█████▉                                                                    | 12/150 [8:36:40<98:51:11, 2578.78s/it]

Testing dataset: 153349 out of 263634 correct. Accuracy: 58.167
Epoch: 13
Training dataset: 926288 out of 1493921 correct. Accuracy: 62.004. Loss: 1.752


  9%|██████▍                                                                   | 13/150 [9:19:36<98:05:50, 2577.74s/it]

Testing dataset: 158366 out of 263634 correct. Accuracy: 60.07
Epoch: 14
Training dataset: 928102 out of 1493921 correct. Accuracy: 62.125. Loss: 1.747


  9%|██████▊                                                                  | 14/150 [10:02:37<97:25:33, 2578.92s/it]

Testing dataset: 150250 out of 263634 correct. Accuracy: 56.992
Epoch: 15
Training dataset: 927897 out of 1493921 correct. Accuracy: 62.112. Loss: 1.747


 10%|███████▎                                                                 | 15/150 [10:45:36<96:42:44, 2579.00s/it]

Testing dataset: 145829 out of 263634 correct. Accuracy: 55.315
Epoch: 16
Training dataset: 928644 out of 1493921 correct. Accuracy: 62.162. Loss: 1.745


 11%|███████▊                                                                 | 16/150 [11:28:39<96:02:00, 2580.00s/it]

Testing dataset: 158180 out of 263634 correct. Accuracy: 60.0
Epoch: 17
Training dataset: 928968 out of 1493921 correct. Accuracy: 62.183. Loss: 1.745


 11%|████████▎                                                                | 17/150 [12:11:39<95:19:15, 2580.12s/it]

Testing dataset: 147949 out of 263634 correct. Accuracy: 56.119
Epoch: 18
Training dataset: 930560 out of 1493921 correct. Accuracy: 62.29. Loss: 1.739


 12%|████████▊                                                                | 18/150 [12:54:41<94:37:08, 2580.52s/it]

Testing dataset: 151370 out of 263634 correct. Accuracy: 57.417
Epoch: 19
Training dataset: 930462 out of 1493921 correct. Accuracy: 62.283. Loss: 1.739


 13%|█████████▏                                                               | 19/150 [13:37:40<93:53:43, 2580.33s/it]

Testing dataset: 142150 out of 263634 correct. Accuracy: 53.919
Epoch: 20
Training dataset: 931299 out of 1493921 correct. Accuracy: 62.339. Loss: 1.738


 13%|█████████▋                                                               | 20/150 [14:20:42<93:11:11, 2580.55s/it]

Testing dataset: 123529 out of 263634 correct. Accuracy: 46.856
Epoch: 21
Training dataset: 932613 out of 1493921 correct. Accuracy: 62.427. Loss: 1.732


 14%|██████████▏                                                              | 21/150 [15:03:42<92:28:22, 2580.64s/it]

Testing dataset: 153231 out of 263634 correct. Accuracy: 58.123
Epoch: 22
Training dataset: 933814 out of 1493921 correct. Accuracy: 62.508. Loss: 1.73


 15%|██████████▋                                                              | 22/150 [15:46:41<91:43:52, 2579.94s/it]

Testing dataset: 155867 out of 263634 correct. Accuracy: 59.122
Epoch: 23
Training dataset: 933548 out of 1493921 correct. Accuracy: 62.49. Loss: 1.73


 15%|███████████▏                                                             | 23/150 [16:29:37<90:58:29, 2578.81s/it]

Testing dataset: 158095 out of 263634 correct. Accuracy: 59.968
Epoch: 24
Training dataset: 933522 out of 1493921 correct. Accuracy: 62.488. Loss: 1.731


 16%|███████████▋                                                             | 24/150 [17:12:32<90:13:25, 2577.82s/it]

Testing dataset: 157275 out of 263634 correct. Accuracy: 59.657
Epoch: 25
Training dataset: 933850 out of 1493921 correct. Accuracy: 62.51. Loss: 1.729


 17%|████████████▏                                                            | 25/150 [17:55:30<89:30:36, 2577.89s/it]

Testing dataset: 140534 out of 263634 correct. Accuracy: 53.306
Epoch: 26
Training dataset: 934105 out of 1493921 correct. Accuracy: 62.527. Loss: 1.728


 17%|████████████▋                                                            | 26/150 [18:38:30<88:48:47, 2578.44s/it]

Testing dataset: 155524 out of 263634 correct. Accuracy: 58.992
Epoch: 27
Training dataset: 934695 out of 1493921 correct. Accuracy: 62.567. Loss: 1.725


 18%|█████████████▏                                                           | 27/150 [19:21:35<88:09:52, 2580.42s/it]

Testing dataset: 150507 out of 263634 correct. Accuracy: 57.089
Epoch: 28
Training dataset: 936125 out of 1493921 correct. Accuracy: 62.662. Loss: 1.723


 19%|█████████████▋                                                           | 28/150 [20:04:33<87:25:22, 2579.69s/it]

Testing dataset: 147022 out of 263634 correct. Accuracy: 55.767
Epoch: 29
Training dataset: 936652 out of 1493921 correct. Accuracy: 62.698. Loss: 1.722


 19%|██████████████                                                           | 29/150 [20:47:25<86:37:24, 2577.23s/it]

Testing dataset: 146153 out of 263634 correct. Accuracy: 55.438
Epoch: 30
Training dataset: 935339 out of 1493921 correct. Accuracy: 62.61. Loss: 1.723


 20%|██████████████▌                                                          | 30/150 [21:30:27<85:57:15, 2578.63s/it]

Testing dataset: 146924 out of 263634 correct. Accuracy: 55.73
Epoch: 31
Training dataset: 935726 out of 1493921 correct. Accuracy: 62.636. Loss: 1.725


 21%|███████████████                                                          | 31/150 [22:13:24<85:13:41, 2578.33s/it]

Testing dataset: 148430 out of 263634 correct. Accuracy: 56.302
Epoch: 32
Training dataset: 936655 out of 1493921 correct. Accuracy: 62.698. Loss: 1.722


 21%|███████████████▌                                                         | 32/150 [22:56:22<84:30:30, 2578.23s/it]

Testing dataset: 147922 out of 263634 correct. Accuracy: 56.109
Epoch: 33
Training dataset: 937433 out of 1493921 correct. Accuracy: 62.75. Loss: 1.72


 22%|████████████████                                                         | 33/150 [23:39:23<83:49:19, 2579.14s/it]

Testing dataset: 148123 out of 263634 correct. Accuracy: 56.185
Epoch: 34


 22%|████████████████                                                         | 33/150 [24:02:34<85:14:34, 2622.86s/it]


KeyboardInterrupt: 

In [10]:
checkpoint = torch.load(f'{CHECKPOINT_PATH}\model_best_checkpoint.pth.tar')
print(f"Best Epoch: {checkpoint['epoch']}, Best Accuracy: {round(checkpoint['best accuracy'], 3)}")

Best Epoch: 13, Best Accuracy: 60.07


In [11]:
resnet18 = models.resnet18(pretrained=False)
num_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_features, NUM_CLASSES)
resnet18.load_state_dict(checkpoint['model'])

dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)
torch.onnx.export(
    resnet18, 
    dummy_input, 
    ONNX_SAVE_PATH,
    verbose=False,
    export_params=True,
    opset_version=11,
    input_names = ['image'],
    output_names = ['pred']
)