In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.v2 as v2
from torchvision.models import GoogLeNet_Weights

from data_prep import DermNet, get_dataloaders
from googlenet_scalp import GoogLeNet_Scalp
from model_trainer import train_model

# hyperparameters
batch_size = 64
lr = 1e-4
eps = 1e-4
weight_decay = 1e-4
step_size = 7
gamma = 0.1
num_epochs = 25

# optimizations
num_workers = 12
pin_memory = True
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# get transform and augment
transform = GoogLeNet_Weights.DEFAULT.transforms()
## Reference: https://sebastianraschka.com/blog/2023/data-augmentation-pytorch.html
augmenter = v2.RandAugment()

# get dataset and data loaders
dataset = DermNet(transform=transform)
num_classes = len(dataset.classes)
train_loader, val_loader = get_dataloaders(
    dataset=dataset, transform=transform, batch_size=batch_size,
    num_workers=num_workers, pin_memory=pin_memory
)
dataloaders = {'train': train_loader, 'val': val_loader}

# setup model
model_ft = GoogLeNet_Scalp(device=device, num_classes=num_classes)
# model_ft.load_state_dict(torch.load('weights/model_checkpoint.pt'))

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=lr, eps=eps, weight_decay=weight_decay)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=step_size, gamma=gamma)

model_ft = train_model(model_ft, dataloaders, criterion, optimizer_ft, lr_scheduler,
                        num_epochs=num_epochs, device=device, augmenter=augmenter)
# save model
torch.save(model_ft.state_dict(), 'weights/model_checkpoint.pt')

Epoch 1/25 | Learning Rate 0.0001
----------
train Loss: 0.0614 Acc: 0.9797
val Loss: 0.4648 Acc: 0.8566

Epoch 2/25 | Learning Rate 0.0001
----------
train Loss: 0.0575 Acc: 0.9825
val Loss: 0.0355 Acc: 0.9890

Epoch 3/25 | Learning Rate 0.0001
----------
train Loss: 0.0347 Acc: 0.9899
val Loss: 0.0745 Acc: 0.9779

Epoch 4/25 | Learning Rate 0.0001
----------
train Loss: 0.0274 Acc: 0.9899
val Loss: 0.0452 Acc: 0.9926

Epoch 5/25 | Learning Rate 0.0001
----------
train Loss: 0.0416 Acc: 0.9871
val Loss: 0.2762 Acc: 0.9449

Epoch 6/25 | Learning Rate 0.0001
----------
train Loss: 0.0207 Acc: 0.9954
val Loss: 0.0368 Acc: 0.9890

Epoch 7/25 | Learning Rate 0.0001
----------
train Loss: 0.0428 Acc: 0.9853
val Loss: 0.1189 Acc: 0.9559

Epoch 8/25 | Learning Rate 1e-05
----------
train Loss: 0.0225 Acc: 0.9908
val Loss: 0.0576 Acc: 0.9853

Epoch 9/25 | Learning Rate 1e-05
----------
train Loss: 0.0224 Acc: 0.9926
val Loss: 0.2341 Acc: 0.9118

Epoch 10/25 | Learning Rate 1e-05
----------
tra