We use a notebook rather than a .py file so we can keep datasets in memory between runs. This is useful for testing different network parameters, architectures, etc.

## 1. Prepare datasets:

In [None]:
from torch.utils.data import random_split, ConcatDataset
import sys

# get own modules
sys.path.insert(0, './age-gender-cnn')
from src.train import train_model
from src.face_dataset import *
from src.label_funcs import *
import src.preprocessor as preprocessor
from src.augmentations import get_augs
from src.ds_transforms import *
from src.networks import *
import src.tests as tests


# Define dataset parameters
ds_size = 1000 # takes ~2m to process 1000 files
processor = preprocessor.process(crop='mid', size=168)
transform = alexnet_transform(168)
print_errors = False # True for debugging

# imdb_age_ds = MemoryDataset(
#     'C:\\Users\\jckpn\\Documents\\YEAR 3 PROJECT\\implementation\\source\\other\\imdb_crop',
#     label_func=age_label_all, transform=transform, processor=processor,
#     ds_size=ds_size, print_errors=print_errors)

imdb_gender_ds = MemoryDataset(
    'C:\\Users\\jckpn\\Documents\\YEAR 3 PROJECT\\implementation\\source\\other\\imdb_crop',
    label_func=binary_gender_label, transform=transform, processor=processor,
    ds_size=ds_size, print_errors=print_errors)

# Configure as needed here

train_val_set = imdb_gender_ds

# Split dataset into training and validation sets
val_split_ratio = 0.2
val_size = int(val_split_ratio * len(train_val_set))
train_size = len(train_val_set) - val_size
train_set, val_set = random_split(train_val_set, [train_size, val_size])
print(f'Split dataset into {len(train_set)} training and {len(val_set)} validation examples')

test_set = val_set # intra-dataset tests for now

## 2. Train models:

In [None]:
from torch.nn import CrossEntropyLoss, MSELoss
from torchvision import transforms


model = train_model(
    model=AlexNet(num_classes=2, pretrained=False), 

    train_set=train_set,
    val_set=val_set,
    image_resize=224,
    aug_transform=get_augs(),
    optim_fn=torch.optim.Adam,
    learning_rate=0.0001,

    patience=5,
    max_epochs=30,
    
    model_save_dir='./models')

tests.class_accuracy(model, test_set, image_resize=224)
tests.confusion_matrix(model, test_set, image_resize=224)