In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.v2 as v2
from torchvision.models import GoogLeNet_Weights

from data_prep import DermNet, get_dataloaders
from googlenet_scalp import GoogLeNet_Scalp
from utils import train_model

# hyperparameters
batch_size = 64
lr = 1e-3
eps = 1e-4
weight_decay = 1e-4
weight_decay = 0
step_size = 7
gamma = 0.1
num_epochs = 25

# optimizations
num_workers = 12
pin_memory = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# get transform and augment
transform = GoogLeNet_Weights.DEFAULT.transforms()
## Reference: https://sebastianraschka.com/blog/2023/data-augmentation-pytorch.html
augmenter = v2.RandAugment()

# get dataset and data loaders
dataset = DermNet(transform=transform)
num_classes = len(dataset.classes)
train_loader, val_loader = get_dataloaders(
    dataset=dataset, transform=transform, batch_size=batch_size,
    num_workers=num_workers, pin_memory=pin_memory
)
dataloaders = {'train': train_loader, 'val': val_loader}

# setup model
model_ft = GoogLeNet_Scalp(device=device, num_classes=num_classes)
model_ft.load_state_dict(torch.load('weights/model_checkpoint.pt'))
# model_ft.load_state_dict(torch.load('weights/googlenet_scalp_99.pt'))

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=step_size, gamma=gamma)

model_ft = train_model(model_ft, dataloaders, criterion, optimizer_ft, lr_scheduler,
                        num_epochs=num_epochs, device=device, augmenter=augmenter)
# save model
torch.save(model_ft.state_dict(), 'weights/model_checkpoint.pt')

Epoch 1/25 | Learning Rate 0.001
----------
train Loss: 0.0125 Acc: 0.9525
val Loss: 0.1512 Acc: 0.7324

Epoch 2/25 | Learning Rate 0.001
----------
train Loss: 0.0195 Acc: 0.9410
val Loss: 0.0582 Acc: 0.8521

Epoch 3/25 | Learning Rate 0.001
----------
train Loss: 0.0124 Acc: 0.9481
val Loss: 0.0194 Acc: 0.9366

Epoch 4/25 | Learning Rate 0.001
----------
train Loss: 0.0084 Acc: 0.9710
val Loss: 0.0519 Acc: 0.8873

Epoch 5/25 | Learning Rate 0.001
----------
train Loss: 0.0149 Acc: 0.9569
val Loss: 0.0705 Acc: 0.7887

Epoch 6/25 | Learning Rate 0.001
----------
train Loss: 0.0114 Acc: 0.9551
val Loss: 0.0474 Acc: 0.8944

Epoch 7/25 | Learning Rate 0.001
----------
train Loss: 0.0070 Acc: 0.9727
val Loss: 0.0394 Acc: 0.9225

Epoch 8/25 | Learning Rate 0.0001
----------
train Loss: 0.0053 Acc: 0.9806
val Loss: 0.0395 Acc: 0.8873

Epoch 9/25 | Learning Rate 0.0001
----------
train Loss: 0.0034 Acc: 0.9842
val Loss: 0.0631 Acc: 0.8204

Epoch 10/25 | Learning Rate 0.0001
----------
train L