In [73]:
import tensorflow as tf
tf.config.list_physical_devices('GPU')

[]

In [74]:
import tensorflow as tf

class PatchConverter(tf.keras.layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID"
        )
        patch_dim = patches.shape[-1] # last dimension of patches is the flattened patch size (e.g. 16x16 patches is 768 - 16x16x3)
        patches = tf.reshape(patches, [batch_size, -1, patch_dim]) # reshape patches into 3D tensor of shape [batch_size, total patches per image, flattened patch size]
        return patches

In [75]:
class PatchEmbeddingWithClassToken(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_layer = tf.keras.layers.Dense(embedding_dim) # a simple dense layer for projection
        self.class_token = self.add_weight(
            shape=(1, 1, embedding_dim),
            initializer="random_normal",
            trainable=True,
            name="class_token"
        )

    def call(self, patches):
        patch_embeddings = self.embedding_layer(patches)
        batch_size = tf.shape(patch_embeddings)[0]
        class_token = tf.broadcast_to(self.class_token, [batch_size, 1, self.class_token.shape[-1]]) # duplicate class_token to match batch size
        return tf.concat([class_token, patch_embeddings], axis=1)

In [76]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, num_patches, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.position_embeddings = self.add_weight(
            shape=(1, num_patches, embedding_dim),
            initializer="random_normal",
            trainable=True,
            name="positional_embeddings"
        )

    def call(self, patch_embeddings):
        embeddings = self.position_embeddings + patch_embeddings
        return embeddings

In [77]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, num_heads, embedding_dim, mlp_dim, dropout_rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.attn_layer = tf.keras.layers.MultiHeadAttention( # key_dim = 768 / 12 = 64
            num_heads=num_heads, key_dim=embedding_dim // num_heads, dropout=dropout_rate
        )
        self.norm1 = tf.keras.layers.LayerNormalization()
        self.norm2 = tf.keras.layers.LayerNormalization()
        self.mlp = tf.keras.Sequential([
            tf.keras.layers.Dense(mlp_dim, activation="gelu"), # paper user GELU activation
            tf.keras.layers.Dropout(dropout_rate),
            tf.keras.layers.Dense(embedding_dim)
        ])
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, inputs, training=False):
        # mutli-head attention sublayer
        x = self.norm1(inputs)
        x = self.attn_layer(x, value=x, training=training)
        x = self.dropout(x, training=training)
        x = tf.keras.layers.Add()([x, inputs])
        
        # mlp sublayer
        x1 = self.norm2(x)
        x1 = self.mlp(x1, training=training)
        x1 = self.dropout(x1, training=training)
        x1 = tf.keras.layers.Add()([x1, x])

        return x1

In [78]:
class ClassificationHead(tf.keras.layers.Layer):
    def __init__(self, num_classes, mlp_dim=None, **kwargs):
        super().__init__(**kwargs)
        output_layer = tf.keras.layers.Dense(num_classes, activation="softmax")
        if mlp_dim:
            self.mlp = tf.keras.Sequential([ # can add dropout here if overfitting during training
                tf.keras.layers.Dense(mlp_dim, activation="gelu"),
                output_layer
            ])
        else:
            self.mlp = output_layer

    def call(self, class_token):
        return self.mlp(class_token)

In [79]:
class ViT(tf.keras.Model):
    def __init__(
        self, 
        input_shape, 
        patch_size, 
        num_classes, 
        embedding_dim, 
        num_heads, 
        num_layers, 
        mlp_dim, 
        dropout_rate,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.patch_converter = PatchConverter(patch_size)
        self.patch_embedding_class_token = PatchEmbeddingWithClassToken(embedding_dim)
        self.positional_embedding = PositionalEmbedding(
            num_patches=(input_shape[0] // patch_size) ** 2 + 1, # +1 includes [class] embedding
            embedding_dim=embedding_dim
        )
        self.encoder_stack = [
            EncoderLayer(num_heads, embedding_dim, mlp_dim, dropout_rate)
            for _ in range(num_layers)
        ]
        self.class_token_norm = tf.keras.layers.LayerNormalization()
        self.classification_head = ClassificationHead(num_classes, 2048)


    def call(self, inputs):
        # Extract the patches and create the patch embeddings w/ class token
        patches = self.patch_converter(inputs)
        embeddings = self.patch_embedding_class_token(patches)

        # Add positional embeddings
        embeddings = self.positional_embedding(embeddings)

        # Encoder stack
        x = embeddings
        for layer in self.encoder_stack:
            x = layer(x)

        # Extract and normalize the [class] token
        class_token = x[:, 0]
        norm_class_token = self.class_token_norm(class_token)

        # Classification head
        class_probas = self.classification_head(norm_class_token)
        
        return class_probas

In [80]:
def test_transformer_layer_output_shapes(
    single_batch,
    input_shape,
    patch_size,
    embedding_dim,
    num_heads,
    mlp_dim,
    dropout_rate,
    num_classes
):

    batch_size = single_batch.shape[0]
    patches_per_image = (input_shape[0] // patch_size) ** 2 
    total_patches = patches_per_image + 1 # class token

    # Patch Conversion
    patch_converter = PatchConverter(patch_size)
    patches = patch_converter(single_batch)
    assert patches.shape == (batch_size, patches_per_image, embedding_dim)

    # Patch Embeddings + Class Token
    patch_embedding_with_class_token = PatchEmbeddingWithClassToken(embedding_dim)
    embeddings = patch_embedding_with_class_token(patches)
    assert embeddings.shape == (batch_size, total_patches, embedding_dim)

    # Positional Embeddings
    positional_embedding = PositionalEmbedding(total_patches, embedding_dim)
    embeddings = positional_embedding(embeddings)
    assert embeddings.shape == (batch_size, total_patches, embedding_dim)

    # Single Encoder Layer
    single_encoder = EncoderLayer(num_heads, embedding_dim, mlp_dim, dropout_rate)
    encoder_output = single_encoder(embeddings)
    assert encoder_output.shape == (batch_size, total_patches, embedding_dim)
    
    # Class Token Norm & Classification Head
    class_token_norm = tf.keras.layers.LayerNormalization()
    classification_head = ClassificationHead(num_classes)
    class_token = encoder_output[:, 0]
    norm_class_token = class_token_norm(class_token)
    class_probas = classification_head(norm_class_token)
    assert class_probas.shape == (batch_size, num_classes)

# Data Preprocessing for iNat2017

See [/dataset-utils](./dataset-utils/README.md) for details on preprocessing approach for this dataset, no augmentation has been applied. Configurations used (inspired in part by values from paper):

Model Config:
- Image size: `224x224`
- Patch size: `16x16`
- Embedding dimension: `768`
- Number of heads: `12` (64-dim embedding per head)
- Encoder stack: `12`
- MLP dimension: `3072`
- Dropout rate: `0.1`
- Classification head (dense layer): `2048` units (with gelu)
- Classification head (output layer): `5089` classes (with softmax)

Training Config:
- Batch size: `32`
- Base learning rate: `3e-3` (linear warmup for 10%, then cosine decay)
- Optimizer: `AdamW`
- Weight decay: `0.1`
- Epochs: `30`
- Shuffle buffer size: `50000`

Roughly 97.8M trainable parameters.

Additional:
- SageMaker instance size for training: `G5` or `P3` family of instances.
- SageMaker storage container: `20GB` (~18GB of processed training and validation data)

---

The following is a scaled down version of the above configuration. This can still achieve meaningful pretraining results though, and you'll pay much less in cloud compute costs.

Model Config:
- Embedding dimension: `256`
- Number of heads: `4` (64-dim embedding per head)
- Encoder stack: `4`
- MLP dimension: `1024`
- Classification head (dense layer): `512` units

Training Config:
- Batch size: `16`
- Base learning rate: `1e-4`
- Epochs: `8`
- Shuffle buffer size: `10000`
- Dataset subset: `20%`

Additional:
- SageMaker instance size for training: `G4dn` instances.

In [14]:
# validation set
!aws s3 cp s3://inat17-train-val-records/train_val_images-processed/val2017/ ./train_val_images-processed/val2017/ --recursive

download: s3://inat17-train-val-records/train_val_images-processed/val2017/inat17_batch-1-of-96.tfrecord to train_val_images-processed/val2017/inat17_batch-1-of-96.tfrecord


In [15]:
# training set
!aws s3 cp s3://inat17-train-val-records/train_val_images-processed/train2017/ ./train_val_images-processed/train2017/ --recursive

download: s3://inat17-train-val-records/train_val_images-processed/train2017/inat17_batch-1-of-580.tfrecord to train_val_images-processed/train2017/inat17_batch-1-of-580.tfrecord


In [81]:
feature_description = {
    "image": tf.io.FixedLenFeature([], tf.string),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

def normalize_image(image):
    return tf.cast(image, tf.float32) / 127.5 - 1.0 # normalize to [-1, 1]

def parse_example_safely(serialized_example):
    try:
        parsed_example = tf.io.parse_single_example(serialized_example, feature_description)
    
        # decode and normalize image, and extract label
        image = tf.image.decode_jpeg(parsed_example["image"], channels=3)
        image = normalize_image(image)
        label = parsed_example["label"]
        
        return image, label
    except tf.errors.InvalidArgumentError as e:
        tf.print(f"Error parsing example: {e}")
        return None, None
        
# create training and validation sets
train_set_filepaths = [f"./train_val_images-processed/train2017/inat17_batch-{i}-of-580.tfrecord" for i in range(1, 581)]
valid_set_filepaths = [f"./train_val_images-processed/val2017/inat17_batch-{i}-of-96.tfrecord" for i in range(1, 97)]

dataset_buffer_size = 16 * 1024 * 1024
train_set = tf.data.TFRecordDataset(
    train_set_filepaths, 
    compression_type="GZIP",
    buffer_size=dataset_buffer_size
) # 579,184 examples

valid_set = tf.data.TFRecordDataset(
    valid_set_filepaths, 
    compression_type="GZIP",
    buffer_size=dataset_buffer_size
) #  95,986 examples

# map, shuffle, batch and prefetch
batch_size = 64
buffer_size = 50000
train_set = (
    train_set
    .map(parse_example_safely, num_parallel_calls=tf.data.AUTOTUNE)
    .filter(lambda image, label: image is not None and label is not None)
    .shuffle(buffer_size)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

valid_set = (
    valid_set
    .map(parse_example_safely, num_parallel_calls=tf.data.AUTOTUNE)
    .filter(lambda image, label: image is not None and label is not None)
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

for image_batch, label_batch in train_set.take(1):
    print(f"Shape of images in first batch: {image.shape}, with labels: {label.shape}")
    test_transformer_layer_output_shapes(
        single_batch=image_batch,
        input_shape=(224, 224, 3),
        patch_size=16,
        embedding_dim=768,
        num_heads=12,
        mlp_dim=3072,
        dropout_rate=0.1,
        num_classes=5089,
    )
    print("Tests passed: test_transformer_layer_output_shapes")

Shape of images in first batch: (64, 224, 224, 3), with labels: (64,)
Tests passed: test_transformer_layer_output_shapes


In [96]:
# prep callbacks
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    "checkpoints/vit_inat17_sagemaker/vit_inat17_epoch-{epoch:02d}.weights.h5",
    save_weights_only=True,
    save_best_only=True,
    monitor="val_loss",
    mode="min",
    verbose=1
)

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
)

train_set_size = 579184
epochs = 30
linear_lr_decay = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=3e-3,
    decay_steps=(train_set_size // batch_size) * epochs, # base LR * (1 - t / T)
    end_learning_rate=0.0,
)

# create ViT
vit_pretraining_model = ViT(
    input_shape=(224, 224, 3),
    patch_size=16,
    num_classes=5089,
    embedding_dim=768,
    num_heads=12,
    num_layers=12,
    mlp_dim=3072,
    dropout_rate=0.1,
)

# compile
vit_pretraining_model.compile(
    optimizer=tf.keras.optimizers.AdamW( # with weight decay (like l2 regularization)
        learning_rate=linear_lr_decay, # we can reduce base lr or weight decay if unstable during training
        weight_decay=0.1
    ),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

In [83]:
import time

start_time = time.time()

# train
vit_pretraining_model.fit(
    train_set,
    validation_data=valid_set,
    epochs=epochs,
    callbacks=[checkpoint_callback],
    verbose=1
)

end_time = time.time()
total_time = end_time - start_time
print(f"Total training time for {epochs} epochs: {total_time:.2f} seconds ({total_time / 60:.2f} minutes)")

     16/Unknown [1m858s[0m 53s/step - accuracy: 0.0035 - loss: 8.7261




Epoch 1: val_loss improved from inf to 8.76769, saving model to checkpoints/vit_inat17_sagemaker/vit_inat17_epoch-01.weights.h5
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1029s[0m 64s/step - accuracy: 0.0036 - loss: 8.7212 - val_accuracy: 0.0010 - val_loss: 8.7677
Total training time for 1 epochs: 1028.95 seconds (17.15 minutes)
