<a href="https://colab.research.google.com/github/namirinz/CompetitionCode-beta/blob/master/Image_Captioning/KME_Image_captioning_bahdanau.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install nami --upgrade
!pip install tqdm
!rm -rf .keras

Collecting nami
  Downloading https://files.pythonhosted.org/packages/9e/2c/2458895840e65939b355b2f34512d73ba3dadad3f746c53dec58470962f6/nami-1.2.1.12-py3-none-any.whl
Installing collected packages: nami
Successfully installed nami-1.2.1.12


In [2]:
!nvidia-smi

Mon Nov  9 13:50:10 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import nami
import time
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange

Downloading data from https://firebasestorage.googleapis.com/v0/b/ysc-kme-25095.appspot.com/o/model.h5?alt=media&token=b88ce51f-0ffe-406d-9875-d5593d10f619
Downloading data from https://firebasestorage.googleapis.com/v0/b/ysc-kme-25095.appspot.com/o/vocab.json?alt=media&token=c82c614a-07c7-4566-a47a-9c3626422833


In [4]:
(X_train, y_train), (X_test, y_test) = nami.datasets.kme.load_data()

Downloading data from https://firebasestorage.googleapis.com/v0/b/ysc-kme-25095.appspot.com/o/Dict_segment.json?alt=media&token=d89b3c42-a078-41b7-ab37-689c892d0957
Downloading data from https://firebasestorage.googleapis.com/v0/b/ysc-kme-25095.appspot.com/o/kme_dataset.npy?alt=media&token=6d8172ad-807d-4faa-b1b3-79470ce1b302


# Import Dataset

In [5]:
from nami.datasets.kme import load_data
(image_train, caption_train), (image_test, caption_test) = load_data()
print(image_train.shape, caption_train.shape)
print(image_test.shape, caption_test.shape)

(545, 224, 224, 3) (545,)
(137, 224, 224, 3) (137,)


## Preprocessing Dataset

In [6]:
from nami.AI.kme_tokenize import Tokenizer
tokenizer = Tokenizer()



In [7]:
tokenizer.fit_on_texts(caption_train)

In [8]:
text2seq_train = tokenizer.text_to_sequences(caption_train)
seq2text_train = tokenizer.sequences_to_text(text2seq_train)

In [9]:
text2seq_test = tokenizer.text_to_sequences(caption_test)
seq2text_test = tokenizer.sequences_to_text(text2seq_test)

In [10]:
text2seq_train.shape

(545, 35)

## Hyperparameter

In [11]:
BATCH_SIZE = 16
BUFFER_SIZE = 1000
units = 512
vocab_size = len(tokenizer.word2index)

num_steps_train = len(image_train) // BATCH_SIZE
num_steps_test = len(image_test) // BATCH_SIZE

max_length_train = text2seq_train.shape[1]
max_length_test = text2seq_test.shape[1]

encoding_size = 256 # Determines dimension of the encodings of images

## tf.data.Dataset ImageDataGenerator

In [12]:
img_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range = 30,
    width_shift_range = 0.05,
    height_shift_range = 0.05,
    horizontal_flip = True,
    vertical_flip = True,
)

In [13]:
with tf.device('/cpu:0'):
  dataset = tf.data.Dataset.from_generator(
    lambda : img_gen_train.flow(x = image_train, y = text2seq_train, batch_size = 16, shuffle = True),
    output_types =  (tf.float32, tf.float32),
  ).cache(      
  ).prefetch(
      tf.data.experimental.AUTOTUNE
  )

In [14]:
img_gen_val = tf.keras.preprocessing.image.ImageDataGenerator()

In [15]:
with tf.device('/cpu:0'):
  dataset_val = tf.data.Dataset.from_generator(
    lambda : img_gen_val.flow(x = image_test, y = text2seq_test, batch_size = 16, shuffle = False),
    output_types =  (tf.float32, tf.float32),
  ).cache(      
  ).prefetch(
    tf.data.experimental.AUTOTUNE
  )

In [16]:
#dataset = tf.data.Dataset.from_tensor_slices((image_train, text2seq_train)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#dataset_val = tf.data.Dataset.from_tensor_slices((image_test, text2seq_test)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# Model Architecture

## EfficientNetB7 

In [17]:
from tensorflow.keras.applications import efficientnet

EffNet = efficientnet.EfficientNetB7(include_top = False, weights=None, input_shape=(224, 224, 3))
# Output shape = [batch_size, 7, 7, 2560]

## CNN Encoder (Encoder output of feature extraction)

In [18]:
from tensorflow.keras.models import Sequential
encoder = Sequential([
    tf.keras.layers.Lambda(efficientnet.preprocess_input, input_shape=(224, 224, 3)),
    EffNet,
    tf.keras.layers.Dense(encoding_size,activation='relu',name="encoding_layer"),
    tf.keras.layers.Reshape(target_shape=(7*7, encoding_size))
], name= "CNN_feature_extraction")

In [19]:
encoder.summary()

Model: "CNN_feature_extraction"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lambda (Lambda)              (None, 224, 224, 3)       0         
_________________________________________________________________
efficientnetb7 (Functional)  (None, 7, 7, 2560)        64097687  
_________________________________________________________________
encoding_layer (Dense)       (None, 7, 7, 256)         655616    
_________________________________________________________________
reshape (Reshape)            (None, 49, 256)           0         
Total params: 64,753,303
Trainable params: 64,442,576
Non-trainable params: 310,727
_________________________________________________________________


## Bahdanau Attention (Class)

In [20]:
class BahdanauAttention(tf.keras.Model):
  def __init__(self, units):
    super(BahdanauAttention, self).__init__()
    self.W1 = tf.keras.layers.Dense(units)
    self.W2 = tf.keras.layers.Dense(units)
    self.V = tf.keras.layers.Dense(1)

  def call(self, features, hidden):
    # features(CNN_encoder output) shape == (batch_size, 25, embedding_dim)

    # hidden shape == (batch_size, hidden_size)
    # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
    hidden_with_time_axis = tf.expand_dims(hidden, 1)

    # attention_hidden_layer shape == (batch_size, 25, units)
    attention_hidden_layer = (tf.nn.tanh(self.W1(features) +
                                         self.W2(hidden_with_time_axis)))

    # score shape == (batch_size, 25, 1)
    # This gives you an unnormalized score for each image feature.
    score = self.V(attention_hidden_layer)

    # attention_weights shape == (batch_size, 25, 1)
    attention_weights = tf.nn.softmax(score, axis=1)

    # context_vector shape after sum == (batch_size, hidden_size)
    context_vector = attention_weights * features
    context_vector = tf.reduce_sum(context_vector, axis=1)

    return context_vector, attention_weights

## RNN Decoder (Class)


In [21]:
class RNN_Decoder(tf.keras.Model):
  def __init__(self, encoding_size, units, vocab_size):
    super(RNN_Decoder, self).__init__()
    self.units = units

    self.embedding = tf.keras.layers.Embedding(vocab_size, encoding_size)
    self.gru = tf.keras.layers.GRU(self.units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc1 = tf.keras.layers.Dense(self.units)
    self.fc2 = tf.keras.layers.Dense(vocab_size)

    self.attention = BahdanauAttention(self.units)

  def __call__(self,features, x, hidden):
    # defining attention as a separate model
    context_vector, attention_weights = self.attention(features, hidden)

    # x shape after passing through embedding == (batch_size, 1, encoding_size)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, encoding_size + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x, initial_state=[hidden])

    # shape == (batch_size, max_length, hidden_size)
    x = self.fc1(output)

    # x shape == (batch_size * max_length, hidden_size)
    x = tf.reshape(x, (-1, x.shape[2]))

    # output shape == (batch_size * max_length, vocab)
    x = self.fc2(x)

    return x, state, attention_weights

  def reset_state(self, batch_size):
    return tf.zeros((batch_size, self.units))

In [22]:
decoder = RNN_Decoder(encoding_size = encoding_size, units = units, vocab_size = vocab_size)

# Sub Class Model

In [23]:
"""
class ImageBahda(tf.keras.Model):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.hidden_state = None
    self.decoder_input = None
  
  def compile(self, word_loss_fn, eos_loss_fn, encoder_opt, decoder_opt, metrics):
    super().__init__()
    self.word_loss_fn = word_loss_fn
    self.eos_loss_fn = eos_loss_fn
    self.encoder_opt = encoder_opt
    self.decoder_opt = decoder_opt
    self.metrics = metrics

  def train_step(self, input):
    image, caption = input
    batch_size = tf.shape(input)[0]
    
    self.hidden_state = self.decoder.reset_state(batch_size = batch_size)
    self.decoder_input = tf.expand_dims([tokenizer.word2index['<start>']] * batch_size, 1)
    
    with tf.GradientTape() as dec_tape, tf.GradientTape() as enc_tape:
      img_feature = encoder(img_tensor)
        
      for i in range(1, max_length_train):
        # passing the features through the decoder
        predictions, hidden_state, _ = decoder(img_feature, dec_input, hidden_state)

        word_loss += word_loss_function(target[:,i], predictions)
        eos_loss += eos_loss_function(target[:, i], predictions)
        loss += word_loss + eos_loss

        # using teacher forcing
        dec_input = tf.expand_dims(target[:, i], 1)
      
      # Metrics
      #train_acc.update_state(target[:, i], predictions)
"""

"\nclass ImageBahda(tf.keras.Model):\n  def __init__(self, encoder, decoder):\n    super().__init__()\n    self.encoder = encoder\n    self.decoder = decoder\n    self.hidden_state = None\n    self.decoder_input = None\n  \n  def compile(self, word_loss_fn, eos_loss_fn, encoder_opt, decoder_opt, metrics):\n    super().__init__()\n    self.word_loss_fn = word_loss_fn\n    self.eos_loss_fn = eos_loss_fn\n    self.encoder_opt = encoder_opt\n    self.decoder_opt = decoder_opt\n    self.metrics = metrics\n\n  def train_step(self, input):\n    image, caption = input\n    batch_size = tf.shape(input)[0]\n    \n    self.hidden_state = self.decoder.reset_state(batch_size = batch_size)\n    self.decoder_input = tf.expand_dims([tokenizer.word2index['<start>']] * batch_size, 1)\n    \n    with tf.GradientTape() as dec_tape, tf.GradientTape() as enc_tape:\n      img_feature = encoder(img_tensor)\n        \n      for i in range(1, max_length_train):\n        # passing the features through the decode

In [24]:
#model =  ImageBahda(encoder = encoder, decoder = decoder)

# Custom Loss Function


In [25]:
def word_loss_function(real, pred):
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = 'none')
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

In [26]:
def eos_loss_function(real, pred):
  loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

  end_token = tokenizer.word2index['<end>']
  pred_id = tf.math.argmax(pred, axis=-1)

  real_eos = tf.cast(tf.math.equal(real, end_token), dtype=tf.float32)
  pred_eos = tf.cast(tf.math.equal(pred_id, end_token), dtype=tf.float32)

  loss_ = loss_object(real_eos, pred_eos)
  return loss_

# Metrics

In [27]:
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
val_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

# Train Step

In [28]:
from tensorflow.keras.optimizers import Adam, Nadam
enc_opt = Adam(learning_rate=0.0004, beta_1=0.9, beta_2=0.999)
dec_opt = Adam(learning_rate=0.0002, beta_1=0.7, beta_2=0.799)

In [29]:
@tf.function
def train_step(img_tensor, target):
  batch_size = img_tensor.shape[0]
  
  hidden_state = decoder.reset_state(batch_size = batch_size)
  dec_input = tf.expand_dims([tokenizer.word2index['<start>']] * batch_size, 1)

  loss, total_loss = 0, 0
  word_loss, total_word_loss = 0, 0
  eos_loss, total_eos_loss = 0, 0

  with tf.GradientTape() as dec_tape, tf.GradientTape() as enc_tape:
    img_feature = encoder(img_tensor)
        
    for i in range(1, max_length_train):
      # passing the features through the decoder
      predictions, hidden_state, _ = decoder(img_feature, dec_input, hidden_state)

      word_loss += word_loss_function(target[:,i], predictions)
      eos_loss += eos_loss_function(target[:, i], predictions)
      loss += (word_loss + eos_loss)

      # using teacher forcing
      dec_input = tf.expand_dims(target[:, i], 1)
      train_acc.update_state(target[:, i], predictions)

  total_loss = (loss / max_length_train)
  total_eos_loss = (eos_loss / max_length_train)
  total_word_loss = (loss / max_length_train)

  decoder_trainable_variables = decoder.trainable_variables
  encoder_trainable_variables = encoder.trainable_variables

  dec_gradients = dec_tape.gradient(loss, decoder_trainable_variables)
  enc_gradients = enc_tape.gradient(loss, encoder_trainable_variables)
    
  dec_opt.apply_gradients(zip(dec_gradients, decoder_trainable_variables))
  enc_opt.apply_gradients(zip(enc_gradients, encoder_trainable_variables))

  return total_word_loss, total_eos_loss, total_loss

# Validate Step

In [30]:
@tf.function
def valid_step(img_tensor, target):
  batch_size = target.shape[0]
  dec_input = tf.expand_dims([tokenizer.word2index['<start>']] * batch_size, 1)
  hidden_state = decoder.reset_state( batch_size = batch_size )

  with tf.GradientTape() as tape:  
    img_feature = encoder(img_tensor)
    
    eos_loss, word_loss = 0, 0

    for i in range(1, max_length_test):
      predictions, hidden_state, _ = decoder(img_feature, dec_input, hidden_state)
        
      word_loss += word_loss_function(target[:,i], predictions)
      
      eos_loss += eos_loss_function(target[:,i], predictions) 
        
      dec_input = tf.expand_dims(target[:, i], 1)

      val_acc.update_state(target[:, i], predictions)

  word_loss /= max_length_test
  eos_loss /= max_length_test
  return word_loss, eos_loss, word_loss + eos_loss

# Training

In [31]:
def add_plot_arr(loss1, loss2, loss3, arr1, arr2, arr3, num_step):
  arr1.append(loss1 / num_step)
  arr2.append(loss2 / num_step)
  arr3.append(loss3 / num_step)
  return arr1, arr2, arr3

In [32]:
def fit(epochs):
  loss1, loss2, loss3 = [], [], []
  val_loss1, val_loss2, val_loss3 = [], [], []
  min_loss = 9999
  for epoch in trange(1, epochs+1):
      train_total_loss = 0
      train_total_word_loss, val_total_word_loss = 0, 0
      train_total_eos_loss, val_total_eos_loss = 0, 0

      val_total_loss, val_loss = 0, 0
      
      train_acc.reset_states()
      val_acc.reset_states()
      # Training Step
      for img_tensor, target in tqdm(dataset.take(num_steps_train), leave = False):
        
        word_loss, eos_loss, loss = train_step(img_tensor, target)
        train_total_loss += loss
        train_total_word_loss += word_loss 
        train_total_eos_loss += eos_loss
      
      # Validation Step
      for img_tensor ,target in tqdm(dataset_val.take(num_steps_test), leave = False):
        val_word_loss, val_eos_loss, val_loss = valid_step(img_tensor ,target)
        val_total_word_loss += val_word_loss
        val_total_eos_loss += val_eos_loss
        val_total_loss += val_loss

      if val_total_loss < min_loss:
        print("SAVE WEIGHT")
        min_loss = val_total_loss
        encoder.save_weights('encoder_weight.h5')
        decoder.save_weights('decoder_weight.h5')

      # storing the epoch end loss value to plot later
      loss1, loss2, loss3 = add_plot_arr(
          train_total_loss, train_total_word_loss, train_total_eos_loss,
          loss1, loss2, loss3, num_steps_train)

      val_loss1, val_loss2, val_loss3 = add_plot_arr(
          val_total_loss, val_total_word_loss, val_total_eos_loss,
          val_loss1, val_loss2, val_loss3, num_steps_test)

      print(f"EPOCH : {epoch} word_loss : {(train_total_word_loss / num_steps_train):.4f} "
          f"eos_loss : {(train_total_eos_loss / num_steps_train):.4f} total_loss : {(train_total_loss / num_steps_train):.4f} "
          f"acc : {train_acc.result().numpy():.4f}\n"
          f"\t val_word_loss : {(val_total_word_loss / num_steps_test):.4f} val_eos_loss : {(val_total_eos_loss / num_steps_test):.4f} "
          f"val_total_loss : {(val_total_loss / num_steps_test):.4f} "
          f"val_acc : {val_acc.result().numpy():.4f}"
          )
      print('-'*30, end='\n\n')
  return {'loss' : loss1, 'word_loss' : loss2, 'eos_loss' : loss3, 'acc': round(train_acc.result().numpy(), 4),
          'val_loss' : val_loss1, 'val_word_loss': val_loss2, 'val_eos_loss' : val_loss3, 'val_acc' : round(val_acc.result().numpy(), 4)
          }

In [33]:
history = fit(epochs = 40)

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 1 word_loss : 45.1092 eos_loss : 0.7729 total_loss : 45.1092 acc : 0.0617
	 val_word_loss : 1.2175 val_eos_loss : 1.0957 val_total_loss : 2.3132 val_acc : 0.0799
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 2 word_loss : 43.7453 eos_loss : 0.9091 total_loss : 43.7453 acc : 0.0667
	 val_word_loss : 1.1833 val_eos_loss : 0.6715 val_total_loss : 1.8547 val_acc : 0.0814
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 3 word_loss : 41.3329 eos_loss : 0.8242 total_loss : 41.3329 acc : 0.0711
	 val_word_loss : 1.1212 val_eos_loss : 0.7144 val_total_loss : 1.8356 val_acc : 0.1006
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 4 word_loss : 40.0342 eos_loss : 0.9047 total_loss : 40.0342 acc : 0.0938
	 val_word_loss : 1.0229 val_eos_loss : 1.0656 val_total_loss : 2.0885 val_acc : 0.1341
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 5 word_loss : 39.4003 eos_loss : 1.0711 total_loss : 39.4003 acc : 0.1201
	 val_word_loss : 0.9206 val_eos_loss : 1.0488 val_total_loss : 1.9694 val_acc : 0.1447
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 6 word_loss : 37.1974 eos_loss : 1.0722 total_loss : 37.1974 acc : 0.1344
	 val_word_loss : 0.8353 val_eos_loss : 1.0430 val_total_loss : 1.8783 val_acc : 0.1608
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 7 word_loss : 35.3814 eos_loss : 1.0702 total_loss : 35.3814 acc : 0.1412
	 val_word_loss : 0.7841 val_eos_loss : 1.0382 val_total_loss : 1.8223 val_acc : 0.1666
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 8 word_loss : 33.7336 eos_loss : 1.0223 total_loss : 33.7336 acc : 0.1494
	 val_word_loss : 0.7602 val_eos_loss : 1.0448 val_total_loss : 1.8050 val_acc : 0.1656
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 9 word_loss : 33.3225 eos_loss : 1.0614 total_loss : 33.3225 acc : 0.1545
	 val_word_loss : 0.7286 val_eos_loss : 1.0391 val_total_loss : 1.7677 val_acc : 0.1752
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 10 word_loss : 31.9261 eos_loss : 1.0106 total_loss : 31.9261 acc : 0.1596
	 val_word_loss : 0.6967 val_eos_loss : 1.0466 val_total_loss : 1.7433 val_acc : 0.1779
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 11 word_loss : 29.7634 eos_loss : 0.8816 total_loss : 29.7634 acc : 0.1635
	 val_word_loss : 0.6852 val_eos_loss : 1.0387 val_total_loss : 1.7239 val_acc : 0.1812
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 12 word_loss : 31.2716 eos_loss : 1.0517 total_loss : 31.2716 acc : 0.1659
	 val_word_loss : 0.6612 val_eos_loss : 0.9744 val_total_loss : 1.6356 val_acc : 0.1872
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 13 word_loss : 30.5332 eos_loss : 1.0299 total_loss : 30.5332 acc : 0.1682
	 val_word_loss : 0.6463 val_eos_loss : 0.9958 val_total_loss : 1.6421 val_acc : 0.1885
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 14 word_loss : 30.4131 eos_loss : 1.0578 total_loss : 30.4131 acc : 0.1712
	 val_word_loss : 0.6392 val_eos_loss : 0.9964 val_total_loss : 1.6356 val_acc : 0.1898
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 15 word_loss : 30.2485 eos_loss : 1.0666 total_loss : 30.2485 acc : 0.1728
	 val_word_loss : 0.6205 val_eos_loss : 1.0355 val_total_loss : 1.6560 val_acc : 0.1953
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 16 word_loss : 29.7132 eos_loss : 1.0530 total_loss : 29.7132 acc : 0.1745
	 val_word_loss : 0.6175 val_eos_loss : 1.0397 val_total_loss : 1.6572 val_acc : 0.1898
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 17 word_loss : 29.6036 eos_loss : 1.0652 total_loss : 29.6036 acc : 0.1767
	 val_word_loss : 0.6059 val_eos_loss : 0.9749 val_total_loss : 1.5807 val_acc : 0.1918
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

SAVE WEIGHT
EPOCH : 18 word_loss : 29.1774 eos_loss : 1.0433 total_loss : 29.1774 acc : 0.1787
	 val_word_loss : 0.6020 val_eos_loss : 0.6790 val_total_loss : 1.2810 val_acc : 0.1920
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 19 word_loss : 28.5635 eos_loss : 1.0165 total_loss : 28.5635 acc : 0.1799
	 val_word_loss : 0.5985 val_eos_loss : 1.0400 val_total_loss : 1.6385 val_acc : 0.1976
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 20 word_loss : 28.5262 eos_loss : 1.0209 total_loss : 28.5262 acc : 0.1820
	 val_word_loss : 0.5917 val_eos_loss : 0.7527 val_total_loss : 1.3444 val_acc : 0.1930
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 21 word_loss : 27.2421 eos_loss : 0.9294 total_loss : 27.2421 acc : 0.1838
	 val_word_loss : 0.5872 val_eos_loss : 1.0417 val_total_loss : 1.6289 val_acc : 0.1983
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 22 word_loss : 28.2810 eos_loss : 1.0342 total_loss : 28.2810 acc : 0.1852
	 val_word_loss : 0.5820 val_eos_loss : 1.0219 val_total_loss : 1.6040 val_acc : 0.1993
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 23 word_loss : 27.7792 eos_loss : 1.0039 total_loss : 27.7792 acc : 0.1869
	 val_word_loss : 0.5794 val_eos_loss : 1.0222 val_total_loss : 1.6016 val_acc : 0.2019
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 24 word_loss : 26.5664 eos_loss : 0.9059 total_loss : 26.5664 acc : 0.1869
	 val_word_loss : 0.5769 val_eos_loss : 1.0113 val_total_loss : 1.5882 val_acc : 0.1996
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 25 word_loss : 27.0733 eos_loss : 0.9618 total_loss : 27.0733 acc : 0.1861
	 val_word_loss : 0.5788 val_eos_loss : 0.7078 val_total_loss : 1.2866 val_acc : 0.1968
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 26 word_loss : 24.8400 eos_loss : 0.7642 total_loss : 24.8400 acc : 0.1886
	 val_word_loss : 0.5807 val_eos_loss : 0.7101 val_total_loss : 1.2908 val_acc : 0.1953
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

EPOCH : 27 word_loss : 24.4743 eos_loss : 0.7415 total_loss : 24.4743 acc : 0.1898
	 val_word_loss : 0.5757 val_eos_loss : 0.8119 val_total_loss : 1.3876 val_acc : 0.2014
------------------------------



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

KeyboardInterrupt: ignored

In [None]:
#plt.plot(history['loss'], label='train_total_loss')
plt.plot(history['val_loss'], marker='v', label='val_total_loss')
plt.legend()
plt.show()

# Caption It

In [None]:
decoder.load_weights('decoder_weight.hdf5')
encoder.load_weights('encoder_weight.hdf5')

In [None]:
attention_features_shape = 49
features_shape = 2560
max_test_length = max
def evaluate(image, max_length):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden_state = decoder.reset_state(batch_size = 1)

    image = tf.reshape(image, shape=(1,image.shape[0],image.shape[1],image.shape[2]))
    features = encoder(image)

    dec_input = tf.expand_dims([tokenizer.word2index['<start>']], 0)
    result = []

    for i in range(max_length): # *

        predictions, hidden_state, attention_weights = decoder(features, dec_input, hidden_state)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = np.argmax(predictions, axis=-1)[0]

        result.append(tokenizer.index2word[predicted_id])

        if tokenizer.index2word[predicted_id] == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

In [None]:
def plot_attention(image, result, attention_plot):

    fig = plt.figure(figsize=(10, 10))
    fig.suptitle(''.join(result[:-1]), fontsize=22, y = 1.03)
    len_result = len(result)
    for l in range(len_result):
        temp_att = np.resize(attention_plot[l], (8, 8))
        ax = fig.add_subplot(len_result//2, len_result//2, l+1)
        ax.set_title(result[l])
        img = ax.imshow(image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())
    
    plt.tight_layout()
    plt.show()

In [None]:
# captions on the validation set
result, attention_plot = evaluate(image_test[123], max_length_test)

print ('Real Caption:', caption_test[123])
print ('Prediction Caption:', ''.join(result[:-1]))

plot_attention(image_test[123], result, attention_plot)

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [None]:
from nami.AI.kme_tokenize import Segmentation
kme_segment = Segmentation()

In [None]:
best_score = {}
chencherry = SmoothingFunction()

for i in trange(len(image_test)):
  word_arr, _ = evaluate(image_test[i], max_length_test)
  real_word,_ = kme_segment.word_segmentation(caption_test[i])

  score = sentence_bleu([real_word], word_arr[:-1], smoothing_function = chencherry.method4)
  best_score[i] = score
  #print(f"{i} | {score:.4f}", end=' ')
  #print(''.join(word_arr[:-1]), end=' : ')
  #print(caption_test[i])

In [None]:
sort_best_score = sorted(best_score.items(), key=lambda x: x[1], reverse=True)

# BLEU SCORE

In [None]:
for i in range(100):
  print(sort_best_score[i][0], sort_best_score[i][1])

In [None]:
print(caption_test[0])
plt.imshow(image_test[0])

In [None]:
# captions on the validation set
result, attention_plot = evaluate(image_test[19], max_length_test)

print ('Real Caption:', caption_test[19])
print ('Prediction Caption:', ''.join(result[:-1]))
plt.imshow(image_test[19])
plt.show()
plot_attention(image_test[19], result, attention_plot)


In [None]:
# captions on the validation set
result, attention_plot = evaluate(image_test[25], max_length_test)
2
print ('Real Caption:', caption_test[25])
print ('Prediction Caption:', ''.join(result[:-1]))
plt.imshow(image_test[25])
plt.show()
plot_attention(image_test[25], result, attention_plot)


In [None]:
# captions on the validation set
result, attention_plot = evaluate(image_train[98], max_length_train)

print ('Real Caption:', caption_train[98])
print ('Prediction Caption:', ''.join(result[:-1]))
plt.imshow(image_train[98])
plt.show()
plot_attention(image_train[98], result, attention_plot)


In [None]:
plt.imshow(image_train[0,:,:,0], cmap='gray')

In [None]:
# captions on the validation set
result, attention_plot = evaluate(image_test[100], max_length_test)

print ('Real Caption:', caption_test[100])
print ('Prediction Caption:', ''.join(result[:-1]))
plt.imshow(image_test[100])
plt.show()
plot_attention(image_test[100], result, attention_plot)
