In [1]:
import numpy as np
import pickle
import json
import os
import yaml

# from silence_tensorflow import silence_tensorflow
# silence_tensorflow()
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import tensorflow as tf

from utils.datasets import get_generators, create_classifier_dataset
from utils.misc import log_config
from utils.train.callbacks import Logger
from utils.train.visualization import analyze_history
from utils.train.classifier import load_model
from config.datasets_config import DATASETS_CONFIG

In [2]:
def configure_saving():
    # Generate save directory and store in config
    save_dir = os.path.join(config['root_save_dir'], config['model_name'])
    config['save_dir'] = save_dir

    # Create save directory (if it does not exist)
    try:
        os.makedirs(save_dir, exist_ok=False)
    except FileExistsError:
        input_ = input('save_dir already exists, continue? (Y/n)  >> ')
        if input_ != 'Y':
            raise ValueError

In [3]:
def load_datasets():
    dataset_config['train_split'] = config['train_split']
    dataset_config['validation_split'] = config['validation_split']

    # Load data generators
    datagen, datagen_val, datagen_test = get_generators(
        ['train', 'val', 'test'],
        config['image_shape'],
        batch_size=1,  # batched later
        random_seed=config['random_seed'],
        dataset_config=dataset_config
    )
    classes = list(datagen.class_indices.keys())
    config['classes'] = classes
    config['num_classes'] = len(classes)

    # Load class weight
    class_weight = None
    if config['use_class_weight']:
        with open(os.path.join(dataset_config['dataset_dir'], 'class_weight.json'), 'r') as f:
            class_weight = json.load(f)
        groups = dataset_config['groups']
        class_weight = {groups[k]: v for k, v in class_weight.items() if k in groups.keys()}
        class_weight = {datagen.class_indices[k]: v for k, v in class_weight.items()}
        print('Using class weights:', class_weight)
    config['class_weight'] = class_weight

    # Load datasets
    datasets, steps = [], []
    for gen in [datagen, datagen_val, datagen_test]:
        ds = create_classifier_dataset(gen, config['image_shape'], len(classes))
        ds = ds.batch(config['batch_size'])
        ds = ds.prefetch(config['prefetch'])

        steps.append(len(gen) // config['batch_size'])
        datasets.append(ds)
    config['steps'] = steps

    return datasets

In [5]:
with open('config/classifier_config.yaml') as file:
    config = yaml.safe_load(file)
dataset_config = DATASETS_CONFIG[config['dataset_type']]

np.random.seed(config['random_seed'])
# tf.random.set_seed(config['random_seed'])  # messes things up in encoder training, so I'm removing it here

if config['model_type'] == 'vae':
    config['latent_dim'] = 512
    config['head_lr'] = 1e-3
    config['encoder_lr'] = 1e-3

# Barlow Twins baseline training setup
config['model_type'] = 'resnet50'
config['root_save_dir'] = 'trained_models/classifiers/lamb_100_4096_256_imagenet_lr'
config['projector_dim'] = 4096  # doesn't actually matter since the projector head is removed!
# config['']

In [None]:
# Trains a model for each fraction of data and save the results
# for frac in [0.05, 0.1, 0.2, 0.4, 0.6, 0.8]:

for lr in [0.005, 0.01, 0.02, 0.04, 0.08, 0.1]:
    # !!! Change weights path here
    for weights_path in ['trained_models/encoders/lamb_100_4096_256_imagenet/resnet.h5']:
        config['encoder_weights_path'] = weights_path
        
        if weights_path is None:
            # Training parameters for supervised models
            model_type = 'supervised'
            config['optimizer'] = 'adam'
            config['lr_scheduler'] = 'plateau'
            config['head_lr'] = 5e-3
            config['encoder_lr'] = 5e-3
        else:
            # Training parameters for semi-supervised models
            model_type = 'barlow'
            config['optimizer'] = 'sgdw'
            config['lr_scheduler'] = 'cosine'
            config['head_lr'] = 0.03
            config['encoder_lr'] = 0.03
        
        # Hyperparameter(s) to be fine-tuned
#         config['train_split'] = frac
        config['train_split'] = 0.2
        config['head_lr'] = lr
        config['encoder_lr'] = lr
        config['model_name'] = f'{model_type}_{lr}'

        configure_saving()

        # Load dataset and model
        datasets = load_datasets()
        model = load_model(config_dict=config)

        # Create training callbacks
        callbacks = []
        if config['patience'] is not None:
            es = EarlyStopping(monitor='val_auc', mode='max', verbose=1, patience=config['patience'])
            callbacks.append(es)

        if config['lr_scheduler'] == 'plateau':
            reduce = ReduceLROnPlateau(
                monitor='val_auc',
                factor=0.2,
                patience=3,
                verbose=1
            )
            callbacks.append(reduce)

        mc = ModelCheckpoint(
            os.path.join(config['save_dir'], 'classifier.h5'),
            monitor='val_auc', 
            mode='max',
            verbose=1,
            save_best_only=True, save_weights_only=True
        )
        callbacks.append(mc)
        
        callbacks.append(Logger())

        # Print and save the configuration
        log_config(config, save_config=True)

        # Train the model
        history = model.fit(
            datasets[0],
            epochs=config['epochs'],
            steps_per_epoch=config['steps'][0],
            validation_steps=config['steps'][1],
            validation_data=datasets[1],
            callbacks=callbacks,
            class_weight=config['class_weight']
        )

        # Save the training history
        with open(os.path.join(config['save_dir'], 'history.pickle'), 'wb') as f:
            pickle.dump(history.history, f)

        # Load best model, save encoder weights (separately), and evaluate model
        model.load_weights(os.path.join(config['save_dir'], 'classifier.h5'))
        model.layers[1].save_weights(os.path.join(config['save_dir'], 'encoder.h5'))
        model.evaluate(datasets[2], steps=config['steps'][2])

Found 23676 validated image filenames belonging to 6 classes.
Found 23677 validated image filenames belonging to 6 classes.
Found 20661 validated image filenames belonging to 6 classes.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/devic

Epoch 5/30
Epoch 00005: val_auc improved from 0.92785 to 0.93922, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epoch 6/30
Epoch 00006: val_auc improved from 0.93922 to 0.94615, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epoch 7/30
Epoch 00007: val_auc improved from 0.94615 to 0.95178, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epoch 8/30
Epoch 00008: val_auc improved from 0.95178 to 0.95544, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epoch 9/30
Epoch 00009: val_auc improved from 0.95544 to 0.95719, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epoch 10/30
Epoch 00010: val_auc improved from 0.95719 to 0.95820, saving model to trained_models/classifiers/lamb_100_4096_256_imagenet_lr/barlow_0.005/classifier.h5
Epo