# 1. Data preparation

In [1]:
%load_ext autoreload
%autoreload 2

from data.datasets import TrainingPicassoDataset
from config.datasets import dataset_configs

z_range = 1000
dataset = 'picasso_test'
train_dataset = TrainingPicassoDataset(dataset_configs[dataset]['training'], z_range)

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [15, 5]



Loaded spots...


  0%|          | 0/93 [00:00<?, ?it/s]

Aligning (94, 41, 31, 31) psfs...


100%|██████████| 93/93 [00:18<00:00,  5.11it/s]


Prepared stacks...
masking
Adding noise... 30
0.0 2317.5955918324116
195 2110
195 1648
Standardising using
 	mean: 0.003320395030787893
 	std 0.0007166993002643155


In [2]:
import numpy as np
from skimage.transform import resize

for k in train_dataset.data.keys():
#     train_dataset.data[k][0][0] = np.repeat(train_dataset.data[k][0][0], 3, axis=-1)
    train_dataset.data[k][0][0] = np.stack([resize(img, (64, 64, 1), anti_aliasing=True) for img in train_dataset.data[k][0][0]])
    print(train_dataset.data[k][0][0].shape)

(77345, 64, 64, 1)
(713, 64, 64, 1)
(357, 64, 64, 1)


In [3]:
x_train, y_train = train_dataset.data['train'][0][0], train_dataset.data['train'][1]
x_test, y_test = train_dataset.data['test'][0][0], train_dataset.data['test'][1]


In [4]:
import tensorflow as tf
x_train = tf.convert_to_tensor(x_train)
y_train = tf.convert_to_tensor(y_train)
x_test = tf.convert_to_tensor(x_test)
y_test = tf.convert_to_tensor(y_test)

In [5]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [None]:
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tqdm.keras import TqdmCallback


"""
## Configure the hyperparameters
"""

learning_rate = 0.01
weight_decay = 0.0001
batch_size = 2**8
num_epochs = 5000
image_size = 64  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier
num_classes = 1

"""
## Use data augmentation
"""

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)


"""
## Implement multilayer perceptron (MLP)
"""


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


"""
## Implement patch creation as a layer
"""


class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


"""
Let's display patches for a sample image
"""

import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.numpy().astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 1))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

"""
## Implement the patch encoding layer
The `PatchEncoder` layer will linearly transform a patch by projecting it into a
vector of size `projection_dim`. In addition, it adds a learnable position
embedding to the projected vector.
"""


class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded


"""
## Build the ViT model
The ViT model consists of multiple Transformer blocks,
which use the `layers.MultiHeadAttention` layer as a self-attention mechanism
applied to the sequence of patches. The Transformer blocks produce a
`[batch_size, num_patches, projection_dim]` tensor, which is processed via an
classifier head with softmax to produce the final class probabilities output.
Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929),
which prepends a learnable embedding to the sequence of encoded patches to serve
as the image representation, all the outputs of the final Transformer block are
reshaped with `layers.Flatten()` and used as the image
representation input to the classifier head.
Note that the `layers.GlobalAveragePooling1D` layer
could also be used instead to aggregate the outputs of the Transformer block,
especially when the number of patches and the projection dimensions are large.
"""


def create_vit_classifier():
    inputs = layers.Input(shape=((image_size, image_size, 1)))
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


"""
## Compile, train, and evaluate the mode
"""


def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.MeanSquaredError(),
        metrics=[
            keras.metrics.RootMeanSquaredError(name="RMSE"),

        ],
    )

    callbacks = [
        ReduceLROnPlateau(
        monitor='loss', factor=0.1, patience=25, verbose=True,
        mode='min', min_delta=1, cooldown=25, min_lr=1e-7,),
        EarlyStopping(monitor='RMSE', patience=100, verbose=False, min_delta=1, restore_best_weights=True),
    ]
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=callbacks,
        verbose=2,
        shuffle=True
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)

Image size: 64 X 64
Patch size: 6 X 6
Patches per image: 100
Elements per patch: 36
Epoch 1/5000
272/272 - 25s - loss: 147354.9844 - RMSE: 383.8685 - val_loss: 148320.1719 - val_RMSE: 385.1236 - lr: 0.0100 - 25s/epoch - 94ms/step
Epoch 2/5000
272/272 - 19s - loss: 132656.5156 - RMSE: 364.2204 - val_loss: 98483.2344 - val_RMSE: 313.8204 - lr: 0.0100 - 19s/epoch - 71ms/step
Epoch 3/5000
272/272 - 19s - loss: 129624.8359 - RMSE: 360.0345 - val_loss: 95570.6484 - val_RMSE: 309.1450 - lr: 0.0100 - 19s/epoch - 71ms/step
Epoch 4/5000
272/272 - 19s - loss: 130118.5078 - RMSE: 360.7194 - val_loss: 120856.3516 - val_RMSE: 347.6440 - lr: 0.0100 - 19s/epoch - 71ms/step
Epoch 5/5000
272/272 - 19s - loss: 132404.1719 - RMSE: 363.8738 - val_loss: 85589.1797 - val_RMSE: 292.5563 - lr: 0.0100 - 19s/epoch - 71ms/step
Epoch 6/5000
272/272 - 19s - loss: 171906.3750 - RMSE: 414.6159 - val_loss: 111564.2109 - val_RMSE: 334.0123 - lr: 0.0100 - 19s/epoch - 71ms/step
Epoch 7/5000
272/272 - 19s - loss: 146594.7

Epoch 56/5000
272/272 - 19s - loss: 54509.5742 - RMSE: 233.4729 - val_loss: 30655.9336 - val_RMSE: 175.0884 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 57/5000
272/272 - 19s - loss: 51876.8125 - RMSE: 227.7648 - val_loss: 31380.2441 - val_RMSE: 177.1447 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 58/5000
272/272 - 19s - loss: 49013.8125 - RMSE: 221.3906 - val_loss: 29506.3145 - val_RMSE: 171.7740 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 59/5000
272/272 - 19s - loss: 47988.8828 - RMSE: 219.0636 - val_loss: 32053.9941 - val_RMSE: 179.0363 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 60/5000
272/272 - 19s - loss: 46026.2148 - RMSE: 214.5372 - val_loss: 28476.0645 - val_RMSE: 168.7485 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 61/5000
272/272 - 19s - loss: 42535.9336 - RMSE: 206.2424 - val_loss: 28897.3105 - val_RMSE: 169.9921 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 62/5000
272/272 - 19s - loss: 40554.2734 - RMSE: 201.3809 - val_loss: 30510.2383 - val_RMSE: 174.671

272/272 - 19s - loss: 8698.0693 - RMSE: 93.2634 - val_loss: 22012.1328 - val_RMSE: 148.3649 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 112/5000
272/272 - 19s - loss: 8575.6182 - RMSE: 92.6046 - val_loss: 21176.7227 - val_RMSE: 145.5222 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 113/5000
272/272 - 19s - loss: 8458.0254 - RMSE: 91.9675 - val_loss: 21242.3613 - val_RMSE: 145.7476 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 114/5000
272/272 - 19s - loss: 8244.6953 - RMSE: 90.8003 - val_loss: 21685.8555 - val_RMSE: 147.2612 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 115/5000
272/272 - 19s - loss: 8186.5654 - RMSE: 90.4796 - val_loss: 21153.3008 - val_RMSE: 145.4417 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 116/5000
272/272 - 19s - loss: 8248.0586 - RMSE: 90.8188 - val_loss: 21184.4102 - val_RMSE: 145.5486 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 117/5000
272/272 - 19s - loss: 8302.5547 - RMSE: 91.1184 - val_loss: 22323.7070 - val_RMSE: 149.4112 - lr: 1.0000e-03 - 1

Epoch 167/5000
272/272 - 19s - loss: 5637.9551 - RMSE: 75.0863 - val_loss: 20976.5703 - val_RMSE: 144.8329 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 168/5000
272/272 - 19s - loss: 5780.0483 - RMSE: 76.0266 - val_loss: 20882.2188 - val_RMSE: 144.5068 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 169/5000
272/272 - 19s - loss: 5726.9795 - RMSE: 75.6768 - val_loss: 21005.8398 - val_RMSE: 144.9339 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 170/5000
272/272 - 19s - loss: 5684.0571 - RMSE: 75.3927 - val_loss: 20567.3496 - val_RMSE: 143.4132 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 171/5000
272/272 - 19s - loss: 5647.8013 - RMSE: 75.1519 - val_loss: 20568.6406 - val_RMSE: 143.4177 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 172/5000
272/272 - 19s - loss: 5464.6733 - RMSE: 73.9234 - val_loss: 19024.6309 - val_RMSE: 137.9298 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 173/5000
272/272 - 19s - loss: 5430.0708 - RMSE: 73.6890 - val_loss: 21667.2109 - val_RMSE: 147.1979 - lr:

Epoch 223/5000
272/272 - 19s - loss: 5079.3418 - RMSE: 71.2695 - val_loss: 20028.1797 - val_RMSE: 141.5210 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 224/5000
272/272 - 19s - loss: 5049.5547 - RMSE: 71.0602 - val_loss: 21084.7441 - val_RMSE: 145.2059 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 225/5000
272/272 - 19s - loss: 5062.9370 - RMSE: 71.1543 - val_loss: 19764.4492 - val_RMSE: 140.5861 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 226/5000
272/272 - 19s - loss: 5070.6694 - RMSE: 71.2086 - val_loss: 19239.6836 - val_RMSE: 138.7072 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 227/5000
272/272 - 19s - loss: 4902.0137 - RMSE: 70.0144 - val_loss: 19513.0938 - val_RMSE: 139.6893 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 228/5000
272/272 - 19s - loss: 5003.3286 - RMSE: 70.7342 - val_loss: 19286.7324 - val_RMSE: 138.8767 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 229/5000
272/272 - 19s - loss: 4900.1025 - RMSE: 70.0007 - val_loss: 21407.8145 - val_RMSE: 146.3141 - lr:

Epoch 279/5000
272/272 - 19s - loss: 4710.1733 - RMSE: 68.6307 - val_loss: 19589.7480 - val_RMSE: 139.9634 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 280/5000
272/272 - 19s - loss: 4804.6431 - RMSE: 69.3155 - val_loss: 21101.8027 - val_RMSE: 145.2646 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 281/5000
272/272 - 19s - loss: 4670.8989 - RMSE: 68.3440 - val_loss: 20813.7734 - val_RMSE: 144.2698 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 282/5000
272/272 - 19s - loss: 4852.9189 - RMSE: 69.6629 - val_loss: 20423.4238 - val_RMSE: 142.9106 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 283/5000
272/272 - 19s - loss: 4837.6323 - RMSE: 69.5531 - val_loss: 21348.1426 - val_RMSE: 146.1100 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 284/5000
272/272 - 19s - loss: 4738.8677 - RMSE: 68.8394 - val_loss: 20825.6543 - val_RMSE: 144.3110 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 285/5000
272/272 - 19s - loss: 4652.7520 - RMSE: 68.2111 - val_loss: 22267.2832 - val_RMSE: 149.2223 - lr:

Epoch 335/5000
272/272 - 19s - loss: 4747.6758 - RMSE: 68.9034 - val_loss: 20963.7461 - val_RMSE: 144.7886 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 336/5000
272/272 - 19s - loss: 4612.7314 - RMSE: 67.9171 - val_loss: 21005.4805 - val_RMSE: 144.9327 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 337/5000
272/272 - 19s - loss: 4685.7524 - RMSE: 68.4526 - val_loss: 20232.7168 - val_RMSE: 142.2417 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 338/5000
272/272 - 19s - loss: 4610.1968 - RMSE: 67.8984 - val_loss: 19824.7637 - val_RMSE: 140.8004 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 339/5000
272/272 - 19s - loss: 4523.2095 - RMSE: 67.2548 - val_loss: 18153.6934 - val_RMSE: 134.7356 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 340/5000
272/272 - 19s - loss: 4623.4644 - RMSE: 67.9961 - val_loss: 20079.4375 - val_RMSE: 141.7019 - lr: 1.0000e-03 - 19s/epoch - 71ms/step
Epoch 341/5000
272/272 - 19s - loss: 4673.3564 - RMSE: 68.3619 - val_loss: 20058.7324 - val_RMSE: 141.6288 - lr:

Epoch 390/5000
272/272 - 19s - loss: 3907.3872 - RMSE: 62.5091 - val_loss: 20422.9297 - val_RMSE: 142.9088 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 391/5000
272/272 - 19s - loss: 3899.5327 - RMSE: 62.4462 - val_loss: 20492.2539 - val_RMSE: 143.1512 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 392/5000
272/272 - 19s - loss: 3973.3274 - RMSE: 63.0343 - val_loss: 21012.8613 - val_RMSE: 144.9581 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 393/5000
272/272 - 19s - loss: 3947.2747 - RMSE: 62.8273 - val_loss: 20453.5527 - val_RMSE: 143.0159 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 394/5000
272/272 - 19s - loss: 4014.4604 - RMSE: 63.3598 - val_loss: 20160.1465 - val_RMSE: 141.9864 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 395/5000
272/272 - 19s - loss: 3943.4390 - RMSE: 62.7968 - val_loss: 20533.2129 - val_RMSE: 143.2941 - lr: 1.0000e-04 - 19s/epoch - 71ms/step
Epoch 396/5000
