In [1]:
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

from train_utils import *
from preprocess_utils import *

import yaml
import pickle

config_path = 'configs/12heads20kvoc.yaml'
with open(config_path) as f:
    config_file = yaml.safe_load(f)

EMBED_DIM = config_file['EMBED_DIM']
FF_DIM = config_file['FF_DIM']
NUM_HEADS = config_file['NUM_HEADS']
SEQ_LENGTH = config_file['SEQ_LENGTH']
VOCAB_SIZE = config_file['VOCAB_SIZE']
BATCH_SIZE = config_file['BATCH_SIZE']
EPOCHS = config_file['EPOCHS']
LEARNING_RATE = config_file['LEARNING_RATE']
USE_FEATURES = config_file['USE_FEATURES']
COCO = config_file['COCO']
FLICKR30K = config_file['FLICKR30K']
FLICKR8K = config_file['FLICKR8K']

train_files, train_captions, val_files, val_captions = load_data(coco=COCO, flickr30k=FLICKR30K, flickr8k=FLICKR8K)
train_captions = [[cap] for cap in train_captions]
val_captions = [[cap] for cap in val_captions]

vectorization = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=custom_standardization,
)

vectorization.adapt(train_captions)

train_dataset = make_dataset(
    train_files,
    train_captions,
    load_feature, vectorization,
    BATCH_SIZE
)

validation_dataset = make_dataset(
    val_files,
    val_captions,
    load_feature, vectorization,
    BATCH_SIZE
)

caption_model = build_caption_model(
    EMBED_DIM, FF_DIM, NUM_HEADS, SEQ_LENGTH, VOCAB_SIZE, USE_FEATURES
)

In [4]:
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]

In [None]:
# Pickle the vocabulary
pickle.dump({'config': vectorization.get_config(),
             'weights': vectorization.get_weights()}
            , open("20kVocab.pkl", "wb"))

In [2]:
# Early stopping criteria
early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
# Checkpoint criteria
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/',
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
caption_model.compile(
    optimizer=optimizer,
    loss=loss
)

In [None]:
#caption_model.load_weights('checkpoints/3head20k')

In [None]:
history = caption_model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=validation_dataset,
    callbacks=[early_stopping, checkpoint_callback]
)

Epoch 1/7
Epoch 2/7
Epoch 3/7

In [4]:
caption_model.evaluate(validation_dataset)



[4.826329231262207, 0.32416287064552307]

In [None]:
caption_model.save_weights(f'checkpoints/{NUM_HEADS}_{VOCAB_SIZE}')

In [8]:
vocab = vectorization.get_vocabulary()
index_lookup = dict(zip(range(len(vocab)), vocab))
max_decoded_sentence_length = SEQ_LENGTH - 1
#valid_images = list(val_mapping.keys())


def generate_caption(img_path):
    # Select a random image from the validation dataset

    # Read the image from the disk
    img = np.load(img_path)
   # img = load_image(img_path)
    
    # Pass the image to the CNN
    img = tf.expand_dims(img, 0)

    img = caption_model.cnn_model(img)

    # Pass the image features to the Transformer encoder
    encoded_img = caption_model.encoder(img, training=False)

    # Generate the caption using the Transformer decoder
    decoded_caption = "<start> "
    for i in range(max_decoded_sentence_length):
        tokenized_caption = vectorization([decoded_caption])[:, :-1]
        mask = tf.math.not_equal(tokenized_caption, 0)
        predictions = caption_model.decoder(
            tokenized_caption, encoded_img, training=False, mask=mask
        )
    
        sampled_token_index = np.argmax(predictions[0, i, :])
        sampled_token = index_lookup[sampled_token_index]
        #print(f'x{sampled_token}x')
        if sampled_token == " <end>" or sampled_token == "": # second part is added later 
            break
        decoded_caption += " " + sampled_token

    print("PREDICTED CAPTION:", end=" ")
    print(decoded_caption.replace("<start> ", "").replace(" <end>", "").strip())


# Check predictions for a few samples
generate_caption('data/features/coco_000000365655.npy')
#for x in valid_images:
#    generate_caption(x)

PREDICTED CAPTION: bir adam bir kamyon ve bir kamyon
