In [21]:
import tensorflow as tf
# You'll generate plots of attention in order to see which parts of an image
# our model focuses on during captioning

# Scikit-learn includes many helpful utilities
import numpy as np
import os
import json
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split


from models import CNN_Encoder, RNN_Decoder, image_features_extract_model

In [22]:
embedding_dim = 768 # should not be equal units
units = 512
vocab_size = 14500 + 1
CHECKPOINT_FOLDER = "./checkpoint_dis/augmented_attention"
TOKENIZER_FOLDER = './tokenizer/'

In [23]:
encoder= CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, vocab_size)

In [24]:
checkpoint_path = CHECKPOINT_FOLDER
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

In [25]:
ckpt_manager.checkpoints

['./checkpoint_dis/augmented_attention/ckpt-6',
 './checkpoint_dis/augmented_attention/ckpt-7',
 './checkpoint_dis/augmented_attention/ckpt-8',
 './checkpoint_dis/augmented_attention/ckpt-9',
 './checkpoint_dis/augmented_attention/ckpt-10']

### CHOSE CHECKPOINT TO LOAD

In [26]:
## load some checkpoint

n = 4

checkpoint = ckpt_manager.checkpoints[-n]

start_epoch = int(checkpoint.split('-')[-1])
print(f'load from {checkpoint}')
# restoring the latest checkpoint in checkpoint_path
ckpt.restore(checkpoint)

load from ./checkpoint_dis/augmented_attention/ckpt-7


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fd0b411ad30>

In [27]:
with open(os.path.join(TOKENIZER_FOLDER,'tokenizer.json')) as f:
    data = json.load(f)
    tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(data)

In [28]:
data = pd.read_csv('clean_cpations_n_files.csv')
disney_captions = data.caption.to_list()
disney_images= data.folder.to_list()

In [29]:
img_name_train, img_name_val, cap_train, cap_val = train_test_split(disney_images,
                                                                    disney_captions,
                                                                    test_size=0.1,
                                                                    random_state=42)

In [33]:
def load_image(image_path):
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, (299, 299))
        img = tf.image.random_flip_left_right(img)
        img = tf.keras.applications.inception_v3.preprocess_input(img)
        return img #, image_path


def evaluate_temp(image, t):

    hidden = decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(load_image(image), 0)
    img_tensor_val_one, img_tensor_val_two = image_features_extract_model(temp_input)
    img_tensor_val_one = tf.reshape(img_tensor_val_one, (img_tensor_val_one.shape[0], -1, img_tensor_val_one.shape[3]))
    img_tensor_val_two = tf.reshape(img_tensor_val_two, (img_tensor_val_two.shape[0], -1, img_tensor_val_two.shape[3]))
    features_one, features_two = encoder(img_tensor_val_one,img_tensor_val_two)

    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)
    result = []
    score = 0

    for i in range(max_length):
        predictions, hidden, _ = decoder(dec_input, features_one, features_two, hidden)
        
        ### Slighlty randomize prediction - can be changed to Beam Search ###
        
        # SET TO 2 - 4 IF ARGMAX POlICY REQUIRED #
        if i%2==0:
            #predictions = tf.nn.softmax(predictions/t)
            predicted_id = tf.random.categorical(predictions/t, 1)[0][0].numpy()
        else:
            predicted_id = np.argmax(predictions[0].numpy())
        result.append(tokenizer.index_word[predicted_id])

        if tokenizer.index_word[predicted_id] == '<end>':
            return result, score/len(result)

        dec_input = tf.expand_dims([predicted_id], 0)
        score+=predictions[0][predicted_id].numpy()
    
    return result, score/len(result)

In [31]:
img_name_path = [p.replace('features','disney_old/disney_img') for  p  in img_name_val]

In [None]:
#### captions on the validation set

#img_name_val, cap_val
max_length = 27
rid = np.random.randint(0, len(img_name_val))
image = img_name_val[rid]
real_caption = cap_val[rid][1:]
print ('Real Caption:', real_caption)

for t in [0.5,0.7,0.9,1,1.2,1.5]:
    res, a = evaluate_temp(image, t)
    print(f"{t}, {' '.join(res)}, score {a}")
Image.open(image)