In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
path_prefix = "/content/drive/My Drive/zero_shot/"

In [None]:
import numpy as np

import os
from collections import Counter

from sklearn.preprocessing import LabelEncoder, normalize
from sklearn.neighbors import KDTree
import tensorflow as tf
from keras.models import Sequential, Model, model_from_json
from keras.layers import Dense, Dropout, Flatten
from keras.layers import BatchNormalization
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.applications.vgg16 import VGG16


import matplotlib.pyplot as plt


np.random.seed(1234)
WORD2VECPATH    = path_prefix + "data/class_vectors.npy"
MODELPATH       = path_prefix + "DatasetA_train_20180813/model/"
MODELNAME = "modelV3"

In [None]:
#set LR decay
def scheduler0(epoch, lr):
    return lr

def scheduler1(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * tf.math.exp(-0.1)

def scheduler2(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * tf.math.exp(-0.05)

def scheduler3(epoch, lr):
  if epoch < 20:
    return lr
  else:
    return lr * tf.math.exp(-0.01)


In [None]:
def save_keras_model(model, model_path):
    """save Keras model and its weights"""
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_json = model.to_json()
    with open(model_path + MODELNAME +  ".json", "w") as json_file:
        json_file.write(model_json)

    # serialize weights to HDF5
    model.save_weights(model_path + MODELNAME + ".h5")
    print("-> zsl model is saved.")
    return

def load_keras_model(model_path):
    with open(model_path + MODELNAME + ".json", 'r') as json_file:
        loaded_model_json = json_file.read()

    loaded_model = model_from_json(loaded_model_json)
    # load weights into new model
    loaded_model.load_weights(model_path + MODELNAME + ".h5")
    return loaded_model

def build_model():
    model  = Sequential()
    base_model = VGG16(include_top=False, input_shape=(224,224,3))
    for layer in base_model.layers:
        layer.trainable = False
    model.add(base_model)
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(NUM_ATTR, activation='relu'))
    model.add(Dense(NUM_CLASS, activation='softmax', trainable=False))#, kernel_initializer=custom_kernel_init))
    return model


def train_model(model, train_ds, val_ds, lr_value, decay):

    #Add decay
    if decay == 'None':
        reduce_lr = tf.keras.callbacks.LearningRateScheduler(scheduler0)
    elif decay == 'exp-0.1':
        reduce_lr=tf.keras.callbacks.LearningRateScheduler(scheduler1)
    elif decay == 'exp-0.05':
        reduce_lr=tf.keras.callbacks.LearningRateScheduler(scheduler2)
    elif decay == 'exp-0.01':
        reduce_lr=tf.keras.callbacks.LearningRateScheduler(scheduler3)


    adam = Adam(lr= lr_value)
    model.compile(loss      = 'categorical_crossentropy',
                  optimizer = adam,
                  metrics   = ['categorical_accuracy', 'top_k_categorical_accuracy'])
    
    #Add early stopping
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20)
    fname_modelFile = 'modelV3.h5'
    mod_path = path_prefix + "DatasetA_train_20180813/model/" + fname_modelFile
    mc = ModelCheckpoint(mod_path, monitor='val_loss', mode='min', save_best_only=True, save_weights_only=True, verbose=1)

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=EPOCH,
        callbacks=[early_stop, mc, reduce_lr]
    )

    print("model training is completed at epoch " + str(early_stop.stopped_epoch))
    return history

In [None]:
global train_classes
with open(path_prefix + 'DatasetA_train_20180813/label_list.txt', 'r') as infile:
    name_classes = [str.strip(line).split('\t') for line in infile]
  #Lets take 30 as Zsl classes
global zsl_classes
indexes = np.arange(0, len(name_classes))
zsl_indexes = np.random.choice(indexes, size=30, replace=False)
zsl_classes = np.array(name_classes)[zsl_indexes]
train_classes= []
for i,obj in enumerate(name_classes):
    if not obj in zsl_classes:
        train_classes.append(obj)

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.2, subset="training", image_size=(224,224), seed=1234, batch_size=BATCH_SIZE, label_mode='categorical')


In [None]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.2, image_size=(224,224), subset="validation", seed=1234, batch_size=BATCH_SIZE, label_mode='categorical')


In [None]:
val_ds

In [None]:
# ---------------------------------------------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------------------------------------------- #
# SET HYPERPARAMETERS

global NUM_CLASS, NUM_ATTR, EPOCH, BATCH_SIZE, LR_VAL_list, learnR_d, scheduler0, scheduler1, scheduler2, scheduler3
NUM_CLASS = 200
NUM_ATTR = 300
BATCH_SIZE = 128
EPOCH = 500
#LR_VAL_list = [1e-3, 1e-4, 5e-5]
LR_VAL = 5e-4
decay = 'exp-0.01'



#DATA
data_dir = path_prefix + 'data/ordered_data/training'
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.2, subset="training", image_size=(224,224), seed=1234, batch_size=BATCH_SIZE, label_mode='categorical')
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir, validation_split=0.2, image_size=(224,224), subset="validation", seed=1234, batch_size=BATCH_SIZE, label_mode='categorical')


### Hyperparameters tuning ###
# for LR_VAL in LR_VAL_list:

# ---------------------------------------------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------------------------------------------- #
# TRAINING PHASE

model = build_model()
model.summary()
train_model(model, train_ds, val_ds, LR_VAL, decay )
# ---------------------------------------------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------------------------------------------- #
# CREATE AND SAVE ZSL MODEL

inp         = model.input
out         = model.layers[-2].output
zsl_model   = Model(inp, out)
print(zsl_model.summary())
#MODELNAME = "V3_lr" + str(LR_VAL) + "_"+str(lr_decay)
save_keras_model(zsl_model, model_path=MODELPATH)

In [None]:
zsl_dir = path_prefix + 'data/ordered_data/zeroshot'

zsl_ds = tf.keras.preprocessing.image_dataset_from_directory(zsl_dir, seed=1234, image_size=(224,224))

print(zsl_ds.class_names)
zsl_labels = zsl_ds.class_names
zsl_class_dict = {}

for label_class in zsl_classes:
    zsl_class_dict[label_class[0]] = label_class[1]

print(zsl_class_dict)

zsl_ds = zsl_ds.take(1)

zsl_model = load_keras_model(model_path=MODELPATH)

with open(path_prefix + 'DatasetA_train_20180813/class_wordembeddings.txt', 'r') as infile:
    class_wordembeddings = [str.strip(line).split(' ') for line in infile]

with open(path_prefix + 'DatasetA_train_20180813/label_list.txt', 'r') as infile:
    name_classes = [str.strip(line).split('\t') for line in infile]

sorted_embedings = [embed for x in name_classes for embed in class_wordembeddings if embed[0] == x[1]]

vectors = np.array(sorted_embedings)[:,1:]
vectors = np.asarray(vectors, dtype=np.float)

classnames = list(np.array(sorted_embedings)[:,0])


tree        = KDTree(vectors)
pred_zsl    = zsl_model.predict(zsl_ds)


print("Able to predict :)")

top5, top3, top1 = 0, 0, 0

print(name_classes)
print(zsl_classes)

for images in zsl_ds:
    print(images[0].numpy().shape)
    print(images[1])
    for i in range(32):
        prediction = pred_zsl[i]
        image = images[0].numpy()[i]
        pred = np.expand_dims(prediction, axis=0)
        dist_5, index_5 = tree.query(pred, k=5)
        pred_labels = [classnames[index] for index in index_5[0]]
        actual_label = zsl_class_dict[zsl_labels[images[1][i]]]
        print(zsl_classes[images[1][i]])
        # image = 1. - image / 127.5
        plt.figure()
        plt.imshow(image.astype("uint8"))
        plt.title("predicted labels for " + str(i) + ": " + str(pred_labels) + " actual label: " + actual_label)

