In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import json
import matplotlib.pyplot as plt
from utils import *
print("torch.version=", torch.__version__)
print("device=",device)

In [None]:
from checkpoints import  models, word_maps

In [None]:
def caption_image_beam_search(encoder, decoder, image_path, word_map, beam_width=5):
    """
    Reads an image and captions it with beam search.
    
    :param encoder: encoder model
    :param decoder: decoder model
    :param image_path: path to image
    :param word_map: word map
    :param beam_size: number of sequences to consider at each decode-step
    :return: caption, weights for visualization
    """
    vocab_size = len(word_map)
    word_map_start = word_map['<start>']
    word_map_end = word_map['<end>']

    # Read image and process
    image = read_image_and_resize(image_path)

    # Encode
    image = image.unsqueeze(0)  # (1, 3, 256, 256)
    encoder_out = encoder(image)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)

    # Flatten encoding
    encoder_out = encoder_out.view(1, -1, encoder_dim)  # (1, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)
    
    # Decode
    seq, alphas = decode_one(decoder, encoder_out, encoder_dim, enc_image_size, word_map_start, word_map_end, beam_width)
    return seq, alphas

In [None]:
def run():

    for x in zip(models, word_maps):
        model, word_map = x
        print("Model: ",model)
        # Load models
        checkpoint = torch.load(model, map_location=str(device))
        decoder = checkpoint['decoder']
        decoder = decoder.to(device)
        decoder.eval()
        encoder = checkpoint['encoder']
        encoder = encoder.to(device)
        encoder.eval()

        # Load word map (word2ix)
        with open(word_map, 'r') as j:
            word_map = json.load(j)
        rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

        # Encode, decode with attention and beam search
        for beam_width in [5]:
            print("Beam Size = ",beam_width)
            seq, alphas = caption_image_beam_search(encoder, decoder, img, word_map, beam_width)
            alphas = torch.FloatTensor(alphas.to("cpu"))

            decoded_seq = []
            for item in seq:
                decoded_seq.append(rev_word_map[item])
            print(decoded_seq)

        # Visualize caption and attention of best sequence
            visualize_att(img, seq, alphas, rev_word_map, smooth=False)

In [None]:
import os
import random
from IPython.display import Image as DisplayImage
#directory = '../data/flickr8k_images/'
directory = '../data/images/test/'
#directory = '../testimage/'
for filename in random.sample(os.listdir(directory), 1):
    if filename.endswith("jpg") == False: continue
    #img = "../data/flickr8k_images/"+str(line)[0:-2]+".jpg"
    print(filename)
    img=directory+filename
    display(DisplayImage(img, width=150, height=150))
    run()