In [139]:
import tensorflow as tf
from tensorflow.keras.applications.resnet_v2 import preprocess_input
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import data
from tensorflow.keras import optimizers

import sklearn
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import OrdinalEncoder
from collections import Counter
import pickle
from sklearn.preprocessing import OneHotEncoder

In [140]:
from importlib import reload
reload(data)

<module 'data' from '/Users/deepakduggirala/Documents/project/siamese/data.py'>

In [2]:
data_dir = '/Users/deepakduggirala/Documents/Elephants-dataset-cropped-png-1024/'

In [3]:
params = {
    'image_size': 256,
    'resize_pad': False,
    'dense_l2_reg_c': 0.0001,
    'embedding_size': 17,
    'lr': 0.001
}

cache_files = {
        'train': str(Path(data_dir) / 'train.cache'),
        'val': str(Path(data_dir) / 'val.cache')
}

In [141]:
image_paths, image_labels = data.get_zoo_elephants_images_and_labels(data_dir)

In [222]:
def shuffle(n):
    x = np.arange(n, dtype=np.int32)
    np.random.shuffle(x)
    return x

In [223]:
def get_support_and_query_sets(image_paths, image_labels, n_support):
    support_images = []
    support_labels = []

    query_images = []
    query_labels = []
    
    np.random.seed(99)
    counts = Counter(image_labels)
    shuffled_idxs = {c:shuffle(count) for c, count in counts.items()}
    
    for c, idxs in shuffled_idxs.items():
        s_idxs = idxs[:n_support]
        q_idxs = idxs[n_support:]

        mask = np.array(image_labels)==c
        c_image_labels = np.array(image_labels)[mask]
        c_image_paths = np.array(image_paths)[mask]

        support_images.extend(c_image_paths[s_idxs]) 
        support_labels.extend(c_image_labels[s_idxs])
        query_images.extend(c_image_paths[q_idxs])
        query_labels.extend(c_image_labels[q_idxs])
        
    return support_images, support_labels, query_images, query_labels

In [220]:
def get_preds(model, support_ds):
    preds = base_model.predict(support_ds, verbose=True)
    return preds/np.linalg.norm(preds, axis=1, keepdims=1)

In [263]:
def get_support_class_means(preds, categories, support_labels):
    class_means = np.zeros((17, 2048))
    for i,c in enumerate(categories):
        mask = np.array(support_labels) == c
        class_means[i,:] = np.mean(preds[mask, :], axis=0)
    return class_means.astype(np.float32)

In [226]:
support_images, support_labels, query_images, query_labels = get_support_and_query_sets(
    image_paths, image_labels, n_support)

In [230]:
params = {
    "image_size": 256,
    "resize_pad": False,
    "batch_size": {
      "support": 32,
      "query": 32,
      },
    "lr": 0.00005,
  "decay_steps": 13,
  "decay_rate": 0.96,
  "dense_l2_reg_c": 0.01,
  'embedding_size': 17
}

In [266]:
base_model = tf.keras.applications.ResNet50V2(include_top=False, weights="imagenet", input_shape=(
        params['image_size'], params['image_size'], 3), pooling='avg')

base_model.trainable = False


inputs = tf.keras.Input(shape=(params['image_size'], params['image_size'], 3))


x = base_model(inputs, training=False)

# dense1 = tf.keras.layers.Dense(units=17, activation='softmax', name='dense1')(x)
custom_layer = CustomLayer(w_init=W.T, units=17, input_dim=2048)(x)
model = tf.keras.Model(inputs, custom_layer)

In [269]:
model.layers[1].name

'resnet50v2'

In [267]:
model.summary()

Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_20 (InputLayer)       [(None, 256, 256, 3)]     0         
                                                                 
 resnet50v2 (Functional)     (None, 2048)              23564800  
                                                                 
 custom_layer_2 (CustomLayer  (None, 17)               34833     
 )                                                               
                                                                 
Total params: 23,599,633
Trainable params: 34,833
Non-trainable params: 23,564,800
_________________________________________________________________


In [245]:
enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
enc.fit(np.array(support_labels).reshape(-1,1))

OneHotEncoder(handle_unknown='ignore', sparse=False)

In [264]:
support_ds, _, _ = get_dataset(support_images, support_labels,
                                       params,
                                       augment=False,
                                       cache_file=None,
                                       shuffle=False,
                                       batch_size=params['batch_size']['support'])

preds = get_preds(base_model, support_ds)
W = get_support_class_means(preds, enc.categories_[0], support_labels)



In [None]:
ls = enc.transform(np.array(query_labels).reshape(-1,1))

In [265]:
W.dtype

dtype('float32')

In [98]:
def preprocess_image(image, image_size, augment=True, model_preprocess=True):
    if augment:
        image = tf.image.random_flip_left_right(image)
        # image = tf.image.random_brightness(image, 0.2)
        # image = tf.image.random_contrast(image, 0.5, 2.0)
        image = tf.image.random_saturation(image, 0.75, 1.25)
        image = tf.image.random_hue(image, 0.05)
        # image = tf.image.random_jpeg_quality(image, 20, 100)
    if model_preprocess:
        image = preprocess_input(image)
    return image


def parse_image_function(image_path, image_size, resize_pad=False):
    # print('reading', image_path)
    image_string = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image_string, channels=3)
    if not resize_pad:
        image = tf.image.resize(image, [image_size, image_size])
    else:
        image = tf.image.resize_with_pad(image, target_height=image_size, target_width=image_size)
    # image = preprocess_image(image, image_size, augment)
    return image


def get_dataset(image_paths, image_labels, params,
                augment=None, cache_file=None, model_preprocess=True,
                shuffle=True, batch_size=32):
    N = len(image_labels)

    AUTOTUNE = tf.data.AUTOTUNE
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
    dataset = dataset.map(lambda x, y: (parse_image_function(
        x, params['image_size'], resize_pad=params['resize_pad']), y), num_parallel_calls=AUTOTUNE)

    if cache_file:
        dataset = dataset.cache(cache_file)

    dataset = dataset.map(lambda x, y: (
        preprocess_image(x, params['image_size'], augment=augment, model_preprocess=model_preprocess), y),
        num_parallel_calls=AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=N)

    if batch_size:
        dataset = dataset.batch(batch_size).prefetch(AUTOTUNE)

    return dataset, N, image_labels

In [97]:
enc = OneHotEncoder(handle_unknown='ignore', sparse=False)
labels = enc.fit_transform(np.array(image_labels).reshape(-1,1))

In [99]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=99)
sss.get_n_splits(image_paths, labels)
train_index, test_index = next(sss.split(image_paths, labels))

In [130]:
def my_loss_fn(y_true, y_pred, C=0.1):
    # tf.print(y_pred.shape)
    cross_entropy_loss = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    entropy_loss = -tf.reduce_sum(y_pred * tf.math.log(y_pred), 1)
    return cross_entropy_loss + C * entropy_loss

In [131]:
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.001), 
    loss=my_loss_fn,
    metrics=['accuracy'])

In [132]:
model.fit(train_ds.take(1), epochs=1, validation_data=val_ds.take(1))



<keras.callbacks.History at 0x16675bd60>

In [258]:
class CustomLayer(tf.keras.layers.Layer):
    def __init__(self, w_init, units=17, input_dim=2048):
        super(CustomLayer, self).__init__()
        self.w = tf.Variable(
            initial_value=w_init,
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        embeddings = tf.math.l2_normalize(inputs, axis=1, epsilon=1e-10) # b x 2048, w: 2048 x 17
        t = tf.matmul(embeddings, self.w)                                # b x 17
        t = tf.math.l2_normalize(t, axis=1, epsilon=1e-10)
        z = t + self.b
        return tf.keras.activations.softmax(z, axis=-1)
