# Masked image modeling with Autoencoders

**Author:** [Lennart Seeger], [Aritra Roy Gosthipaty], [Sayak Paul]<br>
**Date created:** 2021/12/20<br>
**Last modified:** 2023/03/24<br>

In [None]:
import numpy as np
import os
import sys
from tensorflow.keras import layers
import tensorflow_addons as tfa
from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.applications.resnet import ResNet50
import keras

sys.path.insert(1, '../src')
%load_ext autoreload
%autoreload 2

from data.datasets import get_mlrsnet, get_denmark
from models.mae import prepare_data, get_test_augmentation_model, Patches, PatchEncoder, get_train_augmentation_model, get_test_augmentation_model, create_encoder, create_decoder, MaskedAutoencoder, mlp, TrainMonitor
from model_utility.learning_rate_scheduler import WarmUpCosine
from supportive.evaluate import evaluate_extractor

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
# data
image_size = 64
x_test, y_test = get_denmark(image_size=64)
x_train=np.load("../data/avg_std30.npy")

# training
buffer_size = 1024
batch_size = 512
auto = tf.data.AUTOTUNE
num_classes = 25
epochs = 50
steps_per_epoch = len(x_train)//batch_size

# OPTIMIZER
learning_rate = 0.001
weight_decay = 0.0001

# patching and masking
patch_size = 4
num_patches = (image_size // patch_size) ** 2
mask_proportion = 0.75
input_shape = (image_size, image_size, 3)

# encoder and decoder
layer_norm_eps = 1e-6
enc_projection_dim = 64
dec_projection_dim = 32
enc_num_heads = 2
enc_layers = 2
dec_num_heads = 1
dec_layers = (
    1  # The decoder is lightweight but should be reasonably deep for reconstruction.
)
enc_transformer_units = [
    enc_projection_dim * 2,
    enc_projection_dim,
]  # Size of the transformer layers.
dec_transformer_units = [
    dec_projection_dim * 2,
    dec_projection_dim,
]

In [None]:
# manage the dataset
train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(buffer_size).batch(batch_size).prefetch(auto)

test_ds = tf.data.Dataset.from_tensor_slices(x_test)
test_ds = test_ds.batch(batch_size).prefetch(auto)

In [None]:
# Get a batch of images.
image_batch = next(iter(train_ds))

# Augment the images.
augmentation_model = get_train_augmentation_model(input_shape, image_size)
augmented_images = augmentation_model(image_batch)

# Define the patch layer.
patch_layer = Patches(patch_size)

# Get the patches from the batched images.
patches = patch_layer(images=augmented_images)

# Now pass the images and the corresponding patches
# to the `show_patched_image` method.
random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)

# Chose the same chose image and try reconstructing the patches
# into the original image.
image = patch_layer.reconstruct_from_patch(patches[random_index])
plt.imshow(image)
plt.axis("off")
plt.show()

In [None]:
# Create the patch encoder layer.
patch_encoder = PatchEncoder(patch_size, enc_projection_dim, mask_proportion)

# Get the embeddings and positions.
(
    unmasked_embeddings,
    masked_embeddings,
    unmasked_positions,
    mask_indices,
    unmask_indices,
) = patch_encoder(patches=patches)


# Show a maksed patch image.
new_patch, random_index = patch_encoder.generate_masked_image(patches, unmask_indices)

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
img = patch_layer.reconstruct_from_patch(new_patch)
plt.imshow(keras.utils.array_to_img(img))
plt.axis("off")
plt.title("Masked")
plt.subplot(1, 2, 2)
img = augmented_images[random_index]
plt.imshow(keras.utils.array_to_img(img))
plt.axis("off")
plt.title("Original")
plt.show()

## Model initialization

In [None]:
train_augmentation_model = get_train_augmentation_model(input_shape, image_size)
test_augmentation_model = get_test_augmentation_model(image_size)
patch_layer = Patches(patch_size)
patch_encoder = PatchEncoder(patch_size, enc_projection_dim, mask_proportion)
encoder = create_encoder(enc_num_heads, enc_layers, enc_projection_dim, layer_norm_eps, enc_transformer_units)
decoder = create_decoder(dec_layers, dec_num_heads, image_size, layer_norm_eps, dec_projection_dim, dec_transformer_units, num_patches, enc_projection_dim)

mae_model = MaskedAutoencoder(
    train_augmentation_model=train_augmentation_model,
    test_augmentation_model=test_augmentation_model,
    patch_layer=patch_layer,
    patch_encoder=patch_encoder,
    encoder=encoder,
    decoder=decoder,
)

### Learning rate scheduler for Optimizer

In [None]:
total_steps = steps_per_epoch * epochs
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
    learning_rate_base=learning_rate,
    total_steps=total_steps,
    warmup_learning_rate=0.0,
    warmup_steps=warmup_steps,
)
fig, ax = plt.subplots(figsize=(8, 5),dpi=100)
lrs = [scheduled_lrs(step) for step in range(total_steps)]
plt.plot(lrs)
plt.xlabel("Step", fontsize=14)
plt.ylabel("LR", fontsize=14)
plt.show()

# save fig
#fig.savefig("learning_rate_schedule")

## Model compilation and training

In [None]:
# train callback
train_callbacks = [TrainMonitor(epoch_interval=5, test_images=x_test)]

In [None]:
optimizer = tfa.optimizers.AdamW(learning_rate=scheduled_lrs, weight_decay=weight_decay)
    
# Compile and pretrain the model.
mae_model.compile(
    optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]
)
#mae_model.predict(x_test)
history = mae_model.fit(
    train_ds.repeat(), epochs=epochs, callbacks=[train_callbacks],steps_per_epoch = steps_per_epoch#val_ds#, validation_data=val_ds
)

# Measure its performance.
loss, mae = mae_model.evaluate(test_ds)
print(f"Loss: {loss:.2f}")
print(f"MAE: {mae:.2f}")

In [None]:
# Measure performance

loss, mae = mae_model.evaluate(test_ds)
print(loss)
print(mae)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['mae'])
plt.title('model accuracy')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss', 'mae'], loc='upper left')
plt.show()

In [None]:
# Extract components for extractor
# Extract the augmentation layers.
train_augmentation_model = mae_model.train_augmentation_model
test_augmentation_model = mae_model.test_augmentation_model

# Extract the patchers.
patch_layer = mae_model.patch_layer
patch_encoder = mae_model.patch_encoder
patch_encoder.downstream = True

# Extract the encoder.
encoder = mae_model.encoder

In [None]:
test_ds = prepare_data(image_size=image_size, images=x_test, labels=y_test, is_train=False, buffer_size=buffer_size, batch_size=batch_size, auto=auto)

In [None]:
# save model
path="../model/mae/model"
encoder.save(path)
model_loaded = keras.models.load_model(path)

In [None]:
extractor_model = keras.Sequential(
    [
        layers.Input((image_size, image_size, 3)),
        get_test_augmentation_model(image_size),
        patch_layer,
        patch_encoder,
        model_loaded,
        layers.BatchNormalization(),
        layers.GlobalAveragePooling1D(),
    ],
    name="extraction_model",
)

In [None]:
model_resnet = ResNet50(weights='imagenet', include_top=False,input_shape=(image_size,image_size,3),pooling="avg")
print("extractor_model: ", evaluate_extractor(extractor_model.predict, x_test, y_test, neighbors=10))
print("model_resnet: ", evaluate_extractor(model_resnet.predict, x_test, y_test, neighbors=10))

In [None]:
with open("../model/mae/results.txt", 'a') as file:
        file.write('\n')
        file.write('\n')
        file.write('\n')
        file.write('------------------------------------------')
        file.write("\nlearning_rate: "+str(learning_rate))
        file.write("\nbatch_size: "+str(batch_size))
        file.write("\nepochs: "+str(epochs))
        file.write("\npatch_size: "+str(patch_size))
        file.write("\nmask_proportion: "+str(mask_proportion))
        file.write("\nenc_projection_dim: "+str(enc_projection_dim))
        file.write("\ndec_projection_dim: "+str(dec_projection_dim))
        file.write("\noptimizer: "+str(optimizer))
        file.write("\nenc_num_heads: "+str(enc_num_heads))
        file.write("\nenc_layers: "+str(enc_layers))
        file.write("\ndec_num_heads: "+str(dec_num_heads))
        file.write("\ndec_layers: "+str(dec_layers))
        file.write('\nneighbor_accuracy: '+str(evaluate_extractor(extractor_model.predict, x_test, y_test, neighbors=10)))
        file.write('\ntest_loss: '+str(loss))
        file.write('\ntest_mae: '+str(mae))