# Setup

### Imports / Constants / Functions

In [None]:
# Disable Tensorflow's warnings
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [None]:
import tensorflow as tf


import pathlib
# from typing import Any,Union

import numpy as np
import matplotlib.pyplot as plt

In [None]:
IMG_HEIGHT  = 228
IMG_WIDTH   = 228

In [None]:
def display_image(x, n):
    plt.figure(figsize=(20, 5))
    for i in range(n):
        ax = plt.subplot(1, n, i + 1)
        plt.imshow(np.array(x[i]).astype('uint8'), vmax=1)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

def display_image2(x, n):
    plt.figure(figsize=(20, 5))
    for i in range(n):
        ax = plt.subplot(1, n, i + 1)
        plt.imshow(np.array(x[i]), vmax=1)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

### Load Data

In [None]:
BATCH_SIZE = 64

In [None]:
# Train
dataset = tf.keras.utils.image_dataset_from_directory('data/0_data_unseen/',
                                            image_size=(IMG_HEIGHT, IMG_WIDTH),
                                            batch_size=BATCH_SIZE)

# Convert BatchDataset to np array 
dataset = np.concatenate(list(dataset.map(lambda x, y: x))) 

In [None]:
display_image(dataset, 4)

# Binary Classifier (Deliverable 1)

In [None]:
# Import Classifier Model
bin_classifier_path = 'models/bin_10.19.2022_15:29:58_0.93%'
bin_classifier = tf.keras.models.load_model(bin_classifier_path)

In [None]:
photos = []

predictions = bin_classifier.predict(dataset)

for prediction, image in zip(predictions, dataset):
    # If picture is photo, save it
    if int(np.round(prediction)) == 1:
            photos.append(image)

In [None]:
display_image(photos, len(photos))

# Autoencoder (Deliverable 2)

In [None]:
# Prepare Dataset
photos = np.array(photos).astype('float32') / 255.
photos_dataset = tf.data.Dataset.from_tensor_slices(photos).batch(1)

In [None]:
# Import Autoencoder Model
def ssim_accuracy(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, 1.0)

auto_encoder_path = 'models/autoenc_10.19.2022_17:39:08'
auto_encoder = tf.keras.models.load_model(auto_encoder_path, custom_objects={"ssim_accuracy":ssim_accuracy})

In [None]:
encoded_photos_dataset = auto_encoder.predict(photos_dataset)

In [None]:
display_image2(encoded_photos_dataset, len(encoded_photos_dataset))

# Captioning (Deliverable 3)

In [None]:
# Import Classifier Model
cap_decoder_path = 'models/cap_decoder10.19.2022_17:58:23'
decoder = tf.keras.models.load_model(cap_decoder_path)
cap_encoder_path = 'models/cap_encoder10.19.2022_17:58:23'
encoder = tf.keras.models.load_model(cap_encoder_path)

In [None]:
decoder

In [None]:
max_length = 7
attention_features_shape = 64

def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    #hidden = decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(image, 0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0],
                                                 -1,
                                                 img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input,
                                                         features,
                                                         hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

In [None]:
result, attention_plot = evaluate(photos[0])

In [None]:
fig = plt.figure(figsize=(10, 10))

len_result = len(result)
for i in range(len_result):
    temp_att = np.resize(attention_plot[i], (8, 8))
    grid_size = max(int(np.ceil(len_result/2)), 2)
    ax = fig.add_subplot(grid_size, grid_size, i+1)
    ax.set_title(result[i])
    img = ax.imshow(photos[0])
    ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())

plt.tight_layout()
plt.show()

In [None]:
checkpoint_path = "./checkpoints/3_captioning_train"
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

In [None]:
start_epoch = 0
if ckpt_manager.latest_checkpoint:
  start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
  # restoring the latest checkpoint in checkpoint_path
  ckpt.restore(ckpt_manager.latest_checkpoint)

In [None]:
def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden = decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(load_image(image)[0], 0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0],
                                                 -1,
                                                 img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input,
                                                         features,
                                                         hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot