In [1]:
import torchvision
import torch
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
DATA_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Datasets\Augment"
CHECKPOINT_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Models\PyTorch ResNet50 Model\model_best_checkpoint.pth.tar"
ONNX_SAVE_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier\Models\Quantized Model\pytorch_resnet50.onnx"
NUM_CLASSES = 199
TRAIN_EPOCHS = 20

In [3]:
mean = [0.3882, 0.3525, 0.3215]
std = [0.3120, 0.2901, 0.2704]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
])

data = torchvision.datasets.ImageFolder(root=DATA_PATH, transform=transform)

train_size = int(0.85 * len(data))
test_size = len(data) - train_size
train_data, test_data = torch.utils.data.random_split(data, [train_size, test_size])

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

In [5]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

In [6]:
def set_device():
    if torch.cuda.is_available(): dev = 'cuda:0'
    else: dev = 'cpu'
    
    return torch.device(dev)

def evaluate_model(model, test_loader):
    model.eval()
    
    predicted_correctly_on_epoch = 0
    total = 0
    
    device = set_device()
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            predicted_correctly_on_epoch += (predicted == labels).sum().item()
    
    epoch_acc = 100.0 * (predicted_correctly_on_epoch / total)
    print(f"Testing dataset: {int(predicted_correctly_on_epoch)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}")
    
    return epoch_acc

def save_checkpoint(model, epoch, optimizer, best_acc):
    state = {
        'epoch': epoch+1,
        'model': model.state_dict(),
        'best accuracy': best_acc,
        'optimizer': optimizer.state_dict()
    }
    
    torch.save(state, CHECKPOINT_PATH)

In [7]:
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
    device = set_device()
    best_acc = 0
    
    for epoch in tqdm(range(n_epochs)):
        print(f"Epoch: {epoch+1}")
        model.train()
        running_loss = 0.0
        running_correct = 0.0
        total = 0
        
        for data in train_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            optimizer.step()
            
            running_loss += loss.item()
            running_correct += (labels == predicted).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100.00 * (running_correct / total)
        
        print(f"Training dataset: {int(running_correct)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}. Loss: {round(epoch_loss, 3)}")
        
        test_acc = evaluate_model(model, test_loader)
        
        if (test_acc > best_acc):
            best_acc = test_acc
            save_checkpoint(model, epoch, optimizer, best_acc)
        
    print("Finished")

In [8]:
resnet50 = models.resnet50(pretrained=False)
num_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_features, NUM_CLASSES)

device = set_device()
resnet50 = resnet50.to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet50.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)

In [9]:
train_nn(resnet50, train_loader, test_loader, loss_func, optimizer, n_epochs=TRAIN_EPOCHS)

  0%|                                                                                           | 0/20 [00:00<?, ?it/s]

Epoch: 1
Training dataset: 662632 out of 1493921 correct. Accuracy: 44.355. Loss: 2.671
Testing dataset: 193069 out of 263634 correct. Accuracy: 73.234


  5%|███▊                                                                        | 1/20 [1:22:28<26:06:54, 4948.15s/it]

Epoch: 2
Training dataset: 1283963 out of 1493921 correct. Accuracy: 85.946. Loss: 0.715
Testing dataset: 215398 out of 263634 correct. Accuracy: 81.703


 10%|███████▌                                                                    | 2/20 [2:44:28<24:39:27, 4931.51s/it]

Epoch: 3
Training dataset: 1332505 out of 1493921 correct. Accuracy: 89.195. Loss: 0.579
Testing dataset: 235399 out of 263634 correct. Accuracy: 89.29


 15%|███████████▍                                                                | 3/20 [4:06:23<23:15:12, 4924.27s/it]

Epoch: 4
Training dataset: 1336228 out of 1493921 correct. Accuracy: 89.444. Loss: 0.562
Testing dataset: 244562 out of 263634 correct. Accuracy: 92.766


 20%|███████████████▏                                                            | 4/20 [5:28:08<21:51:08, 4916.76s/it]

Epoch: 5
Training dataset: 1335845 out of 1493921 correct. Accuracy: 89.419. Loss: 0.557


 25%|███████████████████                                                         | 5/20 [6:50:50<20:33:13, 4932.88s/it]

Testing dataset: 201014 out of 263634 correct. Accuracy: 76.247
Epoch: 6
Training dataset: 1341500 out of 1493921 correct. Accuracy: 89.797. Loss: 0.532


 30%|██████████████████████▊                                                     | 6/20 [8:11:53<19:05:28, 4909.20s/it]

Testing dataset: 190004 out of 263634 correct. Accuracy: 72.071
Epoch: 7
Training dataset: 1345356 out of 1493921 correct. Accuracy: 90.055. Loss: 0.517


 35%|██████████████████████████▌                                                 | 7/20 [9:32:49<17:39:52, 4891.76s/it]

Testing dataset: 203986 out of 263634 correct. Accuracy: 77.375
Epoch: 8
Training dataset: 1345070 out of 1493921 correct. Accuracy: 90.036. Loss: 0.517


 40%|██████████████████████████████                                             | 8/20 [10:53:50<16:16:24, 4882.02s/it]

Testing dataset: 233718 out of 263634 correct. Accuracy: 88.652
Epoch: 9
Training dataset: 1349249 out of 1493921 correct. Accuracy: 90.316. Loss: 0.502


 45%|█████████████████████████████████▊                                         | 9/20 [12:14:44<14:53:25, 4873.24s/it]

Testing dataset: 230158 out of 263634 correct. Accuracy: 87.302
Epoch: 10
Training dataset: 1350703 out of 1493921 correct. Accuracy: 90.413. Loss: 0.497


 50%|█████████████████████████████████████                                     | 10/20 [13:35:45<13:31:35, 4869.51s/it]

Testing dataset: 205953 out of 263634 correct. Accuracy: 78.121
Epoch: 11
Training dataset: 1350025 out of 1493921 correct. Accuracy: 90.368. Loss: 0.498


 55%|████████████████████████████████████████▋                                 | 11/20 [14:56:47<12:10:05, 4867.28s/it]

Testing dataset: 227891 out of 263634 correct. Accuracy: 86.442
Epoch: 12
Training dataset: 1348758 out of 1493921 correct. Accuracy: 90.283. Loss: 0.502


 60%|████████████████████████████████████████████▍                             | 12/20 [16:17:43<10:48:30, 4863.84s/it]

Testing dataset: 237703 out of 263634 correct. Accuracy: 90.164
Epoch: 13
Training dataset: 1351146 out of 1493921 correct. Accuracy: 90.443. Loss: 0.495


 65%|████████████████████████████████████████████████▊                          | 13/20 [17:38:44<9:27:20, 4862.89s/it]

Testing dataset: 224057 out of 263634 correct. Accuracy: 84.988
Epoch: 14
Training dataset: 1349979 out of 1493921 correct. Accuracy: 90.365. Loss: 0.499


 70%|████████████████████████████████████████████████████▌                      | 14/20 [18:59:51<8:06:24, 4864.14s/it]

Testing dataset: 216127 out of 263634 correct. Accuracy: 81.98
Epoch: 15
Training dataset: 1351964 out of 1493921 correct. Accuracy: 90.498. Loss: 0.491


 75%|████████████████████████████████████████████████████████▎                  | 15/20 [20:20:44<6:45:03, 4860.74s/it]

Testing dataset: 234823 out of 263634 correct. Accuracy: 89.072
Epoch: 16
Training dataset: 1350668 out of 1493921 correct. Accuracy: 90.411. Loss: 0.494


 80%|████████████████████████████████████████████████████████████               | 16/20 [21:41:35<5:23:51, 4857.79s/it]

Testing dataset: 228689 out of 263634 correct. Accuracy: 86.745
Epoch: 17
Training dataset: 1350906 out of 1493921 correct. Accuracy: 90.427. Loss: 0.495


 85%|███████████████████████████████████████████████████████████████▊           | 17/20 [23:02:40<4:02:59, 4859.93s/it]

Testing dataset: 233720 out of 263634 correct. Accuracy: 88.653
Epoch: 18
Training dataset: 1351027 out of 1493921 correct. Accuracy: 90.435. Loss: 0.493


 90%|███████████████████████████████████████████████████████████████████▌       | 18/20 [24:23:41<2:42:00, 4860.44s/it]

Testing dataset: 228331 out of 263634 correct. Accuracy: 86.609
Epoch: 19
Training dataset: 1350808 out of 1493921 correct. Accuracy: 90.42. Loss: 0.495


 95%|███████████████████████████████████████████████████████████████████████▎   | 19/20 [25:44:31<1:20:57, 4857.25s/it]

Testing dataset: 225900 out of 263634 correct. Accuracy: 85.687
Epoch: 20
Training dataset: 1351383 out of 1493921 correct. Accuracy: 90.459. Loss: 0.491


100%|█████████████████████████████████████████████████████████████████████████████| 20/20 [27:05:28<00:00, 4876.44s/it]

Testing dataset: 242646 out of 263634 correct. Accuracy: 92.039
Finished





ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [11]:
checkpoint = torch.load(CHECKPOINT_PATH)
print(f"Best Epoch: {checkpoint['epoch']}, Best Accuracy: {round(checkpoint['best accuracy'], 3)}")

Best Epoch: 4, Best Accuracy: 92.766


In [12]:
resnet50 = models.resnet50(pretrained=False)
num_features = resnet50.fc.in_features
resnet50.fc = nn.Linear(num_features, NUM_CLASSES)
resnet50.load_state_dict(checkpoint['model'])

dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)
torch.onnx.export(
    resnet50, 
    dummy_input, 
    ONNX_SAVE_PATH,
    verbose=False,
    export_params=True,
    opset_version=11,
    input_names = ['image'],
    output_names = ['pred']
)