In [1]:
import os
os.chdir('../')

import numpy as np
import tensorflow as tf
from scipy.misc import imread, imresize

from hart.data import disp
from hart.data.kitti.tools import get_data
from hart.model import util
from hart.model.attention_ops import FixedStdAttention,RATMAttention
from hart.model.eval_tools import log_norm, log_ratios, log_values, make_expr_logger
from hart.model.tracker import HierarchicalAttentiveRecurrentTracker as HART
from hart.model.nn import AlexNetModel, IsTrainingLayer
from hart.train_tools import TrainSchedule, minimize_clipped

import glob
import moviepy.editor as mpy
import cv2

import matplotlib.pyplot as plt
% matplotlib inline

In [2]:
def load_imgs(img_dir):
    img_paths = sorted(os.listdir(img_dir))
    imgs = np.empty([len(img_paths), 1] + list(img_size), dtype=np.float32)
    for i, img_path in enumerate(img_paths):
        img_path= os.path.join(img_dir, img_path)
        imgs[i, 0] = imresize(imread(img_path), img_size[:2])
        
    return imgs
def write_imgs_glimpses_objmasks_to_file(imgs,glimpse,obj_mask,pred_bbox,pred_att,base_path):
    n_images = imgs.shape[0]
    for i in xrange(n_images):
        #cv2 rectangle parameters here are the image, the top left coordinate, bottom right coordinate, color, and line thickness.
        y1_bbox = pred_bbox[i].squeeze()[0]
        x1_bbox = pred_bbox[i].squeeze()[1]
        y2_bbox = y1_bbox + pred_bbox[i].squeeze()[2]
        x2_bbox = x1_bbox + pred_bbox[i].squeeze()[3]
        
        y1_att = pred_att[i].squeeze()[0]
        x1_att = pred_att[i].squeeze()[1]
        y2_att = y1_att + pred_att[i].squeeze()[2]
        x2_att = x1_att + pred_att[i].squeeze()[3]
        
        #Blue will be prediction BBOX
        cv2.rectangle(imgs[i].squeeze(),(x1_bbox,y1_bbox),(x2_bbox,y2_bbox),(0,0,255),1)
        #Red will be attention BBOX
        cv2.rectangle(imgs[i].squeeze(),(x1_att,y1_att),(x2_att,y2_att),(255,0,0),1)

        plt.imsave(base_path+"_img_"+str(i)+".png",imgs[i].squeeze() / 255.)
        plt.imsave(base_path+"_glimpse_"+str(i)+".png",glimpse[i].squeeze())
        plt.imsave(base_path+"_objMask_"+str(i)+".png",obj_mask[i].squeeze(), cmap='gray', vmin=0., vmax=1.)

        

In [3]:
alexnet_dir = 'checkpoints'
img_dir = 'data/kitti_tracking/small_test/0016'
img_dir = 'data/kitti_tracking/small/0015'
#img_dir = 'data/my_experiments/frames_small/ApplesRollingAmongstApples'
img_dir = 'data/my_experiments/frames_small/CarOnRoad'
# checkpoint_path = 'checkpoints/kitti/pretrained/2017_07_06_16.41/model.ckpt-142320'
checkpoint_path = 'checkpoints/kitti/pretrained/model.ckpt-347346'

batch_size = 1
img_size = 187, 621, 3
crop_size = 56, 56, 3

rnn_units = 100
norm = 'batch'
keep_prob = .75

img_size, crop_size = [np.asarray(i) for i in (img_size, crop_size)]
keys = ['img', 'bbox', 'presence']

bbox_shape = (1, 1, 4)

In [4]:
tf.reset_default_graph()
util.set_random_seed(0)

x = tf.placeholder(tf.float32, [None, batch_size] + list(img_size), name='image')
y0 = tf.placeholder(tf.float32, bbox_shape, name='bbox')
p0 = tf.ones(y0.get_shape()[:-1], dtype=tf.uint8, name='presence')

is_training = IsTrainingLayer()
builder = AlexNetModel(alexnet_dir, layer='conv3', n_out_feature_maps=5, upsample=False, normlayer=norm,
                       keep_prob=keep_prob, is_training=is_training)

model = HART(x, y0, p0, batch_size, crop_size, builder, rnn_units,
             bbox_gain=[-4.78, -1.8, -3., -1.8],
             zoneout_prob=(.05, .05),
             normalize_glimpse=True,
             attention_module=FixedStdAttention,
             debug=True,
             transform_init_features=True,
             transform_init_state=True,
             dfn_readout=True,
             feature_shape=(14, 14),
             is_training=is_training)

In [5]:
saver = tf.train.Saver()
sess = tf.Session()

In [6]:
sess.run(tf.global_variables_initializer())
saver.restore(sess, checkpoint_path)
model.test_mode(sess)

INFO:tensorflow:Restoring parameters from checkpoints/kitti/pretrained/model.ckpt-347346


In [7]:
imgs = load_imgs(img_dir)
##bbox = y_origin,x_origin,y_add,x_add
#bbox = [88, 250, 18, 25]
x_start = 396
y_start = 68
x_end = 434
y_end = 90
bbox = [y_start,x_start,y_end-y_start,x_end-x_start]

In [8]:
feed_dict = {x: imgs, y0: np.reshape(bbox, bbox_shape)}
tensors = [model.pred_bbox, model.att_pred_bbox, model.glimpse, model.obj_mask]
pred_bbox, pred_att, glimpse, obj_mask = sess.run(tensors, feed_dict)

In [9]:
'''
n = imgs.shape[0]
fig, axes = plt.subplots(n, 3, figsize=(20, 2*n))
for i, ax in enumerate(axes):
    ax[0].imshow(imgs[i].squeeze() / 255.)
    ax[1].imshow(glimpse[i].squeeze())
    ax[2].imshow(obj_mask[i].squeeze(), cmap='gray', vmin=0., vmax=1.)
    disp.rect(pred_bbox[i].squeeze(), 'b', ax=ax[0])
    disp.rect(pred_att[i].squeeze(), 'g', ax=ax[0])
    for a in ax:
        a.xaxis.set_visible(False)
        a.yaxis.set_visible(False)
        
axes[0, 0].plot([], c='g', label='att')
axes[0, 0].plot([], c='b', label='pred')
axes[0, 0].legend(loc='center right')
axes[0, 0].set_xlim([0, img_size[1]])
axes[0, 0].set_ylim([img_size[0], 0])
'''

"\nn = imgs.shape[0]\nfig, axes = plt.subplots(n, 3, figsize=(20, 2*n))\nfor i, ax in enumerate(axes):\n    ax[0].imshow(imgs[i].squeeze() / 255.)\n    ax[1].imshow(glimpse[i].squeeze())\n    ax[2].imshow(obj_mask[i].squeeze(), cmap='gray', vmin=0., vmax=1.)\n    disp.rect(pred_bbox[i].squeeze(), 'b', ax=ax[0])\n    disp.rect(pred_att[i].squeeze(), 'g', ax=ax[0])\n    for a in ax:\n        a.xaxis.set_visible(False)\n        a.yaxis.set_visible(False)\n        \naxes[0, 0].plot([], c='g', label='att')\naxes[0, 0].plot([], c='b', label='pred')\naxes[0, 0].legend(loc='center right')\naxes[0, 0].set_xlim([0, img_size[1]])\naxes[0, 0].set_ylim([img_size[0], 0])\n"

In [10]:
filename = "CarOnRoad"
path = 'results/%s'%filename
try:
    os.makedirs(path+'/gifs')
except:
    pass
write_imgs_glimpses_objmasks_to_file(imgs,glimpse,obj_mask,pred_bbox,pred_att,'%s/%s'%(path,filename))

In [11]:
print(pred_att[0])

[[[  67.79806519  395.47900391   22.78027916   38.12802505]]]


In [12]:
gif_imgs_name = path+'/gifs/'+filename+'_images'
gif_glimpses_name = path+'/gifs/'+filename+'_glimpses'
gif_objMasks_name = path+'/gifs/'+filename+'_objMasks'
gif_orig_name = path+'/gifs/'+filename+'original'

file_list_imgs = glob.glob('%s/*_img_*.png'%path) # Get all the pngs in the current directory
file_list_glimpses = glob.glob('%s/*_glimpse_*.png'%path) # Get all the pngs in the current directory
file_list_objMasks = glob.glob('%s/*_objMask_*.png'%path) # Get all the pngs in the current directory
file_list_original =  glob.glob('%s/*.png'%img_dir) # Get all the pngs in the current directory

#fps = 4
fps = 23

list.sort(file_list_imgs, key=lambda x: int(x.split('_')[-1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
list.sort(file_list_glimpses, key=lambda x: int(x.split('_')[-1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
list.sort(file_list_objMasks, key=lambda x: int(x.split('_')[-1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
#For KITTI
#list.sort(file_list_original, key=lambda x: int(x.split('/')[-1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
#For My Data
list.sort(file_list_original, key=lambda x: int(x.split('-')[-1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case

clip_imgs = mpy.ImageSequenceClip(file_list_imgs, fps=fps)
clip_glimpses = mpy.ImageSequenceClip(file_list_glimpses, fps=fps)
clip_objMasks = mpy.ImageSequenceClip(file_list_objMasks, fps=fps)
clip_original = mpy.ImageSequenceClip(file_list_original, fps=fps)

clip_imgs.write_gif('{}.gif'.format(gif_imgs_name), fps=fps)
clip_glimpses.write_gif('{}.gif'.format(gif_glimpses_name), fps=fps)
clip_objMasks.write_gif('{}.gif'.format(gif_objMasks_name), fps=fps)
clip_original.write_gif('{}.gif'.format(gif_orig_name), fps=fps)



[MoviePy] Building file results/CarOnRoad/gifs/CarOnRoad_images.gif with imageio


100%|██████████| 450/450 [00:47<00:00,  9.55it/s]


[MoviePy] Building file results/CarOnRoad/gifs/CarOnRoad_glimpses.gif with imageio



100%|██████████| 450/450 [00:03<00:00, 121.36it/s]


[MoviePy] Building file results/CarOnRoad/gifs/CarOnRoad_objMasks.gif with imageio



100%|██████████| 450/450 [00:00<00:00, 835.12it/s]


[MoviePy] Building file results/CarOnRoad/gifs/CarOnRoadoriginal.gif with imageio



100%|██████████| 450/450 [00:46<00:00,  9.63it/s]
