In [None]:
import os

import keras
import ujson
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.layers import Dense
from keras.models import Model
from keras.optimizers import Adam

from models.classifier.resnet50.constants import CHECKPOINT_FILE_NAME
from models.classifier.resnet50.constants import FINAL_FILE_NAME

from flytekit.contrib.notebook import python_notebook
from flytekit.sdk.tasks import inputs, outputs
from flytekit.sdk.types import Types


In [None]:
def print_dir(directory, logger):
    for r, d, files in os.walk(directory):
        logger.info(r, d, files)


In [4]:
# Training function
def train_resnet50_model(
    train_directory,
    validation_directory,
    output_model_folder,
    logger,
    patience,
    epochs,
    batch_size,
    size,
    weights,
):
    logger.info(
        f"Train Resnet 50 called with Train: {train_directory}, Validation: {validation_directory}"
    )
    print_dir(train_directory, logger)
    print_dir(validation_directory, logger)

    # Creating a data generator for training data
    gen = keras.preprocessing.image.ImageDataGenerator()

    # Creating a data generator and configuring online data augmentation for validation data
    val_gen = keras.preprocessing.image.ImageDataGenerator(
        horizontal_flip=True, vertical_flip=True
    )

    # Organizing the training images into batches
    batches = gen.flow_from_directory(
        train_directory,
        target_size=size,
        class_mode="categorical",
        shuffle=True,
        batch_size=batch_size,
    )

    num_train_steps = len(batches)
    if not num_train_steps:
        raise Exception("No training batches")
    logger.info("num_train_steps = %s" % num_train_steps)

    # Organizing the validation images into batches
    val_batches = val_gen.flow_from_directory(
        validation_directory,
        target_size=size,
        class_mode="categorical",
        shuffle=True,
        batch_size=batch_size,
    )

    num_valid_steps = len(val_batches)
    if not num_valid_steps:
        raise Exception("No validation batches.")
    logger.info("num_valid_steps = %s" % num_valid_steps)

    # Picking the predefined ResNet50 as our model, and initialize it with a weight file
    model = keras.applications.resnet50.ResNet50(weights=weights)

    # Change resnet from a binary classifier to a multi-class classifier by removing the last later
    classes = list(iter(batches.class_indices))
    model.layers.pop()

    # Since we don't have much training data, we want to leverage the feature learned from a larger dataset, in this,
    # case, imagenet. So we fine-tune based on a pre-trained weight by freezing the weights except for the last layer
    for layer in model.layers:
        layer.trainable = False

    # Attaching a fully-connected layer with softmax activation as the last layer to support multi-class classification
    last = model.layers[-1].output
    x = Dense(len(classes), activation="softmax")(last)

    finetuned_model = Model(inputs=model.input, outputs=x)

    # Compile the model with an optimizer, a loss function, and a list of metrics of choice
    finetuned_model.compile(
        optimizer=Adam(lr=0.00001),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )

    for c in batches.class_indices:
        classes[batches.class_indices[c]] = c
    finetuned_model.classes = classes

    # Setting early stopping thresholds to reduce training time
    early_stopping = EarlyStopping(patience=patience)

    # Checkpoint the current best model
    checkpointer = ModelCheckpoint(
        output_model_folder + "/" + CHECKPOINT_FILE_NAME, verbose=1, save_best_only=True
    )

    # Train it
    finetuned_model.fit_generator(
        batches,
        steps_per_epoch=num_train_steps,
        epochs=epochs,
        callbacks=[early_stopping, checkpointer],
        validation_data=val_batches,
        validation_steps=num_valid_steps,
    )

    finetuned_model.save(output_model_folder + "/" + FINAL_FILE_NAME)


Using TensorFlow backend.


In [1]:
# The main training notebook task
def train_on_datasets(
    wf_params,
    train_zips,
    validation_zips,
    model_config_string,
    model_output_path,
    model_blobs,
    model_files_names,
):
    metadata_folder = wf_params.working_directory.get_named_tempfile("metadata")
    Path(metadata_folder).mkdir(0o777, parents=True, exist_ok=False)

    zips_folder = wf_params.working_directory.get_named_tempfile("zips")
    Path(zips_folder).mkdir(0o777, parents=True, exist_ok=False)

    train_dataset = wf_params.working_directory.get_named_tempfile("train")
    Path(train_dataset).mkdir(0o777, parents=True, exist_ok=False)

    validation_dataset = wf_params.working_directory.get_named_tempfile("validate")
    Path(validation_dataset).mkdir(0o777, parents=True, exist_ok=False)

    output_folder = wf_params.working_directory.get_named_tempfile("output")
    Path(output_folder).mkdir(0o777, parents=True, exist_ok=False)

    model_config = ujson.loads(model_config_string)
    train_streams = flatten_session_sub_path_stream_tuple(
        model_config.get("train_datasets", {})
    )
    validaton_streams = flatten_session_sub_path_stream_tuple(
        model_config.get("validation_datasets", {})
    )
    s3_client = boto3.client("s3")
    download_and_arrange_datasets_for_resnet(
        data_streams=train_streams,
        zips_folder=zips_folder,
        blobs=train_zips,
        out_path=train_dataset,
        s3_client=s3_client,
        tmp_metadata_folder=metadata_folder,
    )
    download_and_arrange_datasets_for_resnet(
        data_streams=validaton_streams,
        zips_folder=zips_folder,
        blobs=validation_zips,
        out_path=validation_dataset,
        s3_client=s3_client,
        tmp_metadata_folder=metadata_folder,
    )

    # TODO: read overrides for some of these values from the model_config.json
    train_resnet50_model(
        train_dataset,
        validation_dataset,
        output_folder,
        logger=wf_params.logging,
        patience=DEFAULT_PATIENCE,
        size=DEFAULT_IMG_SIZE,
        batch_size=DEFAULT_BATCH_SIZE,
        epochs=DEFAULT_EPOCHS,
        weights=DEFAULT_WEIGHTS,
    )

    # save results to Workflow output
    blobs, files_names_list = blobs_from_folder_recursive(output_folder)
    model_blobs.set(blobs)
    model_files_names.set(files_names_list)

    """
    # write results to storage path also
    for file in files_names_list:
        location = model_output_path + file
        out_blob = Types.Blob.create_at_known_location(location)

        with out_blob as out_writer:
            with open(output_folder + "/" + file, mode="rb") as in_reader:
                out_writer.write(in_reader.read())

    # keep the model_config with the trained model
    location = model_output_path + MODEL_CONFIG_FILE_NAME
    out_blob = Types.Blob.create_at_known_location(location)
    with out_blob as out_writer:
        out_writer.write((model_config_string).encode("utf-8"))

    # write metadata to track what execution this was done by
    location = model_output_path + MODEL_GENERATED_BY_FILE_NAME
    out_blob = Types.Blob.create_at_known_location(location)
    with out_blob as out_writer:
        out_writer.write((f"workflow_id: {wf_params.execution_id}").encode("utf-8"))
    """

In [None]:
model_blobs, model_files_names = train_on_datasets()

In [None]:
from flytekit.contrib.notebook import record_outputs
record_outputs({
    "model_blobs": model_blobs,
    "model_files_names": model_files_names,
})