In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import numpy as np
import tensorflow as tf

from ppnet_tf import *

In [4]:
data_dir = '../../data/CUB_200_2011/images_20_classes/'

train_dir = '../../data/cub200_cropped/train_cropped_augmented/'
test_dir = '../../data/cub200_cropped/test_cropped/'

In [5]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=16,
    shuffle=True
)
val_ds = tf.keras.utils.image_dataset_from_directory(
    test_dir,
    label_mode='categorical',
    image_size=(224, 224),
    batch_size=16,
    shuffle=False
)

# train_ds = tf.keras.utils.image_dataset_from_directory(
#     data_dir,
#     label_mode='categorical',
#     seed=42,
#     validation_split=0.2,
#     subset='training',
#     image_size=(224, 224),
#     batch_size=32
# )
# val_ds = tf.keras.utils.image_dataset_from_directory(
#     data_dir,
#     label_mode='categorical',
#     seed=42,
#     validation_split=0.2,
#     subset='validation',
#     image_size=(224, 224),
#     batch_size=32
# )

Found 17426 files belonging to 20 classes.
Found 515 files belonging to 20 classes.


In [6]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def _normalize(image):
    """Normalize the image to zero mean and unit variance."""
    offset = tf.constant(mean, shape=[1, 1, 3])
    image -= offset

    scale = tf.constant(std, shape=[1, 1, 3])
    image /= scale
    return image

def preprocess_image(image, label):
    image = tf.to_float(image) / 255.0
    image = (image - mean) / std
    return image, label

# train_ds = train_ds.map(preprocess_image)
# val_ds = val_ds.map(preprocess_image)

In [7]:
# normalization = tf.keras.layers.Rescaling(1./255)
# train_ds = train_ds.map(lambda x, y: (normalization(x), y))
# val_ds = val_ds.map(lambda x, y: (normalization(x), y))

In [8]:
# resnet50 = tf.keras.models.load_model('../pretrained/keras/resnet50.h5', compile=False)
resnet50 = tf.keras.applications.resnet_v2.ResNet50V2(
    include_top=False,
    input_shape=(224, 224, 3),
    weights='imagenet'
)
densenet121 = tf.keras.applications.densenet.DenseNet121(
    include_top=False,
    input_shape=(224, 224, 3),
    weights='imagenet'
)
# for layer in resnet50.layers:
#     layer.trainable = False
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(
    units=20,
    activation='softmax'
)

In [9]:
def conv_features(config):
    inputs = tf.keras.Input(shape=(224, 224, 3))
    if config['backbone'] == 'resnet':
        x = resnet50(inputs)
    elif config['backbone'] == 'densenet':
        x = densenet121(inputs)


    x = tf.keras.layers.Conv2D(
        filters=config['prototype_shape'][3],
        kernel_size=1,
        kernel_initializer=tf.keras.initializers.GlorotUniform(),
        activation='relu'
    )(x)

    x = tf.keras.layers.Conv2D(
        filters=config['prototype_shape'][3],
        kernel_size=1,
        kernel_initializer=tf.keras.initializers.GlorotUniform(),
        activation='sigmoid'
    )(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [10]:
config = {
    'img_size': 224,
    'prototype_shape': [200, 1, 1, 128],
    'feature_shape': [7, 7, 2048],
    'num_classes': 20,
    'backbone': 'resnet'
}

In [11]:
def get_proto_class_idx(cfg):
    num_classes = cfg['num_classes']
    num_prototypes = cfg['prototype_shape'][0]

    num_prototype_per_classes = num_prototypes // num_classes
    proto_class_id_buffer = tf.zeros((num_prototypes, num_classes), dtype=tf.float32)

    for j in range(num_prototypes):
        class_idx = j // num_prototype_per_classes
        proto_class_id_buffer = tf.tensor_scatter_nd_update(proto_class_id_buffer, [[j, class_idx]], [1])

    return proto_class_id_buffer

In [12]:
proto_class_idx = get_proto_class_idx(
    {
        'prototype_shape': [200, 1, 1, 128],
        'num_classes': 20
    }
)

In [13]:
feature_layers = conv_features(config)
feature_layers.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 resnet50v2 (Functional)     (None, 7, 7, 2048)        23564800  
                                                                 
 conv2d (Conv2D)             (None, 7, 7, 128)         262272    
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 128)         16512     
                                                                 
Total params: 23843584 (90.96 MB)
Trainable params: 23798144 (90.78 MB)
Non-trainable params: 45440 (177.50 KB)
_________________________________________________________________


In [14]:
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.applications.resnet_v2.preprocess_input(inputs)
cnn_features = feature_layers(x)

distances = L2Convolution(config=config)(cnn_features)
min_distances = MinDistancePooling(config=config)(distances)
prototype_activations = Distance2Similarity(config=config)(min_distances)

logits =  tf.keras.layers.Dense(
    name='logits',
    units=config['num_classes'],
    activation='softmax'
)(prototype_activations)

model = tf.keras.Model(inputs=inputs, outputs=[logits, min_distances])

test_model = tf.keras.Model(inputs=model.inputs, outputs=[cnn_features, distances, prototype_activations, min_distances, logits])

In [15]:
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 tf.math.truediv (TFOpLambd  (None, 224, 224, 3)       0         
 a)                                                              
                                                                 
 tf.math.subtract (TFOpLamb  (None, 224, 224, 3)       0         
 da)                                                             
                                                                 
 model (Functional)          (None, 7, 7, 128)         23843584  
                                                                 
 l2_convolution (L2Convolut  (None, 7, 7, 200)         51200     
 ion)                                                            
                                                           

In [16]:
optimizer = tf.keras.optimizers.legacy.Adam()
# cross_entropy_loss = tf.keras.losses.CategoricalCrossentropy()
ppart_loss = proto_part_loss(cfg=config, proto_class_id=proto_class_idx)

model.compile(
    optimizer=optimizer,
    loss={'logits': cross_entropy_loss, 'min_distance_pooling': ppart_loss},
    metrics={'logits': 'accuracy'}
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    verbose=2
)

Epoch 1/10
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[6 6 6 6 6 6 6 6 6 6 6 6 6 6 6 6]
[7 6 7 7 7 7 7 6 7 6 6 6 7 6 6 6]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 7 7 7 7 7 7]
[7 7 7 7 7 7 7 7 7 7 

KeyboardInterrupt: 

In [None]:
sample_inputs, sample_labels = list(train_ds.take(1))[0]
# sample_inputs = tf.keras.applications.resnet_v2.preprocess_input(sample_inputs)

In [None]:
sample_inputs

<tf.Tensor: shape=(32, 224, 224, 3), dtype=float32, numpy=
array([[[[111.50446 , 140.50447 ,  74.50446 ],
         [111.632835, 140.63283 ,  74.632835],
         [113.239174, 142.23918 ,  76.239174],
         ...,
         [121.36574 , 144.36574 , 100.36574 ],
         [122.      , 145.      , 101.      ],
         [122.20944 , 145.20944 , 101.20944 ]],

        [[116.783485, 144.2701  ,  77.51339 ],
         [117.29241 , 144.77902 ,  78.022316],
         [118.6317  , 146.11832 ,  79.3616  ],
         ...,
         [122.7567  , 145.7567  , 100.2433  ],
         [122.7567  , 145.7567  , 100.98651 ],
         [123.58705 , 146.58705 , 102.58705 ]],

        [[120.52232 , 147.52232 ,  80.52232 ],
         [120.89834 , 147.89833 ,  80.89834 ],
         [122.109375, 149.10938 ,  82.109375],
         ...,
         [124.26116 , 147.52232 , 101.      ],
         [124.38941 , 147.52232 , 101.98218 ],
         [125.13582 , 148.13582 , 103.613495]],

        ...,

        [[110.      , 150.      ,

In [None]:
outputs = test_model(sample_inputs)