In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import absolute_import, print_function, division
from builtins import *
import h5py
import math
from collections import OrderedDict
import numpy as np
from pycuda import driver, compiler, gpuarray, tools
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [3]:
os.environ['CUDA_DEVICE'] = '1' # pick your device

# -- initialize the device
import pycuda.autoinit

In [4]:
n_color = 3
input_edge = 8#3#6# 
input_size = input_edge * input_edge * n_color
hidden_size = 16#49#8#
inner_hidden_size = 16
output_sizes = [1]*6#[1, 2*2, 4*4, 8*8, 6*6, 16*16]#
inner_output_size = 1
heat_map_size = 256
structure = [24, 12, 6, 3, 2, 1]#[32, 16, 8, 4, 2, 1]#[16, 8, 4, 3, 2, 1]#
break_start = 9
break_end  = 16

edge_n_pixels = input_edge*structure[0]
heat_map_edge = int(math.sqrt(heat_map_size))

In [5]:
train_data_reformat_nontracker = h5py.File('/media/sdb/720p_formatted_training.hdf5', 'r')

In [6]:
from FormattingFiles import norm_and_heatmap, flatten_image, unflatten_image
from RectangularGridConstructor import shape_check, make_connections, break_stuff

In [7]:
import PVM_PyCUDA

In [8]:
connect_dict = make_connections(structure, input_size, hidden_size, output_sizes, context_from_top_0_0=True)
break_unit_list = []
for x in range(break_start, break_end):
    for y in range(break_start, break_end):
        break_unit_list.append('_0_{}_{}'.format(x, y))
connect_dict = break_stuff(connect_dict, break_unit_list, (input_edge, input_edge), inner_hidden_size,
                           inner_output_size)

In [9]:
# dim is a tuple (height, width, number of colors)
dim = (edge_n_pixels, edge_n_pixels, 3)
input_shape = (input_edge, input_edge)
basic_index = np.arange(np.prod(dim)).reshape(dim)
flat_map = flatten_image(basic_index, input_shape)
rev_flat_map = unflatten_image(basic_index.flatten(), dim, input_shape)

In [15]:
saccadePVM = PVM_PyCUDA.PinholeSaccadicPVM(connect_dict, flat_map, rev_flat_map, dim,
                                           x_w=1, y_w=1, norm=255.0, omega=0.8, gamma=0.9)

In [None]:
fname = './SaccadeFullSetTraining/720p_test_teo_1000000'
#'./SaccadeFullSetTraining/structure_16_8_4_3_2_1_input_32_by_32_pvm_unit_in_2_by_2_learning_rate_0.01_hidden_8_3000000steps'
saccadePVM.load_parameters(fname)

In [None]:
# %%prun
learning_rate_list = [.01] * 500000
fname = './SaccadeFullSetTraining/720p_test_teo_500000'
#'./SaccadeFullSetTraining/structure_16_8_4_3_2_1_input_32_by_32_pvm_unit_in_2_by_2_learning_rate_0.01_hidden_8_3500000steps'
#'./FullSetTraining/structure_32_16_8_4_2_1_no_tracking_learning_rate_0.01_hidden_16_1000000steps'
#'./GreenBallTraining/no_tracking_green_ball_learning_rate_0.01_hidden_49_5000000steps'
saccadePVM.train(train_data_reformat_nontracker, learning_rate_list, print_every=10000,
              save_every_print=True, filename=fname, interval=10000)

--------------------------------------------------------------------------------
                   AVG MSE over last 10000 frames
     10000 frames: 0.010693001504
     20000 frames: 0.00555444000307
     30000 frames: 0.00502997777988
     40000 frames: 0.00410998652096
     50000 frames: 0.00508850036622
     60000 frames: 0.00577795625619


In [12]:
test_data_reformat_nontracker = h5py.File('/media/sdb/720p_formatted_training.hdf5', 'r')

In [14]:
%matplotlib tk

# for the differential equation x'' + 2 * gamma * omega * x' + omega**2 (x - x^{eq}) = 0

fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
# ax4 = fig.add_subplot(224)

unflattened_idx_array = rev_flat_map # not a copy

# def window_sum(arr, width, height):
#     N_rows, N_cols, n_colors = arr.shape
#     N_s_row = N_rows - height + 1
#     N_s_col = N_cols - width + 1
#     summed_arr = np.zeros((N_s_row, N_s_col))
#     for row in xrange(N_s_row):
#         for col in xrange(N_s_col):
#             summed_arr[row, col] = np.sum(arr[row:row + height,
#                                               col:col + width, :])
#     return summed_arr
L_y = saccadePVM.sub_frame_height
L_x  = saccadePVM.sub_frame_width

def gen_func():
    global test_data_reformat_nontracker, saccadePVM
    for key, rescale_arr in test_data_reformat_nontracker.items():
        n_frame, height, width, n_colors = rescale_arr.shape
        
        saccadePVM.reset_state(x_pos=(width - saccadePVM.sub_frame_width) // 2,
                               y_pos=(height - saccadePVM.sub_frame_height) // 2)
        for i in range(n_frame):
            image = rescale_arr[i, ...]
            saccadePVM.forward(image)
            yield image, saccadePVM.pred[:saccadePVM.L_input].get(),\
                saccadePVM.err[:saccadePVM.L_input].get(), saccadePVM.x,\
                saccadePVM.y
                
def update(vals):
    global L_y, L_x
    image, pred, err, x, y = vals
    reordered_err = err[unflattened_idx_array]
    mag_err = abs(reordered_err - 0.5)
    
    im1 = ax1.imshow(image, animated=True)
    ln1.set_xdata([x, x])
    ln1.set_ydata([y, y + L_y])
    ln2.set_xdata([x + L_x, x + L_x])
    ln2.set_ydata([y, y + L_y])
    ln3.set_xdata([x, x + L_x])
    ln3.set_ydata([y, y])
    ln4.set_xdata([x, x + L_x])
    ln4.set_ydata([y + L_y, y + L_y])
        
        
    im2 = ax2.imshow(pred[unflattened_idx_array], animated=True)

    im3 = ax3.imshow(mag_err, animated=True)

#     im4 = ax4.imshow(heatmap.reshape(heat_map_edges),
#                      vmin=0, vmax=1, cmap='gray', animated=True)
    return im1, im2, im3, ln1, ln2, ln3, ln4

vals = gen_func().next()
image, pred, err, x, y = vals
reordered_err = err[unflattened_idx_array]
mag_err = abs(reordered_err - 0.5)

im1 = ax1.imshow(image, animated=True)
ln1, = ax1.plot([x, x], [y, y + L_y], 'r')
ln2, = ax1.plot([x + L_x, x + L_x], [y, y + L_y], 'r')
ln3, = ax1.plot([x, x + L_x], [y, y], 'r')
ln4, = ax1.plot([x, x + L_x], [y + L_y, y + L_y], 'r')



im2 = ax2.imshow(pred[unflattened_idx_array], animated=True)

im3 = ax3.imshow(mag_err, animated=True)

# im4 = ax4.imshow(heatmap.reshape(heat_map_edges),
#                  vmin=0, vmax=1, cmap='gray', animated=True)
ani = animation.FuncAnimation(fig, update, frames=gen_func,
                              interval=5, blit=True, save_count=30*(10*60))

# ani.save('PinholeSaccadicPVM_on_Full_test_gamma_0.9_omega_0.8_structure_16_8_4_3_2_1_input_32_by_32_pvm_unit_in_2_by_2_learning_rate_0.01_hidden_8_3000000steps.mp4',
#          writer='ffmpeg', fps=30, bitrate=-1,
#          extra_args=['-vcodec', 'libx264'])

plt.show()

In [None]:
saccadePVM.input_frame_rev_shuf

In [None]:
saccadePVM.L_input

In [None]:
isinstance(round(1.5), int)

In [None]:
for key, data in train_data_reformat_nontracker.items():
    frames, rows, cols, n_colors = data.shape
    saccadePVM.reset_state()
    for i in range(frames):
#         saccadePVM.forward(data[i, ...])
        saccadePVM._PinholeSaccadicPVM__backward(data[i, ...], data[i+1, ...], 0.01)
        break
    break