# # Nearest Neighbor Search using NNCLR

**Author:** [Lennart Seeger], [Rishit Dagli]<br>
**Date created:** 2021/09/13<br>
**Last modified:** 2023/03/24<br>

In [None]:
import os
import sys
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from __future__ import print_function
from tensorflow.keras import models, layers, Input
from keras.applications.resnet import ResNet50
import tensorflow_addons as tfa
from tensorflow import keras

sys.path.insert(1, '../src')
%load_ext autoreload
%autoreload 2

from models.nnclr import NNCLR, get_augmenter, get_encoder
from data.datasets import get_mlrsnet, get_denmark
from model_utility.learning_rate_scheduler import WarmUpCosine
from supportive.evaluate import evaluate_extractor
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

In [None]:
# Dataset hyperparameters
image_size = 64
image_channels = 3
batch_size=512
x_train=np.load("../data/avg_std30.npy")
steps_per_epoch = len(x_train)//batch_size

# optimizer
learning_rate = 0.001
weight_decay = 0.0001

# Algorithm hyperparameters
num_epochs = 50
width = 2048
temperature = 0.1
queue_size=1024 # needs to be higher than batch_size
input_shape = (image_size, image_size, 3)

# augmentation definitions
contrastive_augmentation = {
    "brightness": 0.5,
    "name": "contrastive_augmenter",
    "scale": (0.2, 1.0),
}
classification_augmentation = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}
x_test, y_test = get_denmark(image_size=image_size)

"""Encoder implementation"""
baseModel = ResNet50(weights='imagenet', include_top=False, input_tensor=Input(shape=(image_size, image_size, 3)))

In [None]:
# dataset handling
y_train=np.array([-1]*len(x_train))
train_ds = tf.data.Dataset.from_tensor_slices((x_train,y_train))

In [None]:
total_steps = steps_per_epoch * num_epochs
warmup_epoch_percentage = 0.15
warmup_steps = int(total_steps * warmup_epoch_percentage)
scheduled_lrs = WarmUpCosine(
    learning_rate_base=learning_rate,
    total_steps=total_steps,
    warmup_learning_rate=0.0,
    warmup_steps=warmup_steps,
)

lrs = [scheduled_lrs(step) for step in range(total_steps)]
plt.plot(lrs)
plt.xlabel("Step", fontsize=14)
plt.ylabel("LR", fontsize=14)
plt.show()

In [None]:
# Contrastive pretraining
optimizer=tfa.optimizers.AdamW(learning_rate=scheduled_lrs, weight_decay=weight_decay)

pretraining_model = NNCLR(**contrastive_augmentation, temperature=temperature, queue_size=queue_size, image_size=image_size, baseModel=baseModel, width=width)
pretraining_model.compile(
    contrastive_optimizer=optimizer,
    probe_optimizer=tf.keras.optimizers.Adam(),
    run_eagerly=True,
)
pretraining_history = pretraining_model.fit(
    train_ds.batch(batch_size, drop_remainder=True).repeat(), epochs=num_epochs, validation_data=None,steps_per_epoch = steps_per_epoch,
                    #validation_steps = validation_steps,#, test_dataset#train_dataset, # val_generator
)

In [None]:
# save model
path="../model/nnclr/model"
pretraining_model.encoder.save(path)
model_loaded = keras.models.load_model(path)

In [None]:
# build extractor
extractor_model = keras.Sequential(
    [
        layers.Input((image_size, image_size, 3)),
        pretraining_model.classification_augmenter,
        model_loaded
    ],
    name="extraction_model",
)

In [None]:
model_resnet = ResNet50(weights='imagenet', include_top=False,input_shape=(image_size,image_size,3),pooling="avg")
print(x_test.shape, y_test.shape)
print("extractor_model: ", evaluate_extractor(extractor_model.predict, x_test, y_test, neighbors=10))
print("model_resnet: ", evaluate_extractor(model_resnet.predict, x_test, y_test, neighbors=10))

In [None]:
with open('../model/nnclr/results.txt', 'a') as file:
    file.write('\n------------------------------------------')
    file.write("\nextractor_model: : "+str(evaluate_extractor(extractor_model.predict, x_test, y_test, neighbors=10)))
    file.write("\nbatch_size: "+str(batch_size))
    file.write("\nnum_epochs: "+str(num_epochs))
    file.write("\nsteps_per_epoch: "+str(steps_per_epoch))
    file.write("\noptimizer: "+str(optimizer))
    file.write("\ntemperature: "+str(temperature))
    file.write("\ncontrastive_augmentation: "+str(contrastive_augmentation))
    file.write("\nclassification_augmentation: "+str(classification_augmentation))
    file.write("\nqueue_size: "+str(queue_size))
    stringlist = []
    baseModel.summary(print_fn=lambda x: stringlist.append(x))
    short_model_summary = "\n".join(stringlist)
    file.write("\n"+short_model_summary)