In [None]:
from utils import load_dataset
from constants import *
import torch
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from utils import get_train_split_sizes, set_parameter_requires_grad, show_history
from training import train_model, visualize_model, test

In [None]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")

In [None]:
dataloaders, image_datasets = load_dataset(IMG_SIZE, train_dir, test_dir,
                                               batch_size=BATCH_SIZE, val_size=VAL_SIZE,
                                               pin_memory=gpu)

train_size, val_size = get_train_split_sizes(image_datasets, VAL_SIZE)

dataset_sizes = {
    'test': len(image_datasets['test']),
    'train': train_size,
    'val': val_size
                }

In [None]:
print(f'Dataset sizes: {dataset_sizes}')

In [None]:
TRAIN_EPOCHS = 5
FINETUNE_EPOCHS = 15

# Resnet

In [None]:
model_name = Models.RESNET.value[0]
model_path = MODELS_PATH(model_name)
plots_dir = plots_dir(model_name)

In [None]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft = set_parameter_requires_grad(model_ft, True)
model_ft.fc = torch.nn.Linear(num_ftrs, class_no)
model_ft = model_ft.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
test(model_ft, dataloaders, device, criterion)

## Feature extracting

In [None]:
model_ft, tr_history = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       dataloaders, device, dataset_sizes,
                       num_epochs=TRAIN_EPOCHS, early_stopping_ep=3)

In [None]:
test(model_ft, dataloaders, device, criterion)

In [None]:
torch.save(model_ft.state_dict(), model_path + '_extracted')

In [None]:
show_history(tr_history, 'Training history', os.path.join(plots_dir, 'training_history.jpg'))

## Finetuning

In [None]:
model_ft = set_parameter_requires_grad(model_ft, False)

In [None]:
model_ft, ft_history = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       dataloaders, device, dataset_sizes,
                       num_epochs=FINETUNE_EPOCHS, early_stopping_ep=3)

In [None]:
test(model_ft, dataloaders, device, criterion)

In [None]:
torch.save(model_ft.state_dict(), model_path + '_finetuned')

In [None]:
show_history(ft_history, 'Finetuning history', os.path.join(plots_dir, 'finetuning_history.jpg'))

In [None]:
full_history = [key: tr_history[key] + ft_history[key] for key in tr_history.keys()]

In [None]:
show_history(full_history, 'Whole training history', os.path.join(plots_dir, 'full_history.jpg'))