In [None]:
%load_ext pycodestyle_magic

In [None]:
%pycodestyle_on

In [None]:
import os
import shutil
import tensorflow as tf

from pathlib import Path

from google.cloud import storage

In [None]:
# When using Colab, you need to allow it to access your resources on GCP.
try:
    from google.colab import auth
    auth.authenticate_user()
except:
    pass

# Fill in the project-id of your GCP project and the name of 
# the GCS bucket where you want to store the checkpoints
PROJECT = "my_gcp_project"
BUCKET = "my_gcp_bucket"

# Checkpoint helper functions

This section contains the helper functions for storing and loading Tensorflow/Keras checkpoints on Google Cloud Storage. You can use these functions as a starting point and modify them for your use case.

The only dependency for this boilerplate is the Python Client for Google Cloud Storage, you can find the docs here:

https://googleapis.dev/python/storage/latest/index.html

## Helper functions for local file system manipulation

Helper functions for local file system manipulation. Tensorflow does not support to save a checkpoint in an archived format such as `zip` or `tar`. Archiving checkpoints before sending them to GCS can help reducing necessary storage and make the file structure on gcs simpler. It is not necessary to archive checkpoints.

In [None]:
def list_files(directory: str, recursive: bool = True):
    """Create a generator that lists all files in a directory.
    Optionally also recusrively all list files in subdirectories.

    Args:
        directory (str): List files in this directory.
        recursive (bool, optional): Option to recusrively list files
        in subdirectories. Defaults to True.

    Returns:
        Generator[str]: generator of filepaths in string format.
    """
    if recursive:
        filepaths = Path(directory).rglob('*')
    else:
        filepaths = Path(directory).glob('*')

    return (str(child) for child in filepaths if child.is_file())


def clear_directory(directory: str):
    """Removes all files and subdirectories in a given directory.

    Args:
        directory (str): Path to the directory that will be cleared.
    """
    shutil.rmtree(directory)
    os.mkdir(directory)


def zip_directory(directory: str,
                  output_file_path: str,
                  archive_format: str = 'zip'):
    """Archives a directory with all its files and subdirectories

    Args:
        directory (str): Directory that will be archived.
        output_file_path (str): Path where the checkpoint will be saved.
        archive_format (str, optional): Type of archival. Defaults to 'zip'.
    """
    # Extract the filepath without the file extension,
    # (shutil automatically adds the extension)
    file_path = os.path.splitext(output_file_path)[0]
    shutil.make_archive(file_path, archive_format, directory)

## Helper functions for moving files/directories on GCS

Helper functions move a file or directory on GCS so it can be new file or directory can take its place without losing the previous one. This is useful to save the latest version of a file or directory in a fixed location.

In [None]:
def move_gcs_file(old_file_path: str,
                  new_file_path: str,
                  storage.bucket.Bucket):
    """Move a file on gcs to a new location.

    Args:
        old_file_path (str): Path to the original file location.
        new_file_path (str): Path to the new location of the file.
        bucket (bucket): GCS bucket that contains the file.
    """
    blob = bucket.blob(old_file_path)
    if blob.exists():
        bucket.rename_blob(blob, new_file_path)


def move_gcs_directory(old_directory_path: str,
                       new_directory_path: str,
                       storage.bucket.Bucket):
    """Move a directory on gcs to a new location

    Args:
        old_directory_path (str): Path to the original directory location.
        new_directory_path (str): Path to the new location of the directory.
        bucket (bucket): GCS bucket that contains the directory.
    """
    blobs = bucket.list_blobs(prefix=old_directory_path)
    for blob in blobs:
        file_path = blob.name.replace(old_directory_path, new_directory_path)
        bucket.rename_blob(blob, file_path)

## Helper functions to download files/directories from GCS

Helper functions to download files, directories and archived files from GCS to the local file system, these can be used to restore a checkpoint from GCS.

In [None]:
def download_gcs_file(gcs_file_path: str,
                      local_file_path: str,
                      bucket: storage.bucket.Bucket):
    """Download a file from gcs to a given location.

    Args:
        gcs_file_path (str): Path to the file on GCS.
        local_file_path (str): Local path to save the file.
        bucket (bucket): GCS bucket where the file is stored.
    """
    blob = bucket.blob(gcs_file_path)
    blob.download_to_filename(local_file_path)


def download_gcs_directory(gcs_directory_path: str,
                           local_directory_path: str,
                           bucket: storage.bucket.Bucket):
    """Download a directory with all its subdirectories from GCS.
    (The directory structure is recreated locally)

    Args:
        gcs_directory_path (str): Path to the directory on GCS.
        local_directory_path (str): Local path where the
        directory will be stored.
        bucket (bucket): GCS bucket where the directory is stored.
    """
    blobs = bucket.list_blobs(prefix=gcs_directory_path)

    for blob in blobs:
        local_file_path = blob.name.replace(
            gcs_directory_path,
            local_directory_path
        )
        # Recreate the directory structure if necessary.
        path = Path(local_file_path).parents[0]
        path.mkdir(parents=True, exist_ok=True)
        blob.download_to_filename(local_file_path)


def download_gcs_archive(gcs_file_path: str,
                         local_file_path: str,
                         unpack_directory: str,
                         bucket: storage.bucket.Bucket):
    """Downloads an archived file from GCS and unpacks it.

    Args:
        gcs_file_path (str): Path to the archived file on GCS.
        local_file_path (str): Local path to save the archived file.
        unpack_directory (str): Local directory where the
        archived file will be unpacked.
        bucket (bucket): GCS bucket where the archive is stored.
    """
    download_gcs_file(gcs_file_path, local_file_path, bucket)
    shutil.unpack_archive(local_file_path, unpack_directory)

## Helper functions to send files/directories to GCS

Helper functions to store local files, directories and archived files on GCS, these can be used to backup a checkpoint.

In [None]:
def write_file_to_gcs(local_file_path: str,
                      gcs_file_path: str,
                      bucket: storage.bucket.Bucket):
    """Writes a single file to gcs. The effect of uploading to
    an existing blob depends on the “versioning” and “lifecycle”
    policies defined on the blob’s bucket. In the absence of
    those policies, upload will overwrite any existing contents.
    (https://googleapis.dev/python/storage/latest/blobs.html)

    Args:
        local_file_path (str): Path to the local file that will be sent to GCS.
        gcs_file_path (str): Location on gcs where the file will be stored.
        bucket (bucket): GCS bucket where the file will be stored.
    """
    blob = bucket.blob(gcs_file_path)
    blob.upload_from_filename(Path(local_file_path).absolute())


def write_directory_to_gcs(local_directory_path: str,
                           gcs_directory_path: str,
                           bucket: storage.bucket.Bucket,
                           recursive: bool = True):
    """Write a directory (optionally all its subdirectories) to gcs.
    The effect of uploading to an existing blob depends on the
    “versioning” and “lifecycle” policies defined on the blob’s bucket. In
    the absence of those policies, upload will overwrite any existing contents.
    (https://googleapis.dev/python/storage/latest/blobs.html)

    Args:
        local_directory_path (str): Local path to the directory
        that will be sent to GCS.
        gcs_directory_path (str): Path on the GCS bucket where
        the directory will be stored.
        bucket (bucket): GCS bucket where the directory will be stored.
        recursive (bool, optional): Recursively search
        subdirecties. Defaults to True.
    """
    for local_file_path in list_files(local_directory_path, recursive):
        gcs_file_path = local_file_path.replace(
            local_directory_path,
            gcs_directory_path
        )
        blob = bucket.blob(gcs_file_path)
        blob.upload_from_filename(local_file_path)

In [None]:
def save_tf_model_archive(model_state: object, local_checkpoint: str):
    """Store the model state as a checkpoint in a given
    local directory. This method has to be implemented
    as there are many different ways to do this.

    Args:
        model_state (object): _description_
        local_checkpoint (str): _description_

    Raises:
        NotImplementedError: _description_
    """
    raise NotImplementedError()


def save_tf_checkpoint(model_state: object,
                       local_directory: str,
                       local_checkpoint: str,
                       gcs_latest_checkpoint: str,
                       gcs_checkpoint_store: str,
                       bucket: storage.bucket.Bucket):
    """Send a checkpoint in archived form to GCS,
    moves the current latest checkpoint to a different,
    specified location. Use your own method to store the
    model state in an archive.

    Args:
        model_state (object): Object that contains all
        checkpoint data that needs to be saved.
        local_directory (str): Local path to the directory
        where the checkpoint will be saved.
        local_checkpoint (str): Local path where
        the checkpoint will be saved.
        gcs_latest_checkpoint (str): Path where the latest
        checkpoint is saved on GCS.
        gcs_checkpoint_store (str): Path where the current
        last checkpoint on GCS will be moved to.
        bucket (bucket): GCS bucket where the checkpoint will be stored.
    """
    # clear the directory where the checkpoint is saved locally
    clear_directory(local_directory)
    move_gcs_file(gcs_latest_checkpoint, gcs_checkpoint_store, bucket)
    save_tf_model_archive(model_state, local_checkpoint)
    write_file_to_gcs(local_checkpoint, gcs_latest_checkpoint, bucket)


## Tensorflow example functions to load a checkpoint

In [None]:
def load_tf_checkpoint_zip(gcs_latest_checkpoint: str,
                           arhived_checkpoint: str,
                           local_checkpoint: str,
                           storage.bucket.Bucket):
    """Load an archived checkpoint from GCS.

    Args:
        gcs_latest_checkpoint (str): Path to the archived checkpoint on GCS.
        arhived_checkpoint (str): Local path to save the archived checkpoint.
        local_checkpoint (str): Local path to the unpacked checkpoint.
        bucket (bucket): GCS bucket that contains the checkpoint.

    Returns:
        tf.keras.Model: Loaded model, ready to to continue training.
    """
    download_gcs_archive(
        gcs_latest_checkpoint, arhived_checkpoint,
        local_checkpoint, bucket)
    model = tf.keras.models.load_model(local_checkpoint)
    return model


def load_tf_checkpoint_directory(gcs_latest_checkpoint: str,
                                 local_checkpoint: str,
                                 storage.bucket.Bucket):
    """Load checkpoint saved as directory from GCS

    Args:
        gcs_latest_checkpoint (str): Path to the checkpoint directory on GCS.
        local_checkpoint (str): Path where the checkpoint will be saved
        locally.
        bucket (bucket): GCS bucket that contains the checkpoint.

    Returns:
        tf.keras.Model: Loaded model, ready to to continue training.
    """
    download_gcs_directory(
        gcs_latest_checkpoint,
        gcs_latest_checkpoint,
        bucket)
    model = tf.keras.models.load_model(local_checkpoint)
    return model

# Training example

We will train a simple MLP classifier using the MNIST dataset. We will start from a given checkpoint on GCS and will continue training. Each epoch a checkpoint is made and sent to GCS. You can use this example as guideline for the structure of the boilerplate.

Load the dataset.

In [None]:
mnist_dataset = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist_dataset.load_data()

x_train_normalized = x_train / 255
x_test_normalized = x_test / 255

Create a model and optimizer.

In [None]:
model = tf.keras.models.Sequential()

# input layer
model.add(tf.keras.layers.Flatten(input_shape=x_train_normalized.shape[1:]))
model.add(tf.keras.layers.Dense(
    units=128,
    activation=tf.keras.activations.relu,
    kernel_regularizer=tf.keras.regularizers.l2(0.002)
))

# hidden layers
model.add(tf.keras.layers.Dense(
    units=128,
    activation=tf.keras.activations.relu,
    kernel_regularizer=tf.keras.regularizers.l2(0.002)
))

# output layers
model.add(tf.keras.layers.Dense(
    units=10,
    activation=tf.keras.activations.softmax
))

adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(
    optimizer=adam_optimizer,
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['accuracy']
)

Train the model and save intermediate steps to gcs.

In [None]:
storage_client = storage.Client(project="centered-rope-339915")
bucket = storage_client.get_bucket("pvmb-training-checkpoints")
local_directory = "checkpoint-buffer"
gcs_directory = "tf-mnist-checkpoints"
checkpoint_name = "latest-checkpoint.zip"
local_checkpoint = checkpoint_name
gcs_latest_checkpoint = f"{gcs_directory}/{checkpoint_name}"

#os.mkdir(local_directory)

Load a checkpoint from GCS.

In [None]:
download_gcs_directory(f"{gcs_directory}/checkpoint-2", local_directory, bucket)
model = tf.keras.models.load_model(local_directory)

In [None]:
epochs = 5
for epoch in range(epochs):
    training_history = model.fit(
        x_train_normalized,
        y_train,
        epochs=1,
        validation_data=(x_test_normalized, y_test),
    )
    
    # save the model
    save_tf_checkpoint(
        model,
        local_directory,
        local_checkpoint,
        gcs_latest_checkpoint,
        f"{gcs_directory}/checkpoint-{epoch}.zip",
        bucket
    )