In [4]:
import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torch.autograd import Variable 
from torchvision import transforms 
from data_loader import build_vocab 
from model import EncoderCNN, DecoderRNN
from model import ResNet, ResidualBlock
from PIL import Image


embed_size = 256
hidden_size = 512 
num_layers = 1 
decoder_path = './models/3object/decoder-10-150.pkl'
encoder_path = './models/3object/encoder-10-150.pkl'


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)

def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([64, 64], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

transform = transforms.Compose([ 
    transforms.ToTensor(), 
    transforms.Normalize((0.033, 0.032, 0.033), 
                         (0.027, 0.027, 0.027))])



vocab_path = './data/vocab_3object.pkl'
with open(vocab_path, 'rb') as f:
    vocab = pickle.load(f)
len_vocab = vocab.idx

# Build Models
encoder = ResNet(ResidualBlock, [3, 3, 3], len_vocab)
encoder.eval()  # evaluation mode (BN uses moving mean/variance)
decoder = DecoderRNN(len_vocab, hidden_size, 
                     len(vocab), num_layers)


# Load the trained model parameters
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

# If use gpu
if torch.cuda.is_available():
    encoder.cuda()
    decoder.cuda()

    

In [5]:
test_dir = 'data/3object_test/bitmap/'
cap_dir = 'data/3object_test/caption/'
test_list = os.listdir(test_dir)
cnt = 0 

shape_match = 0
location_match = 0
radius_match = 0 
rect_x_match = 0 
rect_y_match = 0 
length_match = 0
for fname in test_list: 
    #cnt+=1 
    if cnt>2:
        break;
    test_path = test_dir + fname
    test_image = load_image(test_path, transform)
    image_tensor = to_var(test_image)
    
    # Generate caption from image
    feature = encoder(image_tensor)
    sampled_ids = decoder.sample(feature)
    ids_arr = []
    for element in sampled_ids: 
        temp = element.cpu().data.numpy()
        ids_arr.append(int(temp))
    print(fname)

    # Decode word_ids to words
    sampled_caption = []
    for word_id in ids_arr:
        word = vocab.idx2word[word_id]
        sampled_caption.append(word)
        if word == '<end>':
            break
    in_caption = sampled_caption[1:-1]
    print(in_caption)
    
    #read target caption
    cap_target = cap_dir+fname
    cap_path = cap_target.split('.')[0]+'.svg'
    trg_caption = [] 
    with open(cap_path, 'r') as f:
        target = f.readline().split()
        for word in target:
            trg_caption.append(word)
    print(trg_caption)
    print('-------------------------------------------------------------------')
    
    
    if(len(in_caption) == len(trg_caption)):        
        length_match += 1
        
print(length_match) 


9826.png
['circle', '3', '10', '60', 'lime', 'circle', '3', '10', '60', 'cyan', 'rect', '2', '3', '60', '60', 'cyan']
['circle', '1', '10', '90', 'orange', 'circle', '4', '5', '40', 'red']
-------------------------------------------------------------------
9875.png
['rect', '5', '5', '60', '60', 'yellow', 'rect', '9', '6', '60', '60', 'purple']
['rect', '9', '6', '40', '40', 'purple', 'rect', '5', '6', '80', '80', 'cyan']
-------------------------------------------------------------------
9807.png
['rect', '8', '1', '80', '80', 'pink', 'rect', '7', '1', '80', '80', 'purple']
['rect', '6', '2', '90', '90', 'blue', 'rect', '8', '1', '80', '80', 'green']
-------------------------------------------------------------------
9921.png
['circle', '10', '10', '60', 'lime', 'circle', '9', '3', '100', 'cyan', 'rect', '8', '10', '60', '60', 'cyan']
['circle', '8', '10', '90', 'deep_pink', 'circle', '10', '4', '60', 'green']
-------------------------------------------------------------------
9867.pn

9799.png
['circle', '4', '2', '100', 'lime', 'circle', '2', '2', '100', 'lime']
['circle', '3', '2', '110', 'orange', 'circle', '1', '2', '60', 'green']
-------------------------------------------------------------------
9891.png
['circle', '2', '6', '80', 'lime', 'circle', '3', '10', '80', 'cyan', 'rect', '2', '7', '80', '80', 'lime']
['circle', '2', '7', '50', 'skyblue', 'circle', '3', '7', '110', 'red']
-------------------------------------------------------------------
9989.png
['circle', '7', '1', '100', 'lime', 'circle', '7', '7', '70', 'pink', 'rect', '8', '3', '100', '100', 'lime']
['circle', '8', '2', '80', 'green', 'circle', '6', '7', '80', 'red']
-------------------------------------------------------------------
9998.png
['circle', '4', '3', '100', 'cyan', 'circle', '5', '8', '70', 'pink', 'rect', '2', '3', '100', '100', 'lime']
['circle', '4', '2', '70', 'lime', 'circle', '3', '5', '110', 'lime', 'rect', '4', '9', '70', '70', 'orange']
-------------------------------------

9791.png
['circle', '6', '5', '100', 'lime', 'circle', '1', '5', '80', 'pink', 'rect', '1', '4', '100', '100', 'lime']
['circle', '1', '7', '70', 'pink', 'circle', '6', '4', '110', 'spring_green']
-------------------------------------------------------------------
9833.png
['circle', '2', '9', '60', 'lime', 'circle', '2', '8', '100', 'lime', 'rect', '8', '3', '60', '60', 'cyan']
['circle', '2', '10', '90', 'yellow', 'circle', '7', '8', '80', 'purple', 'rect', '1', '4', '90', '90', 'yellow']
-------------------------------------------------------------------
9957.png
['circle', '7', '5', '90', 'lime', 'circle', '1', '5', '90', 'lime']
['circle', '1', '6', '50', 'green', 'circle', '7', '5', '110', 'purple']
-------------------------------------------------------------------
9715.png
['circle', '4', '4', '80', 'lime', 'circle', '1', '3', '80', 'cyan', 'rect', '7', '4', '80', '80', 'lime']
['circle', '2', '5', '100', 'lime', 'circle', '6', '2', '50', 'deep_pink']
--------------------------

9857.png
['circle', '6', '1', '80', 'lime', 'circle', '7', '1', '80', 'cyan']
['circle', '7', '1', '110', 'skyblue', 'circle', '6', '2', '60', 'red']
-------------------------------------------------------------------
9955.png
['circle', '2', '2', '100', 'lime', 'circle', '10', '2', '80', 'pink']
['circle', '10', '1', '80', 'purple', 'circle', '2', '1', '110', 'orange']
-------------------------------------------------------------------
9804.png
['rect', '5', '1', '60', '60', 'cyan', 'rect', '2', '5', '80', '80', 'lime', 'circle', '5', '4', '60', 'cyan']
['rect', '3', '2', '40', '40', 'purple', 'rect', '6', '4', '110', '110', 'lime', 'circle', '2', '1', '40', 'yellow']
-------------------------------------------------------------------
9721.png
['rect', '10', '10', '60', '60', 'purple', 'rect', '9', '5', '90', '90', 'purple']
['rect', '10', '5', '100', '100', 'skyblue', 'rect', '8', '9', '80', '80', 'yellow']
-------------------------------------------------------------------
9882.png


9941.png
['circle', '9', '4', '60', 'lime', 'circle', '1', '3', '80', 'cyan', 'rect', '8', '5', '60', '60', 'cyan']
['circle', '9', '2', '60', 'red', 'circle', '1', '5', '60', 'deep_pink', 'rect', '8', '4', '70', '70', 'pink']
-------------------------------------------------------------------
9722.png
['rect', '8', '1', '60', '60', 'purple', 'rect', '8', '9', '60', '60', 'purple', 'circle', '8', '9', '60', 'lime']
['rect', '8', '9', '100', '100', 'orange', 'rect', '6', '1', '50', '50', 'spring_green']
-------------------------------------------------------------------
9984.png
['rect', '3', '4', '60', '60', 'cyan', 'rect', '8', '2', '30', '30', 'lime', 'circle', '3', '5', '60', 'cyan']
['rect', '10', '1', '20', '20', 'yellow', 'rect', '4', '5', '110', '110', 'red', 'circle', '2', '4', '20', 'green']
-------------------------------------------------------------------
9779.png
['rect', '9', '5', '60', '60', 'purple', 'rect', '7', '5', '60', '60', 'purple']
['rect', '8', '4', '40', '40',

9746.png
['circle', '5', '9', '100', 'cyan', 'circle', '7', '7', '110', 'pink', 'rect', '10', '8', '100', '100', 'lime']
['circle', '4', '7', '120', 'yellow', 'circle', '8', '9', '80', 'yellow', 'rect', '6', '7', '120', '120', 'blue']
-------------------------------------------------------------------
9852.png
['rect', '9', '5', '60', '60', 'purple', 'rect', '9', '5', '60', '60', 'purple']
['rect', '8', '6', '90', '90', 'red', 'rect', '9', '4', '80', '80', 'skyblue']
-------------------------------------------------------------------
9908.png
['circle', '4', '9', '60', 'cyan', 'circle', '2', '3', '80', 'cyan', 'rect', '8', '5', '60', '60', 'cyan']
['rect', '4', '1', '90', '90', 'red', 'rect', '8', '10', '70', '70', 'orange', 'circle', '2', '6', '80', 'red']
-------------------------------------------------------------------
9813.png
['rect', '9', '5', '30', '30', 'lime', 'rect', '1', '5', '30', '30', 'lime', 'circle', '10', '5', '30', 'lime']
['rect', '8', '5', '30', '30', 'orange', 'r

9736.png
['rect', '10', '2', '100', '100', 'cyan', 'rect', '3', '10', '90', '90', 'purple']
['rect', '9', '1', '110', '110', 'pink', 'rect', '3', '10', '90', '90', 'pink']
-------------------------------------------------------------------
9916.png
['circle', '7', '4', '40', 'lime', 'circle', '1', '5', '60', 'cyan', 'rect', '7', '3', '40', '40', 'cyan']
['circle', '8', '1', '60', 'pink', 'circle', '4', '6', '30', 'cyan', 'rect', '6', '1', '60', '60', 'skyblue']
-------------------------------------------------------------------
9848.png
['circle', '9', '3', '60', 'lime', 'circle', '7', '8', '80', 'cyan', 'rect', '8', '3', '60', '60', 'cyan']
['circle', '9', '8', '80', 'orange', 'circle', '9', '3', '50', 'lime', 'rect', '6', '3', '80', '80', 'orange']
-------------------------------------------------------------------
9713.png
['rect', '7', '5', '60', '60', 'yellow', 'rect', '8', '5', '60', '60', 'blue', 'circle', '3', '5', '60', 'cyan']
['rect', '7', '2', '50', '50', 'yellow', 'rect', 

9748.png
['rect', '2', '6', '60', '60', 'yellow', 'rect', '5', '6', '90', '90', 'purple', 'circle', '6', '7', '60', 'lime']
['circle', '6', '8', '30', 'blue', 'circle', '2', '5', '60', 'deep_pink']
-------------------------------------------------------------------
9706.png
['rect', '10', '4', '100', '100', 'cyan', 'rect', '2', '10', '110', '110', 'purple', 'circle', '10', '10', '100', 'lime']
['circle', '4', '5', '50', 'pink', 'circle', '10', '8', '50', 'blue', 'rect', '9', '10', '50', '50', 'blue']
-------------------------------------------------------------------
9812.png
['rect', '2', '8', '60', '60', 'cyan', 'rect', '5', '7', '90', '90', 'pink']
['rect', '1', '6', '50', '50', 'yellow', 'rect', '4', '9', '110', '110', 'spring_green']
-------------------------------------------------------------------
9898.png
['circle', '7', '10', '120', 'lime', 'circle', '9', '9', '120', 'cyan']
['circle', '4', '9', '90', 'red', 'circle', '8', '8', '120', 'blue']
---------------------------------

9910.png
['rect', '4', '9', '100', '100', 'purple', 'rect', '9', '10', '50', '50', 'purple', 'circle', '6', '3', '100', 'lime']
['circle', '4', '8', '60', 'spring_green', 'circle', '10', '3', '20', 'pink', 'rect', '7', '10', '60', '60', 'green']
-------------------------------------------------------------------
9740.png
['circle', '4', '1', '100', 'lime', 'circle', '3', '2', '60', 'cyan']
['circle', '4', '3', '100', 'pink', 'circle', '4', '1', '110', 'blue']
-------------------------------------------------------------------
9798.png
['rect', '7', '4', '80', '80', 'lime', 'rect', '7', '1', '80', '80', 'blue', 'circle', '8', '10', '80', 'cyan']
['circle', '7', '1', '60', 'red', 'circle', '9', '10', '40', 'orange', 'rect', '6', '5', '60', '60', 'pink']
-------------------------------------------------------------------
9823.png
['rect', '7', '8', '60', '60', 'yellow', 'rect', '7', '5', '110', '110', 'purple', 'circle', '8', '7', '60', 'skyblue']
['rect', '8', '5', '50', '50', 'deep_pink

In [6]:
print(length_match/len(test_list))

0
