In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from __future__ import absolute_import, print_function, division
from builtins import *
import time
import math
from collections import OrderedDict
import numpy as np
from pycuda import driver, compiler, gpuarray, tools
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation

In [3]:
os.environ['CUDA_DEVICE'] = '1' # pick your device

# -- initialize the device
import pycuda.autoinit

In [4]:
n_color = 3
input_edge = 2#3#6# 
input_size = input_edge * input_edge * n_color
hidden_size = 8#16#49#
output_sizes = [1]*6#[1, 2*2, 4*4, 8*8, 6*6, 16*16]#
heat_map_size = 256
structure = [16, 8, 4, 3, 2, 1]#[96, 48, 24, 12, 6, 3, 2, 1]#[32, 16, 8, 4, 2, 1]#

edge_n_pixels = input_edge*structure[0]
heat_map_edge = int(math.sqrt(heat_map_size))

In [5]:
import re
import pandas as pd

datapath = '/media/sdb/PVM_trace_set_raw/'
filelist = os.listdir(datapath)

test_data = {}
did_it = []
for filename in filelist:
    filename_wo_ext = filename.rsplit('.', 1)[0]
    if filename_wo_ext not in did_it:
        fname = datapath + filename_wo_ext
        arr = np.load(fname + '.npy')
        df = pd.read_csv(fname + '.csv')
        
        test_data[filename_wo_ext] = (arr, df)
        did_it.append(filename_wo_ext)

In [6]:
from FormattingFiles import norm_and_heatmap, unflatten_image
from RectangularGridConstructor import shape_check

In [7]:
heat_map_edges = (heat_map_edge, heat_map_edge)
input_layer_dim = (edge_n_pixels, edge_n_pixels)
test_data_reformat = {}
for key, data_and_label in test_data.items():
    data_arr, df = data_and_label
    
    rescale_arr, ground_truth_heat_map_list = norm_and_heatmap(data_arr, df, heat_map_edges,
                                                               input_layer_dim, norm=255.)

    test_data_reformat[key] = (rescale_arr, ground_truth_heat_map_list)
del test_data

In [117]:
import PVM_PyCUDA
from RectangularGridConstructor import make_connections

In [15]:
connect_dict = make_connections(structure, input_size, hidden_size, output_sizes, context_from_top_0_0=True)

In [118]:
saccadePVM = PVM_PyCUDA.PinholeSaccadicPVM(connect_dict, (input_edge, input_edge), (edge_n_pixels, edge_n_pixels, 3),
                                           x_w=1, y_w=1, omega=0.8, gamma=0.9)

In [119]:
fname = './SaccadeFullSetTraining/structure_16_8_4_3_2_1_input_32_by_32_pvm_unit_in_2_by_2_learning_rate_0.01_hidden_8_3000000steps'
saccadePVM.load_parameters(fname)

In [142]:
%matplotlib tk

# for the differential equation x'' + 2 * gamma * omega * x' + omega**2 (x - x^{eq}) = 0

fig = plt.figure(figsize=(15, 5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
# ax4 = fig.add_subplot(224)

dim = (edge_n_pixels, edge_n_pixels, n_color)
tot_size = np.prod(dim)
unflattened_idx_array = unflatten_image(np.arange(tot_size), 
                                        dim, 
                                        (input_edge, input_edge))

# def window_sum(arr, width, height):
#     N_rows, N_cols, n_colors = arr.shape
#     N_s_row = N_rows - height + 1
#     N_s_col = N_cols - width + 1
#     summed_arr = np.zeros((N_s_row, N_s_col))
#     for row in xrange(N_s_row):
#         for col in xrange(N_s_col):
#             summed_arr[row, col] = np.sum(arr[row:row + height,
#                                               col:col + width, :])
#     return summed_arr
L_y = saccadePVM.sub_frame_height
L_x  = saccadePVM.sub_frame_width

def gen_func():
    global test_data_reformat, saccadePVM
    for key, data in test_data_reformat.items():
        rescale_arr, ground_truth_heat_map_list = data
        n_frame, height, width, n_colors = rescale_arr.shape
        
#         saccadePVM.reset_state(x_pos=(width - saccadePVM.sub_frame_width) // 2,
#                                y_pos=(height - saccadePVM.sub_frame_height) // 2)
        for i in range(n_frame):
            image = rescale_arr[i, ...]
            saccadePVM.forward(image)
            yield image, saccadePVM.pred[:saccadePVM.L_input].get(),\
                saccadePVM.err[:saccadePVM.L_input].get(), saccadePVM.x,\
                saccadePVM.y
                
def update(vals):
    global L_y, L_x
    image, pred, err, x, y = vals
    reordered_err = err[unflattened_idx_array]
    mag_err = abs(reordered_err - 0.5)
    
    im1 = ax1.imshow(image, animated=True)
    ln1.set_xdata([x, x])
    ln1.set_ydata([y, y + L_y])
    ln2.set_xdata([x + L_x, x + L_x])
    ln2.set_ydata([y, y + L_y])
    ln3.set_xdata([x, x + L_x])
    ln3.set_ydata([y, y])
    ln4.set_xdata([x, x + L_x])
    ln4.set_ydata([y + L_y, y + L_y])
        
        
    im2 = ax2.imshow(pred[unflattened_idx_array], animated=True)

    im3 = ax3.imshow(mag_err, animated=True)

#     im4 = ax4.imshow(heatmap.reshape(heat_map_edges),
#                      vmin=0, vmax=1, cmap='gray', animated=True)
    return im1, im2, im3, ln1, ln2, ln3, ln4

vals = gen_func().next()
image, pred, err, x, y = vals
reordered_err = err[unflattened_idx_array]
mag_err = abs(reordered_err - 0.5)

im1 = ax1.imshow(image, animated=True)
ln1, = ax1.plot([x, x], [y, y + L_y], 'r')
ln2, = ax1.plot([x + L_x, x + L_x], [y, y + L_y], 'r')
ln3, = ax1.plot([x, x + L_x], [y, y], 'r')
ln4, = ax1.plot([x, x + L_x], [y + L_y, y + L_y], 'r')



im2 = ax2.imshow(pred[unflattened_idx_array], animated=True)

im3 = ax3.imshow(mag_err, animated=True)

# im4 = ax4.imshow(heatmap.reshape(heat_map_edges),
#                  vmin=0, vmax=1, cmap='gray', animated=True)
ani = animation.FuncAnimation(fig, update, frames=gen_func,
                              interval=5, blit=True, save_count=30*(10*60))

# ani.save('PinholeSaccadicPVM_on_Full_test_gamma_0.9_omega_0.8_structure_16_8_4_3_2_1_input_32_by_32_pvm_unit_in_2_by_2_learning_rate_0.01_hidden_8_3000000steps.mp4',
#          writer='ffmpeg', fps=30, bitrate=-1,
#          extra_args=['-vcodec', 'libx264'])

plt.show()

In [141]:
saccadePVM.reset_state(x_pos=(96 - saccadePVM.sub_frame_width) // 2,
                       y_pos=(96 - saccadePVM.sub_frame_height) // 2)

In [79]:
np.random.randint(-1, 2)

-1

In [140]:
x_list = []
y_list = []
for key, data in test_data_reformat.items():
        rescale_arr, ground_truth_heat_map_list = data
        n_frame, height, width, n_colors = rescale_arr.shape
        
#         saccadePVM.reset_state(x_pos=(width - saccadePVM.sub_frame_width) // 2,
#                                y_pos=(height - saccadePVM.sub_frame_height) // 2)
        for i in range(400):
            image = rescale_arr[0, ...]
            saccadePVM.forward(image)
            x_list.append(saccadePVM.x)
            y_list.append(saccadePVM.y)

plt.imshow(image)
plt.plot(np.array(x_list) + edge_n_pixels // 2, np.array(y_list) + edge_n_pixels // 2, 'r')
plt.show()