In [None]:
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
plt.style.use('ggplot')
plt.rc('xtick',labelsize=16)
plt.rc('ytick',labelsize=16)

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras import mixed_precision
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.optimizers.schedules import PolynomialDecay, PiecewiseConstantDecay
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from models.clf.vit import VIT, TransformerBlock
from vit_keras import vit, layers, utils
from utils.train_utils import TrainAccumilatorCLF

K.clear_session()
physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

def enable_amp():
    mixed_precision.set_global_policy("mixed_float16")
    
print("Tensorflow version: ", tf.__version__)
print(physical_devices,"\n")
enable_amp() 

In [None]:
def attention_map(model, image):
    """Get an attention map for an image and model using the technique
    described in Appendix D.7 in the paper (unofficial).
    Args:
        model: A ViT model
        image: An image for which we will compute the attention map.
    """
    size = model.input_shape[1]
    grid_size = int(np.sqrt(model.layers[5].output_shape[0][-2] - 1))

    # Prepare the input
    X = vit.preprocess_inputs(cv2.resize(image, (size, size)))[np.newaxis, :]  # type: ignore

    # Get the attention weights from each transformer.
    outputs = [
        l.output[1] for l in model.layers if isinstance(l, TransformerBlock)
    ]
    weights = np.array(
        tf.keras.models.Model(inputs=model.inputs, outputs=outputs).predict(X)
    )
    num_layers = weights.shape[0]
    num_heads = weights.shape[2]
    reshaped = weights.reshape(
        (num_layers, num_heads, grid_size ** 2 + 1, grid_size ** 2 + 1)
    )

    # From Appendix D.6 in the paper ...
    # Average the attention weights across all heads.
    reshaped = reshaped.mean(axis=1)

    # From Section 3 in https://arxiv.org/pdf/2005.00928.pdf ...
    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    reshaped = reshaped + np.eye(reshaped.shape[1])
    reshaped = reshaped / reshaped.sum(axis=(1, 2))[:, np.newaxis, np.newaxis]

    # Recursively multiply the weight matrices
    v = reshaped[-1]
    for n in range(1, len(reshaped)):
        v = np.matmul(v, reshaped[-1 - n])

    # Attention from the output token to the input space.
    mask = v[0, 1:].reshape(grid_size, grid_size)
    mask = cv2.resize(mask / mask.max(), (image.shape[1], image.shape[0]))[
        ..., np.newaxis
    ]
    return (mask * image).astype("uint8")

In [None]:
class ImageNetLoader():
    
    def __init__(self, img_height, img_width, n_classes):
        self.n_classes = n_classes
        self.img_height = img_height
        self.img_width = img_width
        self.MEAN = np.array([0.485, 0.456, 0.406])
        self.STD = np.array([0.229, 0.224, 0.225])
        
    
    @tf.function
    def random_crop(self, image):

        scales = tf.convert_to_tensor(np.array([0.5625, 0.625, 0.6875, 0.75, 0.8125, 0.875, 0.9375, 1.0]))
        scale = scales[tf.random.uniform(shape=[], minval=0, maxval=8, dtype=tf.int32)]
        scale = tf.cast(scale, tf.float32)

        shape = tf.cast(tf.shape(image), tf.float32)
        h = tf.cast(shape[0] * scale, tf.int32)
        w = tf.cast(shape[1] * scale, tf.int32)
        image = tf.image.random_crop(image, size=[h, w, 3])
        return image

    @tf.function
    def normalize(self, image):
        image = image / 255.0
        image = image - self.MEAN
        image = image / self.STD
        return image
    
    
    @tf.function
    def load_image_train(self, datapoint):

        img = datapoint['image']
        label = datapoint['label']
        label = tf.one_hot(tf.cast(label, tf.int32), self.n_classes)

        if tf.random.uniform(()) > 0.5:
            img = tf.image.flip_left_right(img)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_brightness(img, 0.1)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_saturation(img, 0.7, 1.3)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_contrast(img, 0.7, 1.3)
        if tf.random.uniform(()) > 0.5:
            img = tf.image.random_hue(img, 0.1)

        img = self.random_crop(img)
        img = tf.image.resize(img, (self.img_height, self.img_width), method='bilinear')
        img = self.normalize(tf.cast(img, tf.float32))

        return img, label
   

    def load_image_test(self, datapoint):
        img = datapoint['image']
        label = datapoint['label']
        label = tf.one_hot(tf.cast(label, tf.int32), self.n_classes)
        img = tf.image.resize(img, (self.img_height, self.img_width), method='bilinear')
        img = self.normalize(tf.cast(img, tf.float32))
        return img, label

In [None]:
img_size = 768
patch_size = 16
n_classes = 1000
img_width = img_size
img_height = img_size
classes = utils.get_imagenet_classes()

pipeline = ImageNetLoader(
    n_classes = n_classes,
    img_height = img_height,
    img_width = img_width,
)

In [None]:
dataset, info = tfds.load(
    name='imagenet2012:5.0.0', 
    data_dir='/workspace/tensorflow_datasets/', 
    with_info=True, 
    shuffle_files=True
)

In [None]:
train = dataset['train'].map(pipeline.load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
valid = dataset['validation'].map(pipeline.load_image_test, num_parallel_calls=tf.data.AUTOTUNE)

TRAIN_LENGTH = info.splits['train'].num_examples
VALID_LENGTH = info.splits['validation'].num_examples

BATCH_SIZE = 2
ACCUM_STEPS = 128
BUFFER_SIZE = 512
ADJ_BATCH_SIZE = BATCH_SIZE * ACCUM_STEPS
print("Effective batch size: {}".format(ADJ_BATCH_SIZE))

In [None]:
def display_img(img, true_label, pred_label=None):
    plt.figure(figsize=(6,6), dpi=120)
    plt.title("True label: {}".format(true_label), fontsize=12)
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img))
    plt.axis('off')
    plt.show()

In [None]:
for image, label in train.take(4): 
    sample_image, sample_label = image, label

# print(sample_image.shape, sample_label.shape)
display_img(img=sample_image, true_label=classes[tf.argmax(sample_label).numpy()])

In [None]:
train_dataset = train.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
valid_dataset = valid.batch(BATCH_SIZE)

In [None]:
config = {
    "dropout": 0.1,
    "mlp_dim": 4096,
    "num_heads": 16,
    "num_layers": 24,
    "hidden_size": 1024,
    "name": "vit-l16",
    "pretrained": "weights/vit_l16_imagenet21k_imagenet2012.h5"
}

In [None]:
model = VIT(
    image_size=img_size, 
    patch_size=patch_size, 
    num_classes=n_classes, 
    num_layers=config["num_layers"], 
    hidden_size=config["hidden_size"], 
    mlp_dim=config["mlp_dim"], 
    num_heads=config["num_heads"], 
    name=config["name"], 
    dropout=config["dropout"]
)

In [None]:
model.load_weights(config["pretrained"])

In [None]:
def display_attn(img):
    attn = attention_map(model=model, image=img.numpy())
    print('Prediction:', classes[model.predict(img[tf.newaxis, ...])[0].argmax()]) 
    
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16,8))
    ax1.axis('off')
    ax2.axis('off')
    ax1.set_title('Original')
    ax2.set_title('Attention Map')
    _ = ax1.imshow(tf.keras.preprocessing.image.array_to_img(img))
    _ = ax2.imshow(tf.keras.preprocessing.image.array_to_img(attn))
    plt.show()

In [None]:
# display_attn(img=sample_image)

In [None]:
K.clear_session()

In [None]:
MODEL_PATH = "weights/"+model.name+".h5"

In [None]:
EPOCHS = 10

STEPS_PER_EPOCH = TRAIN_LENGTH // ADJ_BATCH_SIZE
VALIDATION_STEPS = VALID_LENGTH // BATCH_SIZE
DECAY_STEPS = (STEPS_PER_EPOCH * EPOCHS) 
print("Decay steps: {}".format(DECAY_STEPS))

In [None]:
CURR_EPOCH = 0
E1 = 30 - CURR_EPOCH
E2 = 60 - CURR_EPOCH
E3 = 90 - CURR_EPOCH

S1 = (STEPS_PER_EPOCH * E1) // ACCUM_STEPS
S2 = (STEPS_PER_EPOCH * E2) // ACCUM_STEPS
S3 = (STEPS_PER_EPOCH * E3) // ACCUM_STEPS

print("--- LR decay --- \nstep {}: {} \nstep {}: {} \nstep {}: {}".format(S1, 1e-2, S2, 1e-3, S3, 1e-4))

In [None]:
learning_rate_fn = PiecewiseConstantDecay(boundaries = [S1, S2, S3], values = [0.1, 0.01, 0.001, 0.0001])

In [None]:
# opt = SGD(learning_rate=learning_rate_fn, momentum=0.9)
opt = SGD(learning_rate=1e-3, momentum=0.9)

trainer = TrainAccumilatorCLF(
    model = model,
    optimizer = mixed_precision.LossScaleOptimizer(opt),
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    n_classes = n_classes,
    reduce_lr_on_plateau = None,
    accum_steps = ACCUM_STEPS,
)

In [None]:
results = trainer.fit(
    epochs = EPOCHS,
    train_dataset = train_dataset,
    test_dataset = valid_dataset, 
    weights_path = MODEL_PATH,
)

In [None]:
results = model.history

In [None]:
def plot_history(results, model):
         
    plt.figure(figsize=(15,7))
    plt.subplot(1,3,1)  

    plt.plot(results.history['loss'], 'r', label='Training loss')
    plt.plot(results.history['val_loss'], 'b', label='Validation loss')
    plt.title("Loss: "+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,2)
    plt.plot(results.history['accuracy'], 'r', label='Training accuracy')
    plt.plot(results.history['val_accuracy'], 'b', label='Validation accuracy')
    plt.title('Accuracy: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})

    plt.subplot(1,3,3)
    plt.plot(results.history['iou_coef'], 'r', label='IoU coefficient')
    plt.plot(results.history['val_iou_coef'], 'b', label='Validation IoU coefficient')
    plt.title('IoU Coefficient: '+model.name, fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(prop={'size': 14})
    
    if fine:
        plt.savefig("plots/"+model.name+"_learning_curves.png")
    else:
        plt.savefig("plots/"+model.name+"_learning_curves_coarse.png")
    plt.show()

In [None]:
plot_history(results, model)