In [2]:

# ============================================================================
# Install required libraries
# ============================================================================
#@title Install required libraries { display-mode: "form" }
# ============================================================================
# Install required libraries
# ============================================================================
#@title Install required libraries { display-mode: "form" }
!pip install -q tensorflow-datasets tensorflow-hub matplotlib

# ============================================================================
# Imports and global settings
# ============================================================================
#@title Imports and helpers { display-mode: "form" }
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

print('TensorFlow version:', tf.__version__)

# Reproducibility (best-effort)
SEED = 42
tf.random.set_seed(SEED)
np.random.seed(SEED)


TensorFlow version: 2.19.0


In [5]:
# ============================================================================
# Parameters
# ============================================================================
#@title Training parameters { display-mode: "form" }
BATCH_SIZE = 1  # CycleGAN typically uses batch size 1
IMG_SIZE = 256
EPOCHS = 10  # exactly 25 epochs as requested
LAMBDA_CYCLE = 10.0
LAMBDA_ID = 0.5 * LAMBDA_CYCLE  # identity loss multiplier (common practice)
LR = 2e-4
BETA_1 = 0.5
MAX_TRAIN_IMAGES = 200  # use first 200 images of each domain for faster runs

UNET_SAVE_PATH = 'unet_horse2zebra_generator.keras'
RESNET_SAVE_PATH = 'resnet_horse2zebra_generator.keras'


In [6]:


# ============================================================================
# Data Loading & Preprocessing (Horse2Zebra only)
# ============================================================================
#@title Load and preprocess Horse2Zebra dataset { display-mode: "form" }
print('\nLoading Horse2Zebra dataset...')

dataset, info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

print('TrainA (horses):', info.splits['trainA'].num_examples)
print('TrainB (zebras):', info.splits['trainB'].num_examples)

# Preprocessing: resize to 256x256 and normalize to [-1,1]
@tf.function
def preprocess_train(image):
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = (tf.cast(image, tf.float32) / 127.5) - 1.0
    return image

@tf.function
def preprocess_test(image):
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = (tf.cast(image, tf.float32) / 127.5) - 1.0
    return image

# Take small subset for training to keep runtime reasonable in Colab
train_horses = train_horses.map(lambda x, y: tf.identity(x)).map(lambda x: preprocess_train(x)).take(MAX_TRAIN_IMAGES).shuffle(1000).batch(BATCH_SIZE)
train_zebras = train_zebras.map(lambda x, y: tf.identity(x)).map(lambda x: preprocess_train(x)).take(MAX_TRAIN_IMAGES).shuffle(1000).batch(BATCH_SIZE)

# Test sets for visualization (we will use test_horses)
test_horses = test_horses.map(lambda x, y: tf.identity(x)).map(lambda x: preprocess_test(x)).batch(1)

print('Data pipelines ready.')



Loading Horse2Zebra dataset...




Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/cycle_gan/horse2zebra/3.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/4 [00:00<?, ? splits/s]

Generating trainA examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cycle_gan/horse2zebra/incomplete.3M69DP_3.0.0/cycle_gan-trainA.tfrecord*..…

Generating trainB examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cycle_gan/horse2zebra/incomplete.3M69DP_3.0.0/cycle_gan-trainB.tfrecord*..…

Generating testA examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cycle_gan/horse2zebra/incomplete.3M69DP_3.0.0/cycle_gan-testA.tfrecord*...…

Generating testB examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cycle_gan/horse2zebra/incomplete.3M69DP_3.0.0/cycle_gan-testB.tfrecord*...…

Dataset cycle_gan downloaded and prepared to /root/tensorflow_datasets/cycle_gan/horse2zebra/3.0.0. Subsequent calls will reuse this data.
TrainA (horses): 1067
TrainB (zebras): 1334
Data pipelines ready.


In [7]:

# ============================================================================
# Model components: InstanceNorm, UNet generator, ResNet generator, PatchGAN
# ============================================================================
#@title Model definitions (UNet, ResNet, PatchGAN) { display-mode: "form" }

class InstanceNormalization(tf.keras.layers.Layer):
    """Simple Instance Normalization implementation."""
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.scale = self.add_weight(name='scale', shape=input_shape[-1:], initializer='random_normal', trainable=True)
        self.offset = self.add_weight(name='offset', shape=input_shape[-1:], initializer='zeros', trainable=True)

    def call(self, x):
        mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
        inv = tf.math.rsqrt(var + self.epsilon)
        normalized = (x - mean) * inv
        return self.scale * normalized + self.offset


In [11]:

# UNet generator (encoder-decoder with skip connections)
def build_unet_generator(input_shape=(256, 256, 3), ngf=64):
    """Builds a UNet-style generator for 256x256 images."""

    initializer = tf.random_normal_initializer(0., 0.02)
    inputs = tf.keras.Input(shape=input_shape)

    # --- Encoder (Downsampling) ---
    down_stack = [
        # (filters, apply_batchnorm)
        (ngf, False),   # (128x128x64)
        (ngf*2, True),  # (64x64x128)
        (ngf*4, True),  # (32x32x256)
        (ngf*8, True),  # (16x16x512)
        (ngf*8, True),  # (8x8x512)
        (ngf*8, True),  # (4x4x512)
        (ngf*8, True),  # (2x2x512)
        (ngf*8, True),  # (1x1x512)
    ]

    x = inputs
    skips = []
    for filters, bn in down_stack:
        x = tf.keras.layers.Conv2D(filters, 4, strides=2, padding='same',
                                   kernel_initializer=initializer, use_bias=not bn)(x)
        if bn:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.LeakyReLU(0.2)(x)
        skips.append(x)

    # --- Decoder (Upsampling) ---
    up_stack = [
        (ngf*8, True),  # (2x2x512)
        (ngf*8, True),  # (4x4x512)
        (ngf*8, True),  # (8x8x512)
        (ngf*8, False), # (16x16x512)
        (ngf*4, False), # (32x32x256)
        (ngf*2, False), # (64x64x128)
        (ngf, False),   # (128x128x64)
    ]

    skips = list(reversed(skips[:-1]))  # skip connections except the last (bottleneck)
    for (filters, bn), skip in zip(up_stack, skips):
        x = tf.keras.layers.Conv2DTranspose(filters, 4, strides=2, padding='same',
                                            kernel_initializer=initializer, use_bias=not bn)(x)
        if bn:
            x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    # --- Output layer ---
    x = tf.keras.layers.Conv2DTranspose(3, 4, strides=2, padding='same',
                                        kernel_initializer=initializer,
                                        activation='tanh')(x)

    return tf.keras.Model(inputs=inputs, outputs=x, name='UNet_Generator')


In [9]:

# ResNet generator (CycleGAN style with residual blocks)
def resnet_block(x_in, filters, use_instnorm=True):
    x = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(x_in)
    if use_instnorm:
        x = InstanceNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.Conv2D(filters, 3, padding='same', kernel_initializer='he_normal')(x)
    if use_instnorm:
        x = InstanceNormalization()(x)
    return tf.keras.layers.Add()([x_in, x])

def build_resnet_generator(input_shape=(IMG_SIZE, IMG_SIZE, 3), ngf=64, n_blocks=9):
    inputs = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(ngf, 7, padding='same', kernel_initializer='he_normal')(inputs)
    x = InstanceNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    # Downsampling
    n_down = 2
    for i in range(n_down):
        mult = 2**i
        x = tf.keras.layers.Conv2D(ngf * mult * 2, 3, strides=2, padding='same', kernel_initializer='he_normal')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)

    # Residual blocks
    mult = 2**n_down
    for i in range(n_blocks):
        x = resnet_block(x, ngf*mult)

    # Upsampling
    for i in range(n_down):
        mult = 2**(n_down - i)
        x = tf.keras.layers.Conv2DTranspose(int(ngf * mult / 2), 3, strides=2, padding='same', kernel_initializer='he_normal')(x)
        x = InstanceNormalization()(x)
        x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(3, 7, padding='same', kernel_initializer='he_normal', activation='tanh')(x)
    return tf.keras.Model(inputs=inputs, outputs=x, name='ResNet_Generator')


In [12]:

# PatchGAN discriminator (70x70)
def build_patchgan_discriminator(input_shape=(IMG_SIZE, IMG_SIZE, 3), ndf=64):
    inp = tf.keras.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(ndf, 4, strides=2, padding='same')(inp)
    x = tf.keras.layers.LeakyReLU(0.2)(x)

    x = tf.keras.layers.Conv2D(ndf*2, 4, strides=2, padding='same')(x)
    x = InstanceNormalization()(x)
    x = tf.keras.layers.LeakyReLU(0.2)(x)

    x = tf.keras.layers.Conv2D(ndf*4, 4, strides=2, padding='same')(x)
    x = InstanceNormalization()(x)
    x = tf.keras.layers.LeakyReLU(0.2)(x)

    x = tf.keras.layers.Conv2D(ndf*8, 4, strides=1, padding='same')(x)
    x = InstanceNormalization()(x)
    x = tf.keras.layers.LeakyReLU(0.2)(x)

    x = tf.keras.layers.Conv2D(1, 4, strides=1, padding='same')(x)
    return tf.keras.Model(inputs=inp, outputs=x, name='PatchGAN_Discriminator')

# Instantiate models (UNet generators and discriminators)
print('\nBuilding model instances...')
G_unet_A2B = build_unet_generator()  # horse -> zebra
G_unet_B2A = build_unet_generator()  # zebra -> horse (separate instance)
G_resnet = build_resnet_generator()
D_A = build_patchgan_discriminator()  # discriminator for horses
D_B = build_patchgan_discriminator()  # discriminator for zebras

print('Models built.')



Building model instances...
Models built.


In [13]:

# ============================================================================
# Losses and optimizers
# ============================================================================
#@title Loss functions and optimizers { display-mode: "form" }
mse = tf.keras.losses.MeanSquaredError()
mae = tf.keras.losses.MeanAbsoluteError()

def generator_loss_lsgan(fake_logits):
    return mse(tf.ones_like(fake_logits), fake_logits)

def discriminator_loss_lsgan(real_logits, fake_logits):
    real_loss = mse(tf.ones_like(real_logits), real_logits)
    fake_loss = mse(tf.zeros_like(fake_logits), fake_logits)
    return 0.5 * (real_loss + fake_loss)

def cycle_loss(real_image, cycled_image):
    return mae(real_image, cycled_image) * LAMBDA_CYCLE

def identity_loss(real_image, same_image):
    return mae(real_image, same_image) * LAMBDA_ID

# Optimizers
G_A_optimizer = tf.keras.optimizers.Adam(LR, beta_1=BETA_1)
G_B_optimizer = tf.keras.optimizers.Adam(LR, beta_1=BETA_1)
D_A_optimizer = tf.keras.optimizers.Adam(LR, beta_1=BETA_1)
D_B_optimizer = tf.keras.optimizers.Adam(LR, beta_1=BETA_1)


In [14]:

# ============================================================================
# Training loop (CycleGAN with UNet generators) - train both directions
# ============================================================================
#@title Training loop: Train CycleGAN with UNet generators for 25 epochs { display-mode: "form" }

@tf.function
def train_step(real_A, real_B):
    # real_A: horse, real_B: zebra
    with tf.GradientTape(persistent=True) as tape:
        # Generators forward
        fake_B = G_unet_A2B(real_A, training=True)
        cycled_A = G_unet_B2A(fake_B, training=True)

        fake_A = G_unet_B2A(real_B, training=True)
        cycled_B = G_unet_A2B(fake_A, training=True)

        # Identity
        same_A = G_unet_B2A(real_A, training=True)
        same_B = G_unet_A2B(real_B, training=True)

        # Discriminator outputs
        disc_real_A = D_A(real_A, training=True)
        disc_real_B = D_B(real_B, training=True)
        disc_fake_A = D_A(fake_A, training=True)
        disc_fake_B = D_B(fake_B, training=True)

        # Generator adversarial losses
        G_A_adv_loss = generator_loss_lsgan(disc_fake_B)
        G_B_adv_loss = generator_loss_lsgan(disc_fake_A)

        # Cycle & identity losses
        total_cycle_loss = cycle_loss(real_A, cycled_A) + cycle_loss(real_B, cycled_B)
        id_loss_A = identity_loss(real_A, same_A)
        id_loss_B = identity_loss(real_B, same_B)

        # Total generator losses
        G_A_loss = G_A_adv_loss + total_cycle_loss + id_loss_B + id_loss_A
        G_B_loss = G_B_adv_loss + total_cycle_loss + id_loss_A + id_loss_B

        # Discriminator losses
        D_A_loss = discriminator_loss_lsgan(disc_real_A, disc_fake_A)
        D_B_loss = discriminator_loss_lsgan(disc_real_B, disc_fake_B)

    # Compute gradients
    G_A_grads = tape.gradient(G_A_loss, G_unet_A2B.trainable_variables)
    G_B_grads = tape.gradient(G_B_loss, G_unet_B2A.trainable_variables)
    D_A_grads = tape.gradient(D_A_loss, D_A.trainable_variables)
    D_B_grads = tape.gradient(D_B_loss, D_B.trainable_variables)

    # Apply gradients
    G_A_optimizer.apply_gradients(zip(G_A_grads, G_unet_A2B.trainable_variables))
    G_B_optimizer.apply_gradients(zip(G_B_grads, G_unet_B2A.trainable_variables))
    D_A_optimizer.apply_gradients(zip(D_A_grads, D_A.trainable_variables))
    D_B_optimizer.apply_gradients(zip(D_B_grads, D_B.trainable_variables))

    return {
        'G_A_loss': G_A_loss, 'G_B_loss': G_B_loss,
        'D_A_loss': D_A_loss, 'D_B_loss': D_B_loss
    }


In [None]:

# Training driver
print('\nStarting training for', EPOCHS, 'epochs using', MAX_TRAIN_IMAGES, 'images from each domain...')
start_time = time.time()
train_dataset = tf.data.Dataset.zip((train_horses, train_zebras))

for epoch in range(1, EPOCHS + 1):
    print(f'Epoch {epoch}/{EPOCHS}')
    epoch_losses = []
    for step, (real_A, real_B) in enumerate(train_dataset):
        losses = train_step(real_A, real_B)
        epoch_losses.append({k: float(v) for k, v in losses.items()})
        if step % 50 == 0:
            print(f'  step {step} - G_A_loss: {losses["G_A_loss"]:.4f}, D_A_loss: {losses["D_A_loss"]:.4f}')
    if epoch_losses:
        avg = {k: np.mean([x[k] for x in epoch_losses]) for k in epoch_losses[0].keys()}
        print(' Epoch avg losses:', {k: float(v) for k, v in avg.items()})

print('Training completed in %.2f seconds' % (time.time() - start_time))

# G_unet_A2B is our trained horse->zebra generator
G_unet_trained = G_unet_A2B



Starting training for 25 epochs using 200 images from each domain...
Epoch 1/25
  step 0 - G_A_loss: 10.6946, D_A_loss: 0.5350
  step 50 - G_A_loss: 6.5551, D_A_loss: 0.2858
  step 100 - G_A_loss: 5.9493, D_A_loss: 0.2767
  step 150 - G_A_loss: 4.8231, D_A_loss: 0.2464
 Epoch avg losses: {'G_A_loss': 6.145173474550247, 'G_B_loss': 6.129569165706634, 'D_A_loss': 0.2517614558339119, 'D_B_loss': 0.23538692854344845}
Epoch 2/25
  step 0 - G_A_loss: 4.0668, D_A_loss: 0.1926
  step 50 - G_A_loss: 3.7719, D_A_loss: 0.2011
  step 100 - G_A_loss: 4.5257, D_A_loss: 0.2105
  step 150 - G_A_loss: 6.1011, D_A_loss: 0.1138
 Epoch avg losses: {'G_A_loss': 4.482466560602188, 'G_B_loss': 4.402515996694564, 'D_A_loss': 0.2157214138843119, 'D_B_loss': 0.17918569393455983}
Epoch 3/25
  step 0 - G_A_loss: 3.7721, D_A_loss: 0.1200
  step 50 - G_A_loss: 4.1633, D_A_loss: 0.2479
  step 100 - G_A_loss: 3.5451, D_A_loss: 0.1285
  step 150 - G_A_loss: 3.5351, D_A_loss: 0.3755
 Epoch avg losses: {'G_A_loss': 3.9

In [1]:

# ============================================================================
# ResNet generator: try to load TF-Hub pre-trained model for horse2zebra; fallback to local ResNet
# ============================================================================
#@title Load ResNet generator from TF-Hub (if available) or use fallback { display-mode: "form" }

resnet_loaded_from_hub = False
hub_urls = [
    'https://tfhub.dev/google/cyclegan/horse2zebra/1',
]
G_resnet_loaded = None
for url in hub_urls:
    try:
        print('Attempting to load TF-Hub module:', url)
        hub_mod = hub.load(url)
        # Try to wrap serving_default if available
        if hasattr(hub_mod, 'signatures') and 'serving_default' in hub_mod.signatures:
            serv = hub_mod.signatures['serving_default']
            def hub_forward(x):
                out = serv(tf.cast(x, tf.float32))
                return list(out.values())[0]
            inp = tf.keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
            out = tf.keras.layers.Lambda(lambda t: hub_forward(t))(inp)
            G_resnet_loaded = tf.keras.Model(inputs=inp, outputs=out, name='ResNet_Hub_Generator')
            resnet_loaded_from_hub = True
            print('Wrapped TF-Hub module as Keras model.')
            break
        else:
            print('TF-Hub module loaded but no serving_default signature to wrap.')
    except Exception as e:
        print('TF-Hub load failed for', url, '-', str(e))


Attempting to load TF-Hub module: https://tfhub.dev/google/cyclegan/horse2zebra/1
TF-Hub load failed for https://tfhub.dev/google/cyclegan/horse2zebra/1 - name 'hub' is not defined


In [None]:

if resnet_loaded_from_hub and G_resnet_loaded is not None:
    G_resnet_trained = G_resnet_loaded
else:
    print('No TF-Hub ResNet generator found; using local ResNet implementation (random init).')
    G_resnet_trained = G_resnet


In [None]:

# ============================================================================
# Visualization: compare outputs on Horse test images (2x2 grid)
# ============================================================================
#@title Visualization: compare UNet vs ResNet outputs on a Horse test image { display-mode: "form" }
# Pick a single horse test image
horse_iter = iter(test_horses)
try:
    horse_img = next(horse_iter)
except StopIteration:
    raise RuntimeError('No horse test images found in dataset')

input_image = horse_img  # shape (1, H, W, 3), [-1,1]

# Run both generators
unet_out = G_unet_trained(input_image, training=False)
resnet_out = G_resnet_trained(input_image, training=False)

# Convert to [0,1] for plotting
def to_display(x):
    x = (x + 1.0) / 2.0
    x = tf.clip_by_value(x, 0.0, 1.0)
    return x

input_disp = to_display(input_image[0]).numpy()
unet_disp = to_display(unet_out[0]).numpy()
resnet_disp = to_display(resnet_out[0]).numpy()

fig, axes = plt.subplots(2,2, figsize=(10,10))
axes[0,0].imshow(input_disp)
axes[0,0].set_title('Input Horse Image')
axes[0,0].axis('off')

axes[0,1].imshow(unet_disp)
axes[0,1].set_title('Output Zebra by UNet')
axes[0,1].axis('off')

axes[1,0].imshow(input_disp)
axes[1,0].set_title('Input Horse Image')
axes[1,0].axis('off')

axes[1,1].imshow(resnet_disp)
if resnet_loaded_from_hub:
    axes[1,1].set_title('Output Zebra by ResNet (TF-Hub)')
else:
    axes[1,1].set_title('Output Zebra by ResNet (Fallback)')
axes[1,1].axis('off')

plt.tight_layout()
plt.show()


In [None]:

# ============================================================================
# Save models to disk and verify loading
# ============================================================================
#@title Save models and reload to verify { display-mode: "form" }
print('\nSaving generators...')

# Save UNet generator
try:
    G_unet_trained.save(UNET_SAVE_PATH)
    print('Saved UNet generator to', UNET_SAVE_PATH)
except Exception as e:
    print('Could not save full UNet model, saving weights only:', e)
    G_unet_trained.save_weights(UNET_SAVE_PATH + '.weights')
    print('Saved UNet weights to', UNET_SAVE_PATH + '.weights')

# Save ResNet generator
try:
    G_resnet_trained.save(RESNET_SAVE_PATH)
    print('Saved ResNet generator to', RESNET_SAVE_PATH)
except Exception as e:
    print('Could not save full ResNet model, saving weights only:', e)
    G_resnet_trained.save_weights(RESNET_SAVE_PATH + '.weights')
    print('Saved ResNet weights to', RESNET_SAVE_PATH + '.weights')

# Attempt reload
print('\nReloading saved models...')
loaded_unet = None
loaded_resnet = None
try:
    loaded_unet = tf.keras.models.load_model(UNET_SAVE_PATH, custom_objects={'InstanceNormalization': InstanceNormalization})
    print('Successfully loaded UNet from', UNET_SAVE_PATH)
except Exception as e:
    print('Failed to load UNet full model:', e)
    try:
        fresh_unet = build_unet_generator()
        fresh_unet.load_weights(UNET_SAVE_PATH + '.weights')
        loaded_unet = fresh_unet
        print('Loaded UNet weights into fresh model from', UNET_SAVE_PATH + '.weights')
    except Exception as e2:
        print('Failed to load UNet weights fallback:', e2)

try:
    loaded_resnet = tf.keras.models.load_model(RESNET_SAVE_PATH, custom_objects={'InstanceNormalization': InstanceNormalization})
    print('Successfully loaded ResNet from', RESNET_SAVE_PATH)
except Exception as e:
    print('Failed to load ResNet full model:', e)
    try:
        fresh_resnet = build_resnet_generator()
        fresh_resnet.load_weights(RESNET_SAVE_PATH + '.weights')
        loaded_resnet = fresh_resnet
        print('Loaded ResNet weights into fresh model from', RESNET_SAVE_PATH + '.weights')
    except Exception as e2:
        print('Failed to load ResNet weights fallback:', e2)

print('\nDone. If you see success messages above, saving and loading succeeded.')

# ============================================================================
# Notes & caveats
# ============================================================================
#@title Notes { display-mode: "form" }
# - This notebook trains a CycleGAN with UNet generators on Horse2Zebra using the first 200 images of each domain for 25 epochs.
# - It attempts to load a TF-Hub ResNet generator for Horse2Zebra; if unavailable, it uses the locally implemented ResNet (random init).
# - The plotted comparison uses an actual horse test image (horse->zebra translation) as requested.
# - Training for high-quality results typically requires more data, compute, and tuning; this script is tuned to be runnable in Colab within limited time by using a subset.

# End of notebook
