In [None]:
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

from train_utils import *
from preprocess_utils import *

import yaml
import pickle

config_path = 'configs/12heads20kvoc.yaml'

with open(config_path) as f:
    config_file = yaml.safe_load(f)

EMBED_DIM = config_file['EMBED_DIM']
FF_DIM = config_file['FF_DIM']
NUM_HEADS = config_file['NUM_HEADS']
SEQ_LENGTH = config_file['SEQ_LENGTH']
VOCAB_SIZE = config_file['VOCAB_SIZE']
BATCH_SIZE = config_file['BATCH_SIZE']
EPOCHS = config_file['EPOCHS']
LEARNING_RATE = config_file['LEARNING_RATE']
USE_FEATURES = config_file['USE_FEATURES']
COCO = config_file['COCO']
FLICKR30K = config_file['FLICKR30K']
FLICKR8K = config_file['FLICKR8K']

train_files, train_captions, val_files, val_captions = load_data(coco=COCO, flickr30k=FLICKR30K, flickr8k=FLICKR8K)
train_captions = [[cap] for cap in train_captions]
val_captions = [[cap] for cap in val_captions]

vectorization = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=custom_standardization,
)

vectorization.adapt(train_captions)

train_dataset = make_dataset(
    train_files,
    train_captions,
    load_feature, vectorization,
    BATCH_SIZE
)

validation_dataset = make_dataset(
    val_files,
    val_captions,
    load_feature, vectorization,
    BATCH_SIZE
)

caption_model = build_caption_model(
    EMBED_DIM, FF_DIM, NUM_HEADS, SEQ_LENGTH, VOCAB_SIZE, USE_FEATURES
)

In [None]:
# Pickle the vocabulary
pickle.dump({'config': vectorization.get_config(),
             'weights': vectorization.get_weights()}
            , open("20000Voc.pkl", "wb"))

In [None]:
# Early stopping criteria
early_stopping = keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
# Checkpoint criteria
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='checkpoints/',
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
caption_model.compile(
    optimizer=optimizer,
    loss=loss
)

In [None]:
history = caption_model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=validation_dataset,
    callbacks=[early_stopping, checkpoint_callback]
)

In [None]:
caption_model.save_weights(f'checkpoints/{NUM_HEADS}_{VOCAB_SIZE}')