In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from vit_keras import vit, utils, visualize
from models.clf.vit import VIT

In [None]:
img_size = 768
patch_size = 16
n_classes = 1000

In [None]:
config = {
    "dropout": 0.1,
    "mlp_dim": 4096,
    "num_heads": 16,
    "num_layers": 24,
    "hidden_size": 1024,
    "name": "vit-l16",
    "pretrained": "weights/vit_l16_imagenet21k_imagenet2012.h5"
}

In [None]:
# model = VIT(
#     image_size=img_size, 
#     patch_size=patch_size, 
#     num_classes=n_classes, 
#     num_layers=config["num_layers"], 
#     hidden_size=config["hidden_size"], 
#     mlp_dim=config["mlp_dim"], 
#     num_heads=config["num_heads"], 
#     name=config["name"], 
#     dropout=config["dropout"]
# )

In [None]:
classes = utils.get_imagenet_classes()
model = vit.vit_l16(
    image_size=img_size,
    activation='sigmoid',
    pretrained=True,
    include_top=True,
    pretrained_top=True,
    weights="imagenet21k+imagenet2012"
)


```python
Model: "VIT-L_16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 384, 384, 3)]     0         
_________________________________________________________________
embedding (Conv2D)           (None, 24, 24, 1024)      787456    
_________________________________________________________________
reshape (Reshape)            (None, 576, 1024)         0         
_________________________________________________________________
class_token (ClassToken)     (None, 577, 1024)         1024      
_________________________________________________________________
Transformer/posembed_input ( (None, 577, 1024)         590848    
_________________________________________________________________
Transformer/encoderblock_0 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_1 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_2 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_3 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_4 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_5 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_6 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_7 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_8 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_9 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_10  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_11  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_12  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_13  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_14  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_15  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_16  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_17  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_18  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_19  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_20  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_21  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_22  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_23  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoder_norm (La (None, 577, 1024)         2048      
_________________________________________________________________
ExtractToken (Lambda)        (None, 1024)              0         
_________________________________________________________________
head (Dense)                 (None, 1000)              1025000   
=================================================================
Total params: 304,715,752
Trainable params: 304,715,752
Non-trainable params: 0
_________________________________________________________________

```


```python
Model: "vit-l16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 384, 384, 3)]     0         
_________________________________________________________________
embedding (Conv2D)           (None, 24, 24, 1024)      787456    
_________________________________________________________________
reshape (Reshape)            (None, 576, 1024)         0         
_________________________________________________________________
class_token (ClassToken)     (None, 577, 1024)         1024      
_________________________________________________________________
Transformer/posembed_input ( (None, 577, 1024)         590848    
_________________________________________________________________
Transformer/encoderblock_0 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_1 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_2 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_3 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_4 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_5 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_6 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_7 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_8 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_9 ( ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_10  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_11  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_12  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_13  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_14  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_15  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_16  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_17  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_18  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_19  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_20  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_21  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_22  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoderblock_23  ((None, 577, 1024), (None 12596224  
_________________________________________________________________
Transformer/encoder_norm (La (None, 577, 1024)         2048      
_________________________________________________________________
ExtractToken (Lambda)        (None, 1024)              0         
_________________________________________________________________
head (Dense)                 (None, 1000)              1025000   
=================================================================
Total params: 304,715,752
Trainable params: 304,715,752
Non-trainable params: 0
_________________________________________________________________
```

In [None]:
model.summary()

In [None]:
# model.load_weights("weights/vit_l16_imagenet21k_imagenet2012.h5")
model.save_weights("weights/vit_l16_imagenet21k_imagenet2012.h5")

In [None]:
url = 'https://upload.wikimedia.org/wikipedia/commons/d/d7/Granny_smith_and_cross_section.jpg'
image = utils.read(url, img_size)
X = vit.preprocess_inputs(image).reshape(1, img_size, img_size, 3)
y = model.predict(X)
print(classes[y[0].argmax()]) # Granny smith

In [None]:
# Get an image and compute the attention map
url = 'https://upload.wikimedia.org/wikipedia/commons/b/bc/Free%21_%283987584939%29.jpg'
image = utils.read(url, image_size)
attention_map = visualize.attention_map(model=model, image=image)
print('Prediction:', classes[model.predict(vit.preprocess_inputs(image)[np.newaxis])[0].argmax()])  
# Prediction: Eskimo dog, husky

In [None]:
# Plot results
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16,8))
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)