In [None]:
!pip install -q -U keras_cv
!pip install -q -U tensorflow 

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/634.9 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m286.7/634.9 KB[0m [31m8.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m634.9/634.9 KB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import tensorflow as tf 

from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder

You do not have Waymo Open Dataset installed, so KerasCV Waymo metrics are not available.


## Basic Cluster Setup
* Install Runhouse and latest SkyPilot version
* Set up LambdaLabs or BYO credentials
* Instantiate and launch cluster

In [None]:
!pip install runhouse
!pip install git+https://github.com/skypilot-org/skypilot.git

In [None]:
import runhouse as rh

INFO | 2023-03-10 04:26:46,264 | No auth token provided, so not using RNS API to save and load configs


### Option 1: On Demand Cluster that spins up/down for you

In [None]:
# To see instructions on how to set up cloud credentials. Skip if using your own cluster
!sky check

In [None]:
# For Lambda Labs
# First get your API key from https://cloud.lambdalabs.com/api-keys
# and create the file lambda_keys with the following line
# api_key = [YOUR API KEY]

!mkdir ~/.lambda_cloud/
!mv lambda_keys ~/.lambda_cloud/lambda_keys
!sky check

[33mSkyPilot collects usage data to improve its services. `setup` and `run` commands are not collected to ensure privacy.
Usage logging can be disabled by setting the environment variable SKYPILOT_DISABLE_USAGE_COLLECTION=1.[0m
Checking credentials to enable clouds for SkyPilot.
  [31m[1mAWS: disabled[0m          
    Reason: AWS credentials are not set. Run the following commands:
      $ pip install boto3
      $ aws configure
    For more info: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-quickstart.html
  [31m[1mAzure: disabled[0m          
    Reason: ~/.azure/msal_token_cache.json does not exist. Run the following commands:
      $ az login
      $ az account set -s <subscription_id>
    For more info: https://docs.microsoft.com/en-us/cli/azure/get-started-with-azure-cli
  [31m[1mGCP: disabled[0m          
    Reason: GCP tools are not installed or credentials are not set. Run the following commands:
      $ pip install google-api-python-client
      

In [None]:
# Launch on-demand Lambda cluster
gpu = rh.cluster(name='rh-a100', instance_type='A100:1', provider='lambda')
gpu.up_if_not()

# set amount of time (min) of inactivity to shut down cluster, or -1 to keep up indefinitely (Default: 30 min)
gpu.autostop_mins = -1

### Option 2: Bring-your-own cluster, by passing in IPs and SSH creds

In [None]:
# Uncomment for bring-your-own cluster. This can be a cluster spun up by Lambda Labs
# gpu = rh.cluster(name='byo-lambda', ips=['<ip_address>'],
#                  ssh_creds={'ssh_user':'ubuntu', 'ssh_private_key': '~/.ssh/id_rsa'})

INFO | 2023-03-09 14:48:24,824 | Running command on byo-lambda: ray start --head
INFO | 2023-03-09 14:48:29,587 | Running command on byo-lambda: mkdir -p ~/.rh; touch ~/.rh/cluster_config.yaml; echo '{"name": "~/byo-lambda", "resource_type": "cluster", "resource_subtype": "Cluster", "ips": ["132.145.193.245"], "ssh_creds": {"ssh_user": "ubuntu", "ssh_private_key": "~/.ssh/id_rsa"}}' > ~/.rh/cluster_config.yaml


### Set up Tensorflow with GPU Support

Install Tensorflow, and check that it has GPU set up properly.

In [None]:
command = "conda install -y -c conda-forge cudatoolkit=11.2.2 cudnn=8.1.0; \
            mkdir -p $CONDA_PREFIX/etc/conda/activate.d; \
            echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/' > $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh; \
            python3 pip install tensorflow"
gpu.run([command])
gpu.restart_grpc_server()  # restart server to load env variables set above

In [None]:
gpu.run(['python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices(\'GPU\'))"'])

INFO | 2023-03-10 04:36:46,362 | Running command on rh-a100: python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"
2023-03-10 04:36:47.825877: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-10 04:36:48.543604: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/ubuntu/miniconda3/lib/:/home/ubuntu/miniconda3/lib/
2023-03-10 04:36:48.543662: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: ca

[(0,
  '')]

## Dreambooth Setup

### Download the instance and class images

In [None]:
tf.keras.utils.get_file(
    origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz",
    untar=True
)
tf.keras.utils.get_file(
    origin="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz",
    untar=True
)

Downloading data from https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/instance-images.tar.gz
Downloading data from https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/class-images.tar.gz


'/root/.keras/datasets/class-images'

In [None]:
instance_images_root = '~/.keras/datasets/instance-images'
class_images_root = '~/.keras/datasets/class-images'

# sync images to the cluster using Runhouse
rh.folder(path=instance_images_root).to(system=gpu, path=instance_images_root)
rh.folder(path=class_images_root).to(system=gpu, path=class_images_root)

INFO | 2023-03-10 04:36:54,461 | Creating new file folder: /root/.keras/datasets/instance-images
INFO | 2023-03-10 04:36:54,474 | Copying folder from file:///root/.keras/datasets/instance-images to: rh-a100, with path: ~/.keras/datasets/instance-images
INFO | 2023-03-10 04:36:54,476 | Creating new ssh folder: .keras/datasets/instance-images
INFO | 2023-03-10 04:36:54,536 | Opening SSH connection to 150.136.66.243, port 22
INFO | 2023-03-10 04:36:54,563 | [conn=0] Connected to SSH server at 150.136.66.243, port 22
INFO | 2023-03-10 04:36:54,565 | [conn=0]   Local address: 172.28.0.12, port 59042
INFO | 2023-03-10 04:36:54,566 | [conn=0]   Peer address: 150.136.66.243, port 22
INFO | 2023-03-10 04:36:54,675 | [conn=0] Beginning auth for user ubuntu
INFO | 2023-03-10 04:36:54,776 | [conn=0] Auth for user ubuntu succeeded
INFO | 2023-03-10 04:36:54,780 | [conn=0, chan=0] Requesting new SSH session
INFO | 2023-03-10 04:36:55,181 | [conn=0, chan=0]   Subsystem: sftp
INFO | 2023-03-10 04:36:5

<runhouse.rns.folders.folder.Folder at 0x7f91943f9190>

In [None]:
def get_image_paths(folder):
    from pathlib import Path
    import os

    abs_folder = Path(folder).expanduser()
    files = os.listdir(abs_folder)
    files = [os.path.join(abs_folder, file) for file in files]
    return files

# Get image paths on the cluster
get_image_paths_gpu = rh.function(fn=get_image_paths).to(system=gpu)

INFO | 2023-03-10 04:36:59,351 | Writing out function function to /content/get_image_paths_fn.py as functions serialized in notebooks are brittle. Please make sure the function does not rely on any local variables, including imports (which should be moved inside the function body).
INFO | 2023-03-10 04:36:59,356 | Setting up Function on cluster.
INFO | 2023-03-10 04:36:59,359 | Creating new file folder: /content
INFO | 2023-03-10 04:36:59,362 | Copying local package content to cluster <rh-a100>
INFO | 2023-03-10 04:36:59,365 | Creating new ssh folder: content
INFO | 2023-03-10 04:37:02,209 | Installing packages on cluster rh-a100: ['./']
INFO | 2023-03-10 04:37:02,267 | Function setup complete.


In [None]:
instance_image_paths = get_image_paths_gpu(instance_images_root)
class_image_paths = get_image_paths_gpu(class_images_root)

print(instance_image_paths[:5])
print(class_image_paths[:5])

INFO | 2023-03-10 04:37:05,146 | Running get_image_paths via gRPC
INFO | 2023-03-10 04:37:05,199 | Time to send message: 0.05 seconds
INFO | 2023-03-10 04:37:05,200 | Running get_image_paths via gRPC
INFO | 2023-03-10 04:37:05,306 | Time to send message: 0.1 seconds
['/home/ubuntu/.keras/datasets/instance-images/alvan-nee-bQaAJCbNq3g-unsplash.jpeg', '/home/ubuntu/.keras/datasets/instance-images/alvan-nee-eoqnr8ikwFE-unsplash.jpeg', '/home/ubuntu/.keras/datasets/instance-images/alvan-nee-9M0tSjb-cpA-unsplash.jpeg', '/home/ubuntu/.keras/datasets/instance-images/alvan-nee-brFsZ7qszSY-unsplash.jpeg', '/home/ubuntu/.keras/datasets/instance-images/alvan-nee-Id1DBHv4fbg-unsplash.jpeg']
['/home/ubuntu/.keras/datasets/class-images/cae1100cdc58a2436697ba178cd3deaed0b43064.jpg', '/home/ubuntu/.keras/datasets/class-images/9c54d0af0a22d05914b5894b55817fa33eac80d4.jpg', '/home/ubuntu/.keras/datasets/class-images/27b9d79bdc218d483c365e40961b77106d986815.jpg', '/home/ubuntu/.keras/datasets/class-image

### Prepare captions

In [None]:
# match the number of instance images we're using
new_instance_image_paths = []
for index in range(len(class_image_paths)):
    instance_image = instance_image_paths[index % len(instance_image_paths)]
    new_instance_image_paths.append(instance_image)

# repeat the prompts / captions per images. 
unique_id = "sks"
class_label = "dog"

instance_prompt = f"a photo of {unique_id} {class_label}" 
instance_prompts = [instance_prompt] * len(new_instance_image_paths)

class_prompt = f"a photo of {class_label}"
class_prompts = [class_prompt] * len(class_image_paths)

In [None]:
# tokenize the text
import numpy as np
import itertools

padding_token = 49407
max_prompt_length = 77
tokenizer = SimpleTokenizer() 

def process_text(caption):
    tokens = tokenizer.encode(caption)
    tokens = tokens + [padding_token] * (max_prompt_length - len(tokens))
    return np.array(tokens)

tokenized_texts = np.empty((len(instance_prompts) + len(class_prompts), max_prompt_length))
for i, caption in enumerate(itertools.chain(instance_prompts, class_prompts)):
    tokenized_texts[i] = process_text(caption)

Downloading data from https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true


In [None]:
# pre-compute the text embeddings to save some memory during training.
# this should be run on a GPU
def encode_text(tokenized_texts, max_prompt_length):
    import tensorflow as tf
    from keras_cv.models.stable_diffusion.text_encoder import TextEncoder

    POS_IDS = tf.convert_to_tensor([list(range(max_prompt_length))], dtype=tf.int32)
    text_encoder = TextEncoder(max_prompt_length)

    gpus = tf.config.list_logical_devices("GPU")

    # Ensure the computation takes place on a GPU.
    with tf.device(gpus[0].name):
        embedded_text = text_encoder(
            [tf.convert_to_tensor(tokenized_texts), POS_IDS], training=False
        ).numpy()

    del text_encoder
    return embedded_text

In [None]:
# send function to be run on GPU defined above
encode_text_gpu = rh.function(fn=encode_text, system=gpu, reqs=['tensorflow', 'keras_cv', 'imutils', 'opencv-python'])
embedded_text = encode_text_gpu(tokenized_texts, max_prompt_length)

INFO | 2023-03-10 05:03:54,468 | Writing out function function to /content/encode_text_fn.py as functions serialized in notebooks are brittle. Please make sure the function does not rely on any local variables, including imports (which should be moved inside the function body).
INFO | 2023-03-10 05:03:54,475 | Setting up Function on cluster.
INFO | 2023-03-10 05:03:54,479 | Creating new file folder: /content
INFO | 2023-03-10 05:03:54,485 | Copying local package content to cluster <rh-a100>
INFO | 2023-03-10 05:03:54,487 | Creating new ssh folder: content
INFO | 2023-03-10 05:03:55,490 | Installing packages on cluster rh-a100: ['tensorflow', 'keras_cv', 'imutils', 'opencv-python', './']
INFO | 2023-03-10 05:04:01,325 | Function setup complete.
INFO | 2023-03-10 05:04:01,327 | Running encode_text via gRPC
INFO | 2023-03-10 05:04:17,059 | Time to send message: 15.73 seconds


In [None]:
embedded_text.shape

(400, 77, 768)

### Prepare the images

In [None]:
def assemble_dataset(instance_paths, class_paths, embedded_texts, save_path, batch_size=1):
    import keras_cv
    import tensorflow as tf
    import os
    from pathlib import Path

    resolution = 512
    auto = tf.data.AUTOTUNE

    augmenter = keras_cv.layers.Augmenter(
        layers=[
            keras_cv.layers.CenterCrop(resolution, resolution),
            keras_cv.layers.RandomFlip(),
            tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
        ]
    )


    def process_image(image_path, tokenized_text):
        image = tf.io.read_file(image_path)
        image = tf.io.decode_png(image, 3)
        image = tf.image.resize(image, (resolution, resolution))
        return image, tokenized_text


    def apply_augmentation(image_batch, embedded_tokens):
        return augmenter(image_batch), embedded_tokens


    def prepare_dict(instance_only=True):
        def fn(image_batch, embedded_tokens):
            if instance_only:
                batch_dict = {
                    "instance_images": image_batch,
                    "instance_embedded_texts": embedded_tokens,
                }
                return batch_dict
            else:
                batch_dict = {
                    "class_images": image_batch,
                    "class_embedded_texts": embedded_tokens,
                }
                return batch_dict
        return fn


    def assemble(image_paths, embedded_texts, instance_only, batch_size):  
        dataset = tf.data.Dataset.from_tensor_slices(
            (image_paths, embedded_texts)
        )
        dataset = dataset.map(process_image, num_parallel_calls=auto)
        dataset = dataset.shuffle(5, reshuffle_each_iteration=True)
        dataset = dataset.batch(batch_size)
        dataset = dataset.map(apply_augmentation, num_parallel_calls=auto)

        prepare_dict_fn = prepare_dict(instance_only=instance_only)
        dataset = dataset.map(prepare_dict_fn, num_parallel_calls=auto)
        return dataset
    
    instance_dataset = assemble(instance_paths, embedded_texts[:len(instance_paths)], True, batch_size)
    class_dataset = assemble(class_paths, embedded_texts[len(instance_paths):], False, batch_size)
    train_dataset = tf.data.Dataset.zip((instance_dataset, class_dataset))

    abs_path = str(Path(save_path).expanduser())
    tf.data.Dataset.save(train_dataset, abs_path)
    return abs_path

In [None]:
assemble_dataset_gpu = rh.function(fn=assemble_dataset).to(system=gpu)
save_data_path = '~/.keras/datasets/train_dataset'
train_dataset_path = assemble_dataset_gpu(new_instance_image_paths, class_image_paths, embedded_text, save_data_path)

INFO | 2023-03-10 04:38:14,844 | Writing out function function to /content/assemble_dataset_fn.py as functions serialized in notebooks are brittle. Please make sure the function does not rely on any local variables, including imports (which should be moved inside the function body).
INFO | 2023-03-10 04:38:14,853 | Setting up Function on cluster.
INFO | 2023-03-10 04:38:14,857 | Creating new file folder: /content
INFO | 2023-03-10 04:38:14,859 | Copying local package content to cluster <rh-a100>
INFO | 2023-03-10 04:38:14,863 | Creating new ssh folder: content
INFO | 2023-03-10 04:38:15,211 | Installing packages on cluster rh-a100: ['./']
INFO | 2023-03-10 04:38:15,271 | Function setup complete.
INFO | 2023-03-10 04:38:15,275 | Running assemble_dataset via gRPC
INFO | 2023-03-10 04:38:29,864 | Time to send message: 14.48 seconds


## Dreambooth Training

In [None]:
# To be run on GPU
def train_dreambooth(resolution, max_prompt_length, use_mp, opt_args, dataset_path, ckpt_path):
    import math
    import os
    import tensorflow as tf
    import tensorflow.experimental.numpy as tnp

    from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
    from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
    from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler

    tf.keras.mixed_precision.set_global_policy("mixed_float16")

    class DreamBoothTrainer(tf.keras.Model):
        # Reference: https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
        def __init__(
            self,
            diffusion_model,
            vae,
            noise_scheduler,
            use_mixed_precision=False,
            prior_loss_weight=1.0,
            max_grad_norm=1.0,
            **kwargs
        ):
            super().__init__(**kwargs)

            self.diffusion_model = diffusion_model
            self.vae = vae
            self.noise_scheduler = noise_scheduler
            self.prior_loss_weight = prior_loss_weight
            self.max_grad_norm = max_grad_norm

            self.use_mixed_precision = use_mixed_precision
            self.vae.trainable = False

        def train_step(self, inputs):
            instance_batch = inputs[0]
            class_batch = inputs[1]

            instance_images = instance_batch["instance_images"]
            instance_embedded_text = instance_batch["instance_embedded_texts"]
            class_images = class_batch["class_images"]
            class_embedded_text = class_batch["class_embedded_texts"]

            images = tf.concat([instance_images, class_images], 0)
            embedded_texts = tf.concat([instance_embedded_text, class_embedded_text], 0)
            batch_size = tf.shape(images)[0]

            with tf.GradientTape() as tape:
                # Project image into the latent space and sample from it.
                latents = self.sample_from_encoder_outputs(self.vae(images, training=False))
                # Know more about the magic number here:
                # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
                latents = latents * 0.18215

                # Sample noise that we'll add to the latents.
                noise = tf.random.normal(tf.shape(latents))

                # Sample a random timestep for each image.
                timesteps = tnp.random.randint(
                    0, self.noise_scheduler.train_timesteps, (batch_size,)
                )

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process).
                noisy_latents = self.noise_scheduler.add_noise(
                    tf.cast(latents, noise.dtype), noise, timesteps
                )

                # Get the target for loss depending on the prediction type
                # just the sampled noise for now.
                target = noise  # noise_schedule.predict_epsilon == True

                # Predict the noise residual and compute loss.
                timestep_embedding = tf.map_fn(
                    lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
                )
                model_pred = self.diffusion_model(
                    [noisy_latents, timestep_embedding, embedded_texts], training=True
                )
                loss = self.compute_loss(target, model_pred)
                if self.use_mixed_precision:
                    loss = self.optimizer.get_scaled_loss(loss)

            # Update parameters of the diffusion model.
            trainable_vars = self.diffusion_model.trainable_variables
            gradients = tape.gradient(loss, trainable_vars)
            if self.use_mixed_precision:
                gradients = self.optimizer.get_unscaled_gradients(gradients)
            gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
            self.optimizer.apply_gradients(zip(gradients, trainable_vars))

            return {m.name: m.result() for m in self.metrics}

        def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
            half = dim // 2
            log_max_preiod = tf.math.log(tf.cast(max_period, tf.float32))
            freqs = tf.math.exp(
                -log_max_preiod * tf.range(0, half, dtype=tf.float32) / half
            )
            args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
            embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
            return embedding

        def sample_from_encoder_outputs(self, outputs):
            mean, logvar = tf.split(outputs, 2, axis=-1)
            logvar = tf.clip_by_value(logvar, -30.0, 20.0)
            std = tf.exp(0.5 * logvar)
            sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
            return mean + std * sample

        def compute_loss(self, target, model_pred):
            # Chunk the noise and model_pred into two parts and compute the loss
            # on each part separately.
            # Since the first half of the inputs has instance samples and the second half
            # has class samples, we do the chunking accordingly. 
            model_pred, model_pred_prior = tf.split(model_pred, num_or_size_splits=2, axis=0)
            target, target_prior = tf.split(target, num_or_size_splits=2, axis=0)

            # Compute instance loss.
            loss = self.compiled_loss(target, model_pred)

            # Compute prior loss.
            prior_loss = self.compiled_loss(target_prior, model_pred_prior)

            # Add the prior loss to the instance loss.
            loss = loss + self.prior_loss_weight * prior_loss
            return loss

        def save_weights(self, filepath, overwrite=True, save_format=None, options=None):
            # Overriding this method will allow us to use the `ModelCheckpoint`
            # callback directly with this trainer class. In this case, it will
            # only checkpoint the `diffusion_model` since that's what we're training
            # during fine-tuning.
            self.diffusion_model.save_weights(
                filepath=filepath,
                overwrite=overwrite,
                save_format=save_format,
                options=options,
            )

    image_encoder = ImageEncoder(resolution, resolution)
    diffusion_model = DiffusionModel(resolution, resolution, max_prompt_length)
    optimizer = tf.keras.optimizers.experimental.AdamW(
        learning_rate=opt_args['lr'],
        weight_decay=opt_args['weight_decay'],
        beta_1=opt_args['beta_1'],
        beta_2=opt_args['beta_2'],
        epsilon=opt_args['epsilon'],
    )

    dreambooth_trainer = DreamBoothTrainer(
        diffusion_model=diffusion_model,
        vae=tf.keras.Model(
                image_encoder.input,
                image_encoder.layers[-2].output,
            ),
        noise_scheduler=NoiseScheduler(),
        use_mixed_precision=use_mp,
    )
    dreambooth_trainer.compile(optimizer=optimizer, loss="mse")

    train_dataset = tf.data.Dataset.load(dataset_path)
    num_update_steps_per_epoch = train_dataset.cardinality()
    max_train_steps = 800
    epochs =  math.ceil(max_train_steps / num_update_steps_per_epoch)
    print(f"Training for {epochs} epochs.")

    ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
        ckpt_path,
        save_weights_only=True,
        monitor="loss",
        mode="min",
    )
    gpus = tf.config.list_logical_devices("GPU")

    # Ensure the computation takes place on a GPU.
    with tf.device(gpus[0].name):
        dreambooth_trainer.fit(train_dataset, epochs=epochs, callbacks=[ckpt_callback])
    return os.path.abspath(ckpt_path)

In [None]:
use_mp = True # Set it to False if you're not using a GPU with tensor cores.
resolution = 512
model_save_path = '~/.keras/models/dreambooth_trainer'

# These hyperparameters come from this tutorial by Hugging Face:
# https://github.com/huggingface/diffusers/tree/main/examples/dreambooth
optimizer_params = {
    'lr': 5e-6,
    'beta_1': 0.9,
    'beta_2': 0.999,
    'weight_decay': (1e-2,),
    'epsilon': 1e-08,
}
ckpt_path = "dreambooth-unet.h5"

# set up libdevice.10.bc to be discoverable by tensorflow
gpu.run(['cp /usr/lib/cuda/nvvm/libdevice/libdevice.10.bc .'])
train_dreambooth_gpu = rh.function(fn=train_dreambooth, system=gpu)
ckpt_path_gpu = train_dreambooth_gpu(resolution, max_prompt_length, use_mp, optimizer_params, train_dataset_path, ckpt_path)

INFO | 2023-03-10 05:41:23,101 | Writing out function function to /content/train_dreambooth_fn.py as functions serialized in notebooks are brittle. Please make sure the function does not rely on any local variables, including imports (which should be moved inside the function body).
INFO | 2023-03-10 05:41:23,113 | Setting up Function on cluster.
INFO | 2023-03-10 05:41:23,116 | Creating new file folder: /content
INFO | 2023-03-10 05:41:23,119 | Copying local package content to cluster <rh-a100>
INFO | 2023-03-10 05:41:23,123 | Creating new ssh folder: content
INFO | 2023-03-10 05:41:23,397 | Installing packages on cluster rh-a100: ['./']
INFO | 2023-03-10 05:41:23,550 | Function setup complete.
INFO | 2023-03-10 05:41:23,556 | Running train_dreambooth via gRPC
INFO | 2023-03-10 05:49:35,949 | Time to send message: 492.39 seconds


In [None]:
ckpt_path_gpu

'/home/ubuntu/dreambooth-unet.h5'

In [None]:
# to terminate the instance from Colab, or you can go into lambdalabs website to manually terminate.
# !sky down rh-a100