Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore model from config and weights for model training #9

Merged
merged 8 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ target/
# exclude data from source control by default
/data/

# exclude model configs and weights
/models/config
/models/weights

# Mac OS-specific storage files
.DS_Store

Expand Down
Empty file added models/config/.gitkeep
Empty file.
Empty file added models/weights/.gitkeep
Empty file.
6 changes: 4 additions & 2 deletions src/features/nlp/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ def text_to_sequence(self, text):
def sequence_to_text(self, sequence):
inv_vocab = dict((v, k) for k, v in self.vocab.items())
return " ".join(inv_vocab[x] for x in sequence if x >= 1)

def clean_text(self, text):
special_tokens = ("<start>", "<end>", "<oov>")
return " ".join(token for token in text.split(" ") if token not in special_tokens)
return " ".join(
token for token in text.split(" ") if token not in special_tokens
)

def save_to_json(self):
"""Save the current tokenizer (with vocab) as JSON
Expand Down
20 changes: 13 additions & 7 deletions src/models/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,30 @@
default="",
help="Filename (without extension) of the model config and weights to load",
)
@click.option("--mode", default="test", help="Evaluation mode: 'val' or 'test', default is 'test'")
@click.option(
"--mode", default="test", help="Evaluation mode: 'val' or 'test', default is 'test'"
)
def main(model_filename, mode):
model = build_saved_model(model_filename)
model = build_saved_model(model_filename, mode="inference")
tokenizer = CustomSpacyTokenizer.from_json()
dataset = read_split_dataset(mode, model.img_shape, model.caption_length, batch_size=5)
dataset = read_split_dataset(
mode, model.img_shape, model.caption_length, batch_size=5
)

print(f"MODEL {model_filename}")
for beam_width in 1, 3, 5:
print(f"BEAM WIDTH: {beam_width}")
total_bleu = [0., 0.]
total_bleu = [0.0, 0.0]
n = 0

for (image, captions), _ in tqdm(dataset):
n += 1
prediction = predict(model, image[0], tokenizer, beam_width)
clean_prediction = tokenizer.clean_text(prediction)

captions = [tokenizer.sequence_to_text(captions[i].numpy()) for i in range(5)]

captions = [
tokenizer.sequence_to_text(captions[i].numpy()) for i in range(5)
]
clean_captions = [tokenizer.clean_text(caption) for caption in captions]

example_bleu = bleu(clean_captions, clean_prediction)
Expand All @@ -45,4 +51,4 @@ def main(model_filename, mode):


if __name__ == "__main__":
main()
main()
9 changes: 4 additions & 5 deletions src/models/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def bleu(captions_true, caption_pred):
n: int
N-grams to consider (correspond to computing BLEU-n)
"""
weights = [
(1, 0, 0, 0),
(1./4., 1./4., 1./4., 1./4.)
]
weights = [(1, 0, 0, 0), (1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0)]
captions_true = [caption.split(" ") for caption in captions_true]
caption_pred = caption_pred.split(" ")

return nltk.translate.bleu_score.sentence_bleu(captions_true, caption_pred, weights=weights)
return nltk.translate.bleu_score.sentence_bleu(
captions_true, caption_pred, weights=weights
)
6 changes: 3 additions & 3 deletions src/models/predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(model_filename):
]
images = [load_image_jpeg(filename) for filename in image_filenames]

model = build_saved_model(model_filename)
model = build_saved_model(model_filename, mode="inference")

predictions = []
for image in images:
Expand All @@ -86,7 +86,7 @@ def main(model_filename):

# visualize results
# Create a grid of subplots
fig, axs = plt.subplots(math.ceil(len(images)/3), 3, figsize=(5, 5))
fig, axs = plt.subplots(math.ceil(len(images) / 3), 3, figsize=(5, 5))

if len(images) == 1:
image, caption = predictions[0]
Expand All @@ -95,7 +95,7 @@ def main(model_filename):
else:
i = 0
j = 0
for (image, caption) in predictions:
for image, caption in predictions:
# Plot each image in a separate subplot with a caption
axs[i][j].imshow(image)
axs[i][j].set_title(caption, fontsize="8")
Expand Down
65 changes: 53 additions & 12 deletions src/models/train_model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os
import json
import time
import click
from pathlib import Path


import logging
from dotenv import find_dotenv, load_dotenv

import tensorflow as tf
from tensorflow import keras

import src.models.metrics as metrics
from src.features.nlp.tokenizer import CustomSpacyTokenizer
from src.models.model import ShowAndTell
from src.models.read_data import read_split_dataset
from src.models.utils import build_saved_model


@click.command()
Expand All @@ -27,10 +33,23 @@
@click.option(
"--learning_rate", default=0.0001, help="Learning rate for the Adam optimizer"
)
def main(img, n_rnn_neurons, embedding_size, batch_size, epochs, learning_rate):
@click.option("--model_filename", default=None, help="File name of the saved model")
def main(
img,
n_rnn_neurons,
embedding_size,
batch_size,
epochs,
learning_rate,
model_filename,
):
tf.random.set_seed(42)
logger = logging.getLogger(__name__)

project_dir = Path(__file__).resolve().parents[2]
models_dir = project_dir / "models"
config_dir = models_dir / "config"
weights_dir = models_dir / "weights"

# train config
img_shape = (img, img, 3)
Expand All @@ -42,11 +61,29 @@ def main(img, n_rnn_neurons, embedding_size, batch_size, epochs, learning_rate):
train_dataset = read_split_dataset("train", img_shape, caption_length, batch_size)
val_dataset = read_split_dataset("val", img_shape, caption_length, batch_size)

# create model - show and tell paper configuration
model = ShowAndTell(
n_rnn_neurons, img_shape, caption_length, embedding_size, vocab_size
)
model.build(input_shape=[(None,) + img_shape, (None, caption_length)])
if model_filename is not None:
# load config from saved json file and load model from saved h5 file
try:
logger.info(
f"Model config loaded from: {os.path.abspath(config_dir)}{os.sep}{model_filename}.json"
)
logger.info(
f"Model weights loaded from: {os.path.abspath(weights_dir)}{os.sep}{model_filename}.h5"
)
model = build_saved_model(model_filename, mode="training")
logger.info(
f"Model loaded from: {os.path.abspath(weights_dir)}{os.sep}{model_filename}.h5"
)
except OSError as e:
logger.error(e)
raise
else:
# create model - show and tell paper configuration
model = ShowAndTell(
n_rnn_neurons, img_shape, caption_length, embedding_size, vocab_size
)
model.build(input_shape=[(None,) + img_shape, (None, caption_length)])

model.summary()

# train model
Expand All @@ -59,10 +96,6 @@ def main(img, n_rnn_neurons, embedding_size, batch_size, epochs, learning_rate):

# save weights and config
timestr = time.strftime("%Y%m%d-%H%M%S")
weights_dir = models_dir / "weights"
os.makedirs(weights_dir, exist_ok=True)
config_dir = models_dir / "config"
os.makedirs(config_dir, exist_ok=True)

model_config_filename = f"{os.path.abspath(config_dir)}{os.sep}train_{timestr}.json"
weights_filename = f"{os.path.abspath(weights_dir)}{os.sep}train_{timestr}.h5"
Expand All @@ -71,10 +104,18 @@ def main(img, n_rnn_neurons, embedding_size, batch_size, epochs, learning_rate):
with open(model_config_filename, "w") as file:
file.write(model_json_str)
model.save_weights(weights_filename)
print(
f"saved model config and weights in {model_config_filename} and {weights_filename}"

logger.info(
f"Saved model config and weights in {model_config_filename} and {weights_filename}"
)


if __name__ == "__main__":
log_fmt = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.INFO, format=log_fmt)

# find .env automagically by walking up directories until it's found, then
# load up the .env entries as environment variables
load_dotenv(find_dotenv())

main()
34 changes: 22 additions & 12 deletions src/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,35 @@ def load_image_jpeg(filename):
return image


def build_saved_model(model_filename):
def build_saved_model(model_filename, mode="training"):
project_dir = Path(__file__).resolve().parents[2]
model_config_dir = project_dir / "models" / "config"
weights_dir = project_dir / "models" / "weights"
if model_filename == "":

if not model_filename:
last_name = os.listdir(model_config_dir)[-1]
model_config_path = model_config_dir / last_name
weights_path = weights_dir / f"{last_name.removesuffix('.json')}.h5"
else:
model_config_path = model_config_dir / f"{model_filename}.json"
weights_path = weights_dir / f"{model_filename}.h5"

with model_config_path.open("r") as model_config_file:
model = keras.models.model_from_json(
model_config_file.read(), custom_objects={"ShowAndTell": ShowAndTell}
)
# build the model with the input shapes
model.build(input_shape=[[1] + model.img_shape, [1, model.caption_length]])
# load corresponding weights
model.load_weights(weights_path, by_name=True)
model.summary()
return model
try:
with model_config_path.open("r") as model_config_file:
model = keras.models.model_from_json(
model_config_file.read(), custom_objects={"ShowAndTell": ShowAndTell}
)
# build the model with the input shapes
batch_size = None if mode == "training" else 1
model.build(
input_shape=[
[batch_size] + model.img_shape,
[batch_size, model.caption_length],
]
)
# load corresponding weights
model.load_weights(weights_path, by_name=True)
model.summary()
return model
except OSError as e:
raise e