In [1]:
import torchvision
import torch
import torchvision.transforms as transforms
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
DATA_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier_Compact\Datasets\Augment"
CHECKPOINT_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier_Compact\Models\PyTorch ResNet18 Model\model_best_checkpoint.pth.tar"
ONNX_SAVE_PATH = r"C:\Users\offco\Documents\Dev_Projects\BookCoverClassifier_Compact\Models\Quantized Model\pytorch_resnet18.onnx"
NUM_CLASSES = 12
TRAIN_EPOCHS = 30

In [3]:
mean = [0.3852, 0.3434, 0.3142]
std = [0.3134, 0.2889, 0.2685]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)),
])

data = torchvision.datasets.ImageFolder(root=DATA_PATH, transform=transform)

train_size = int(0.85 * len(data))
test_size = len(data) - train_size
train_data, test_data = torch.utils.data.random_split(data, [train_size, test_size])

In [4]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

In [5]:
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

In [6]:
def set_device():
    if torch.cuda.is_available(): dev = 'cuda:0'
    else: dev = 'cpu'
    
    return torch.device(dev)

def evaluate_model(model, test_loader):
    model.eval()
    
    predicted_correctly_on_epoch = 0
    total = 0
    
    device = set_device()
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            predicted_correctly_on_epoch += (predicted == labels).sum().item()
    
    epoch_acc = 100.0 * (predicted_correctly_on_epoch / total)
    print(f"Testing dataset: {int(predicted_correctly_on_epoch)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}")
    
    return epoch_acc

def save_checkpoint(model, epoch, optimizer, best_acc):
    state = {
        'epoch': epoch+1,
        'model': model.state_dict(),
        'best accuracy': best_acc,
        'optimizer': optimizer.state_dict()
    }
    
    torch.save(state, CHECKPOINT_PATH)

In [7]:
def train_nn(model, train_loader, test_loader, criterion, optimizer, n_epochs):
    device = set_device()
    best_acc = 0
    
    for epoch in tqdm(range(n_epochs)):
        print(f"Epoch: {epoch+1}")
        model.train()
        running_loss = 0.0
        running_correct = 0.0
        total = 0
        
        for data in train_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            optimizer.zero_grad()
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            optimizer.step()
            
            running_loss += loss.item()
            running_correct += (labels == predicted).sum().item()
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100.00 * (running_correct / total)
        
        print(f"Training dataset: {int(running_correct)} out of {total} correct. Accuracy: {round(epoch_acc, 3)}. Loss: {round(epoch_loss, 3)}")
        
        test_acc = evaluate_model(model, test_loader)
        
        if (test_acc > best_acc):
            best_acc = test_acc
            save_checkpoint(model, epoch, optimizer, best_acc)
        
    print("Finished")

In [8]:
resnet18 = models.resnet18(pretrained=False)
num_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_features, NUM_CLASSES)

device = set_device()
resnet18 = resnet18.to(device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.01, momentum=0.9, weight_decay=0.003)

In [9]:
train_nn(resnet18, train_loader, test_loader, loss_func, optimizer, n_epochs=TRAIN_EPOCHS)

  0%|                                                                                           | 0/30 [00:00<?, ?it/s]

Epoch: 1
Training dataset: 4195 out of 15385 correct. Accuracy: 27.267. Loss: 2.178


  3%|██▊                                                                                | 1/30 [00:43<20:51, 43.16s/it]

Testing dataset: 993 out of 2715 correct. Accuracy: 36.575
Epoch: 2
Training dataset: 6777 out of 15385 correct. Accuracy: 44.049. Loss: 1.64


  7%|█████▌                                                                             | 2/30 [01:22<19:08, 41.04s/it]

Testing dataset: 1375 out of 2715 correct. Accuracy: 50.645
Epoch: 3
Training dataset: 8747 out of 15385 correct. Accuracy: 56.854. Loss: 1.278


 10%|████████▎                                                                          | 3/30 [02:02<18:13, 40.50s/it]

Testing dataset: 1516 out of 2715 correct. Accuracy: 55.838
Epoch: 4
Training dataset: 10544 out of 15385 correct. Accuracy: 68.534. Loss: 0.946


 13%|███████████                                                                        | 4/30 [02:42<17:27, 40.29s/it]

Testing dataset: 1836 out of 2715 correct. Accuracy: 67.624
Epoch: 5
Training dataset: 11855 out of 15385 correct. Accuracy: 77.056. Loss: 0.7


 17%|█████████████▊                                                                     | 5/30 [03:22<16:42, 40.11s/it]

Testing dataset: 2076 out of 2715 correct. Accuracy: 76.464
Epoch: 6
Training dataset: 12781 out of 15385 correct. Accuracy: 83.074. Loss: 0.526


 20%|████████████████▌                                                                  | 6/30 [04:02<16:01, 40.06s/it]

Testing dataset: 2269 out of 2715 correct. Accuracy: 83.573
Epoch: 7
Training dataset: 13305 out of 15385 correct. Accuracy: 86.48. Loss: 0.426


 23%|███████████████████▎                                                               | 7/30 [04:41<15:18, 39.93s/it]

Testing dataset: 2213 out of 2715 correct. Accuracy: 81.51
Epoch: 8
Training dataset: 13526 out of 15385 correct. Accuracy: 87.917. Loss: 0.384


 27%|██████████████████████▏                                                            | 8/30 [05:21<14:36, 39.85s/it]

Testing dataset: 2401 out of 2715 correct. Accuracy: 88.435
Epoch: 9
Training dataset: 14027 out of 15385 correct. Accuracy: 91.173. Loss: 0.29


 30%|████████████████████████▉                                                          | 9/30 [06:01<13:54, 39.75s/it]

Testing dataset: 2393 out of 2715 correct. Accuracy: 88.14
Epoch: 10
Training dataset: 13779 out of 15385 correct. Accuracy: 89.561. Loss: 0.339


 33%|███████████████████████████▎                                                      | 10/30 [06:40<13:14, 39.70s/it]

Testing dataset: 2336 out of 2715 correct. Accuracy: 86.041
Epoch: 11
Training dataset: 14256 out of 15385 correct. Accuracy: 92.662. Loss: 0.249


 37%|██████████████████████████████                                                    | 11/30 [07:20<12:34, 39.73s/it]

Testing dataset: 2338 out of 2715 correct. Accuracy: 86.114
Epoch: 12
Training dataset: 14132 out of 15385 correct. Accuracy: 91.856. Loss: 0.274


 40%|████████████████████████████████▊                                                 | 12/30 [08:00<11:55, 39.74s/it]

Testing dataset: 2385 out of 2715 correct. Accuracy: 87.845
Epoch: 13
Training dataset: 14162 out of 15385 correct. Accuracy: 92.051. Loss: 0.262


 43%|███████████████████████████████████▌                                              | 13/30 [08:40<11:16, 39.81s/it]

Testing dataset: 2526 out of 2715 correct. Accuracy: 93.039
Epoch: 14
Training dataset: 14220 out of 15385 correct. Accuracy: 92.428. Loss: 0.25


 47%|██████████████████████████████████████▎                                           | 14/30 [09:19<10:34, 39.69s/it]

Testing dataset: 2414 out of 2715 correct. Accuracy: 88.913
Epoch: 15
Training dataset: 14430 out of 15385 correct. Accuracy: 93.793. Loss: 0.216


 50%|█████████████████████████████████████████                                         | 15/30 [09:59<09:55, 39.72s/it]

Testing dataset: 2281 out of 2715 correct. Accuracy: 84.015
Epoch: 16
Training dataset: 14077 out of 15385 correct. Accuracy: 91.498. Loss: 0.278


 53%|███████████████████████████████████████████▋                                      | 16/30 [10:39<09:16, 39.73s/it]

Testing dataset: 2153 out of 2715 correct. Accuracy: 79.3
Epoch: 17
Training dataset: 14467 out of 15385 correct. Accuracy: 94.033. Loss: 0.208


 57%|██████████████████████████████████████████████▍                                   | 17/30 [11:19<08:38, 39.87s/it]

Testing dataset: 2425 out of 2715 correct. Accuracy: 89.319
Epoch: 18
Training dataset: 14328 out of 15385 correct. Accuracy: 93.13. Loss: 0.233


 60%|█████████████████████████████████████████████████▏                                | 18/30 [11:59<07:58, 39.86s/it]

Testing dataset: 2003 out of 2715 correct. Accuracy: 73.775
Epoch: 19
Training dataset: 14343 out of 15385 correct. Accuracy: 93.227. Loss: 0.233


 63%|███████████████████████████████████████████████████▉                              | 19/30 [12:39<07:18, 39.83s/it]

Testing dataset: 2517 out of 2715 correct. Accuracy: 92.707
Epoch: 20
Training dataset: 14587 out of 15385 correct. Accuracy: 94.813. Loss: 0.183


 67%|██████████████████████████████████████████████████████▋                           | 20/30 [13:19<06:38, 39.87s/it]

Testing dataset: 2454 out of 2715 correct. Accuracy: 90.387
Epoch: 21
Training dataset: 14332 out of 15385 correct. Accuracy: 93.156. Loss: 0.241


 70%|█████████████████████████████████████████████████████████▍                        | 21/30 [13:58<05:58, 39.84s/it]

Testing dataset: 2042 out of 2715 correct. Accuracy: 75.212
Epoch: 22
Training dataset: 14588 out of 15385 correct. Accuracy: 94.82. Loss: 0.183


 73%|████████████████████████████████████████████████████████████▏                     | 22/30 [14:38<05:18, 39.83s/it]

Testing dataset: 2437 out of 2715 correct. Accuracy: 89.761
Epoch: 23
Training dataset: 14486 out of 15385 correct. Accuracy: 94.157. Loss: 0.204


 77%|██████████████████████████████████████████████████████████████▊                   | 23/30 [15:18<04:38, 39.79s/it]

Testing dataset: 2536 out of 2715 correct. Accuracy: 93.407
Epoch: 24
Training dataset: 14455 out of 15385 correct. Accuracy: 93.955. Loss: 0.216


 80%|█████████████████████████████████████████████████████████████████▌                | 24/30 [15:58<03:58, 39.77s/it]

Testing dataset: 2364 out of 2715 correct. Accuracy: 87.072
Epoch: 25
Training dataset: 14638 out of 15385 correct. Accuracy: 95.145. Loss: 0.175


 83%|████████████████████████████████████████████████████████████████████▎             | 25/30 [16:37<03:18, 39.76s/it]

Testing dataset: 2332 out of 2715 correct. Accuracy: 85.893
Epoch: 26
Training dataset: 14443 out of 15385 correct. Accuracy: 93.877. Loss: 0.216


 87%|███████████████████████████████████████████████████████████████████████           | 26/30 [17:17<02:38, 39.72s/it]

Testing dataset: 1756 out of 2715 correct. Accuracy: 64.678
Epoch: 27
Training dataset: 14484 out of 15385 correct. Accuracy: 94.144. Loss: 0.207


 90%|█████████████████████████████████████████████████████████████████████████▊        | 27/30 [17:57<01:59, 39.73s/it]

Testing dataset: 2329 out of 2715 correct. Accuracy: 85.783
Epoch: 28
Training dataset: 14866 out of 15385 correct. Accuracy: 96.627. Loss: 0.118


 93%|████████████████████████████████████████████████████████████████████████████▌     | 28/30 [18:36<01:19, 39.64s/it]

Testing dataset: 2532 out of 2715 correct. Accuracy: 93.26
Epoch: 29
Training dataset: 14372 out of 15385 correct. Accuracy: 93.416. Loss: 0.234


 97%|███████████████████████████████████████████████████████████████████████████████▎  | 29/30 [19:16<00:39, 39.67s/it]

Testing dataset: 2107 out of 2715 correct. Accuracy: 77.606
Epoch: 30
Training dataset: 14413 out of 15385 correct. Accuracy: 93.682. Loss: 0.219


100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [19:56<00:00, 39.87s/it]

Testing dataset: 2517 out of 2715 correct. Accuracy: 92.707
Finished





In [10]:
checkpoint = torch.load(CHECKPOINT_PATH)
print(f"Best Epoch: {checkpoint['epoch']}, Best Accuracy: {round(checkpoint['best accuracy'], 3)}")

Best Epoch: 23, Best Accuracy: 93.407


In [11]:
resnet18 = models.resnet18(pretrained=False)
num_features = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_features, NUM_CLASSES)
resnet18.load_state_dict(checkpoint['model'])

dummy_input = torch.randn(1, 3, 256, 256, dtype=torch.float32)
torch.onnx.export(
    resnet18, 
    dummy_input, 
    ONNX_SAVE_PATH,
    verbose=False,
    export_params=True,
    opset_version=11,
    input_names = ['image'],
    output_names = ['pred']
)