The following jupyter notebook was created using the following two websites:
- https://rumn.medium.com/custom-pytorch-image-classifier-from-scratch-d7b3c50f9fbe
- https://github.com/lettuceDestroyer/image_classifier

# Imports

In [41]:
import glob
import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.io import read_image, ImageReadMode

# Variables

In [42]:
ROOT_FOLDER_PATH = "C:\\Users\\tobil\\Downloads\\image-taker"
NUMBER_OF_LABELS = 5

# Datasets and Dataloaders

In [43]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5)
])

In [44]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.transform = transform
        
        self.image_paths = []
        for ext in ['png', 'jpg']:
            self.image_paths += glob.glob(os.path.join(root_dir, '*', f'*.{ext}'))
        class_set = set()
        for path in self.image_paths:
            class_set.add(os.path.basename(os.path.dirname(path)))
        self.class_lbl = { cls: i for i, cls in enumerate(sorted(list(class_set)))}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = read_image(self.image_paths[idx], ImageReadMode.RGB).float()
        cls = os.path.basename(os.path.dirname(self.image_paths[idx]))
        label = self.class_lbl[cls]

        return self.transform(img), torch.tensor(label)

In [45]:
#dataset = CustomDataset(ROOT_FOLDER_PATH, transform)
dataset = CustomDataset(ROOT_FOLDER_PATH, transform)

In [46]:
splits = [0.8, 0.1, 0.1]
split_sizes = []
for sp in splits[:-1]:
    split_sizes.append(int(sp * len(dataset)))
split_sizes.append(len(dataset) - sum(split_sizes))

train_set, test_set, val_set = torch.utils.data.random_split(dataset, split_sizes)

In [47]:
dataloaders = {
    "train": DataLoader(train_set, batch_size=8, shuffle=True),
    "test": DataLoader(test_set, batch_size=8, shuffle=False),
    "val": DataLoader(val_set, batch_size=8, shuffle=False)
}

# Model Definition

In [48]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [49]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model.fc = torch.nn.Linear(2048, NUMBER_OF_LABELS)
model.to(device)
pass

In [50]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training

In [51]:
EPOCHS = 10
NUM_CLASSES = len(dataset.class_lbl)

In [52]:
dataset.class_lbl

{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4}

In [53]:
metrics = {
    'train': {'loss': [], 'accuracy': []},
    'val': {'loss': [], 'accuracy': []},
}

In [None]:
for epoch in range(EPOCHS):
  ep_metrics = {
    'train': {'loss': 0, 'accuracy': 0, 'count': 0},
    'val': {'loss': 0, 'accuracy': 0, 'count': 0},
  }

  print(f'Epoch {epoch}')

  for phase in ['train', 'val']:
    print(f'-------- {phase} --------')
    for images, labels in tqdm(dataloaders[phase]):
      optimizer.zero_grad()

      with torch.set_grad_enabled(phase == 'train'):
        output = model(images.to(device))
        ohe_label = torch.nn.functional.one_hot(labels,
                                                num_classes=NUM_CLASSES)

        loss = criterion(output, ohe_label.float().to(device))

        correct_preds = labels.to(device) == torch.argmax(output, dim=1)
        accuracy = (correct_preds).sum()/len(labels)

      if phase == 'train':
        loss.backward()
        optimizer.step()

      ep_metrics[phase]['loss'] += loss.item()
      ep_metrics[phase]['accuracy'] += accuracy.item()
      ep_metrics[phase]['count'] += 1
  
    ep_loss = ep_metrics[phase]['loss']/ep_metrics[phase]['count']
    ep_accuracy = ep_metrics[phase]['accuracy']/ep_metrics[phase]['count']

    print(f'Loss: {ep_loss}, Accuracy: {ep_accuracy}\n')

    metrics[phase]['loss'].append(ep_loss)
    metrics[phase]['accuracy'].append(ep_accuracy)

Epoch 0
-------- train --------


  0%|          | 0/235 [00:00<?, ?it/s]

100%|██████████| 235/235 [03:52<00:00,  1.01it/s]


Loss: 0.16536697372596, Accuracy: 0.951063829787234

-------- val --------


100%|██████████| 30/30 [00:11<00:00,  2.65it/s]


Loss: 0.027967976910683017, Accuracy: 0.9958333333333333

Epoch 1
-------- train --------


100%|██████████| 235/235 [02:41<00:00,  1.46it/s]


Loss: 0.02039459517220669, Accuracy: 0.9962765957446809

-------- val --------


100%|██████████| 30/30 [00:06<00:00,  4.79it/s]


Loss: 0.0031370956741739063, Accuracy: 1.0

Epoch 2
-------- train --------


100%|██████████| 235/235 [02:15<00:00,  1.73it/s]


Loss: 0.003807809290524443, Accuracy: 1.0

-------- val --------


100%|██████████| 30/30 [00:06<00:00,  4.85it/s]


Loss: 0.0011924368731949168, Accuracy: 1.0

Epoch 3
-------- train --------


100%|██████████| 235/235 [02:21<00:00,  1.66it/s]


Loss: 0.0024165184292484054, Accuracy: 1.0

-------- val --------


100%|██████████| 30/30 [00:06<00:00,  4.95it/s]


Loss: 0.0014564336500673865, Accuracy: 1.0

Epoch 4
-------- train --------


100%|██████████| 235/235 [02:16<00:00,  1.72it/s]


Loss: 0.01637178419620428, Accuracy: 0.9946808510638298

-------- val --------


100%|██████████| 30/30 [00:06<00:00,  4.79it/s]


Loss: 0.004081093804173482, Accuracy: 1.0

Epoch 5
-------- train --------


100%|██████████| 235/235 [02:44<00:00,  1.43it/s]


Loss: 0.017130229577355265, Accuracy: 0.9957446808510638

-------- val --------


100%|██████████| 30/30 [00:08<00:00,  3.40it/s]


Loss: 0.043943894107360396, Accuracy: 0.9958333333333333

Epoch 6
-------- train --------


100%|██████████| 235/235 [19:15<00:00,  4.92s/it]   


Loss: 0.0313941594577974, Accuracy: 0.9909574468085106

-------- val --------


100%|██████████| 30/30 [00:08<00:00,  3.59it/s]


Loss: 0.001308706108344874, Accuracy: 1.0

Epoch 7
-------- train --------


100%|██████████| 235/235 [02:56<00:00,  1.33it/s]


Loss: 0.009299867067107078, Accuracy: 0.9984042553191489

-------- val --------


100%|██████████| 30/30 [00:08<00:00,  3.52it/s]


Loss: 0.002471204300915512, Accuracy: 1.0

Epoch 8
-------- train --------


 75%|███████▌  | 177/235 [02:06<00:35,  1.64it/s]

In [None]:
for phase in metrics:
    for metric in metrics[phase]:
        metric_data = metrics[phase][metric]
        plt.plot(range(len(metric_data)), metric_data)
        plt.xlabel('Epoch')
        plt.ylabel(f'{phase} {metric}')
        plt.show()

# Testing

In [None]:
preds = []
actual = []

tot_loss = tot_acc = count = 0

for images, labels in tqdm(dataloaders['test']):
    with torch.set_grad_enabled(False):
        output = model(images.to(device))
        ohe_label = torch.nn.functional.one_hot(labels, num_classes=NUM_CLASSES)
        out_labels = torch.argmax(output, dim=1)


        tot_loss += criterion(output, ohe_label.float().to(device))
        tot_acc += (labels.to(device) == out_labels).sum()/len(labels)
        count += 1

    preds += out_labels.tolist()
    actual += labels.tolist()

print(f"Test Loss: {tot_loss / count}, Test Accuracy: {tot_acc / count}")

In [None]:
class_labels = sorted(dataset.class_lbl.keys())

cm = confusion_matrix(actual, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)

disp.plot()
plt.show()

In [None]:
cm_np = np.array(cm)
stats = pd.DataFrame(index=class_labels)
stats['Precision'] = [cm_np[i, i]/np.sum(cm_np[:, i]) for i in range(len(cm_np))]
stats['Recall'] = [cm_np[i, i]/np.sum(cm_np[i, :]) for i in range(len(cm_np))]

In [None]:
stats

In [None]:
example_inputs = (torch.randn(1, 1, 32, 32),)
onnx_program = torch.onnx.export(model, example_inputs, dynamo=True)
onnx_program.save(os.path.join(ROOT_FOLDER_PATH, "image_classifier_model.onnx"))