In [1]:
import json
import os
import pickle
import yaml

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
import tensorflow as tf

from utils.models import vae
from utils.train import callbacks as cb
from utils.datasets import get_dataset_df
from utils.misc import log_config

In [2]:
def configure_saving(model_name):
    save_dir = os.path.join(config['root_save_dir'], model_name)
    
    try:
        os.makedirs(save_dir, exist_ok=False)
    except FileExistsError:
        input_ = input('save_dir already exists, continue? (Y/n)  >> ')
        if input_ != 'Y':
            raise ValueError
            
    with open(os.path.join(save_dir, 'config.json'), 'w') as file:
        json.dump(config, file)
        
    return save_dir

In [3]:
def load_dataset():
    df = get_dataset_df(config['dataset_config'], config['random_seed'], mode='encoder')
    df = df.sample(frac=1).reset_index(drop=True)
    
    image_datagen = ImageDataGenerator(rescale=1. / 255, validation_split=config['validation_split'])
    
    datasets = []
    for subset in ['training', 'validation']:
        datagen = image_datagen.flow_from_dataframe(
            df[df['split'] == 'train'],
            shuffle=False,
            seed=config['random_seed'],
            target_size=config['image_shape'][:2],
            batch_size=config['batch_size'],
            subset=subset
        )
        dataset = tf.data.Dataset.from_generator(
            lambda: [datagen.next()[0]],
            output_types='float32', output_shapes=[None] * 4
        )
        dataset = dataset.map(lambda x: tf.clip_by_value(x, 0, 1), num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.prefetch(config['prefetch'])
        datasets.append(dataset)
        
        if subset == 'training':
            config['steps_per_epoch'] = len(datagen)
            print('Training steps per epoch:', config['steps_per_epoch'])
        else:
            config['val_steps_per_epoch'] = len(datagen)
            print('Validation steps per epoch:', config['val_steps_per_epoch'])
    
    # One batch for generating visualizations (from validation set)
    test_batch = next(iter(datasets[1].take(1)))

    # Wrapper
    def get_generator(ds):
        def generator():
            while True:
                yield next(iter(ds))
        return generator

    return get_generator(datasets[0])(), get_generator(datasets[1])(), test_batch

In [4]:
with open('config/vae_config.yaml') as file:
    config = yaml.safe_load(file)
    
config['epochs'] = 100
log_config(config, 20)
    
np.random.seed(config['random_seed'])
tf.random.set_seed(config['random_seed'])

random_seed          42
epochs               100
batch_size           256
patience             None
prefetch             8
gpu_used             ['GPU:0', 'GPU:1', 'GPU:2', 'GPU:3']
root_save_dir        trained_models/vaes
lr                   0.0004
image_shape          [224, 224, 3]
validation_split     0.1
dataset_config       {'split_file_path': 'datasets/tissue_classification/fold_test.csv', 'dataset_dir': 'datasets/tissue_classification/dataset_encoder'}
latent_dim           512



In [5]:
save_dir = configure_saving(model_name='vae_100')

In [6]:
dataset, dataset_val, test_batch = load_dataset()

Found 142201 validated image filenames belonging to 1 classes.
Training steps per epoch: 556
Found 15800 validated image filenames belonging to 1 classes.
Validation steps per epoch: 62


In [8]:
strategy = tf.distribute.MirroredStrategy(config['gpu_used'])
print('Number of devices:', strategy.num_replicas_in_sync)

with strategy.scope():
    # Convolutional variational autoencoder
    optimizer = tf.keras.optimizers.Adam(config['lr'])
    model = vae.CVAE(
        latent_dim=config['latent_dim'],
        image_shape=config['image_shape']
    )
    model.compile(optimizer=optimizer)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Number of devices: 4
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:loc

In [9]:
callbacks = []

mc = ModelCheckpoint(
    os.path.join(save_dir, 'model.h5'),
    monitor='val_loss',
    mode='min',
    verbose=1,
    save_best_only=True,
    save_weights_only=True
)
callbacks.append(mc)

# Create visualizations (original, reconstructed, generated)
vc = cb.VAECheckpoint(
    model=model, 
    model_save_dir=save_dir, 
    latent_dim=config['latent_dim'], 
    test_batch=test_batch
)
callbacks.append(vc)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=5,
    min_lr=1e-5,
    verbose=1,
    mode='min'
)
callbacks.append(reduce_lr)

In [None]:
history = model.fit(
    dataset,
    epochs=config['epochs'],
    steps_per_epoch=config['steps_per_epoch'],
    callbacks=callbacks,
    validation_data=dataset_val,
    validation_steps=config['val_steps_per_epoch']
)

Epoch 1/100
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
INFO:tensorflow:batch_all_reduce: 52 all-reduces with algorithm = nccl, num_packs = 1


In [None]:
with open(os.path.join(save_dir, 'history.pickle'), 'wb') as file:
    pickle.dump(history.history, file)

In [None]:
# Load weights from the best model
model.load_weights(os.path.join(save_dir, 'model.h5'))
model.encoder.save_weights(os.path.join(save_dir, 'encoder.h5'))

In [15]:
model.decoder.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
projector (Sequential)       (None, 7, 7, 256)         6485248   
_________________________________________________________________
conv_block_0 (Sequential)    (None, 14, 14, 256)       591104    
_________________________________________________________________
conv_block_1 (Sequential)    (None, 28, 28, 128)       295552    
_________________________________________________________________
conv_block_2 (Sequential)    (None, 56, 56, 64)        74048     
_________________________________________________________________
conv_block_3 (Sequential)    (None, 112, 112, 32)      18592     
_________________________________________________________________
conv_block_4 (Sequential)    (None, 224, 224, 16)      4688      
_________________________________________________________________
conv_output (Conv2D)         (None, 224, 224, 3)      