In [1]:
# this will be useful if you need to reload any module after some changes
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from pycuda import gpuarray, compiler
from collections import OrderedDict
import os
import h5py

In [3]:
# pick your device the default is 0 if not specified if the next line is not commented
os.environ['CUDA_DEVICE'] = '1' 

# autoinit automatically initializes a CUDA context
import pycuda.autoinit

from PVM_PyCUDA import OnTheFlyPVM

In [4]:
# importing two functions for mapping and unmapping and image into a
# one dimensional array
from FormattingFiles import flatten_image, unflatten_image
# importing a function to give a connection dictionary
from RectangularGridConstructor import make_connections, break_stuff

In [5]:
# The parameters for the PVM they are different from the original paper
n_color = 3
input_edge_x, input_edge_y = 2, 2
input_size = input_edge_x * input_edge_y * n_color
hidden_size = 5
inner_hidden_size = 5
output_sizes = [0] * 8#9#
inner_output_size = 0
structure = [(64, 48), (32, 24), (16, 12), (8, 6), (4, 3), (3, 2), (2, 1), 1]
#[(64, 48), (64, 48), (32, 24), (16, 12), (8, 6), (4, 3), (3, 2), (2, 1), 1]
#[(128, 96), (64, 48), (32, 24), (16, 12), (8, 6), (4, 3), (3, 2), (2, 1), 1]#[(96, 96), 48, 24, 12, 6, 3, 2, 1]#

# break_start_x = 16
# break_end_x = 49
# break_start_y = 12
# break_end_y = 37

edge_n_pixels_x, edge_n_pixels_y = input_edge_x * structure[0][0], input_edge_y * structure[0][1]

In [6]:
# initialize any instance of a PVM you need to specify how it's connected
# this can be as general as you want in principle as connectivity is 
# defined in dictionary. The function make_connections is a way to 
# construct a layered hierarchy of rectangular grids with nearest neighbor lateral connections
# was done in the paper
connect_dict = make_connections(structure, input_size, hidden_size, 
                                output_sizes, context_from_top_0_0=True)
# break_unit_list = []
# for x in range(break_start_x, break_end_x):
#     for y in range(break_start_y, break_end_y):
#         break_unit_list.append('_0_{}_{}'.format(x, y))

# connect_dict = break_stuff(connect_dict, 
#                            break_unit_list, 
#                            (input_edge_y, input_edge_x), 
#                            inner_hidden_size,
#                            inner_output_size)


# dim is a tuple (height, width, number of colors)
dim = (edge_n_pixels_y, edge_n_pixels_x, 3)
input_shape = (input_edge_y, input_edge_x)
basic_index = np.arange(np.prod(dim)).reshape(dim)
flat_map = flatten_image(basic_index, input_shape)
# rev_flat_map = unflatten_image(basic_index.flatten(), dim, input_shape)

In [7]:
def connect_next_nearest_neighbors(connect_dict):
    keys = list(connect_dict.keys())
    for key, val in connect_dict.items():
        unit_count, input_size, hidden_size,\
        output_size, fedfromlist, latsuplist = val
        _, lvl_str, x_str, y_str = key.split('_')
        x = int(x_str)
        y = int(y_str)
        # potential next nearest neighbor to north
        p_nnn_N = '_'.join([_, lvl_str, x_str, str(y + 2)])
        # potential next nearest neighbor to south
        p_nnn_S = '_'.join([_, lvl_str, x_str, str(y - 2)])
        # potential next nearest neighbor to west
        p_nnn_W = '_'.join([_, lvl_str, str(x - 2), y_str])
        # potential next nearest neighbor to east
        p_nnn_E = '_'.join([_, lvl_str, str(x + 2), y_str])
        # potential next nearest neighbor to south west
        p_nnn_SW = '_'.join([_, lvl_str, str(x - 1), str(y - 1)])
        # potential next nearest neighbor to south east
        p_nnn_SE = '_'.join([_, lvl_str, str(x + 1), str(y - 1)])
        # potential next nearest neighbor to north west
        p_nnn_NW = '_'.join([_, lvl_str, str(x - 1), str(y + 1)])
        # potential next nearest neighbor to north east
        p_nnn_NE = '_'.join([_, lvl_str, str(x + 1), str(y + 1)])
        
        p_nnn_list = [p_nnn_N,
                      p_nnn_S,
                      p_nnn_W,
                      p_nnn_E,
                      p_nnn_SW,
                      p_nnn_SE,
                      p_nnn_NW,
                      p_nnn_NE]
        for p_nnn in p_nnn_list:
            if p_nnn in keys:
                latsuplist.append(p_nnn)
        connect_dict[key] = (unit_count, input_size, hidden_size,
                             output_size, fedfromlist, latsuplist)
    return connect_dict

In [8]:
connect_dict = connect_next_nearest_neighbors(connect_dict)

In [9]:
path = '/media/sdb/'#'~/Downloads/' # press tab for autocomplete results
train_filename = path + 'PVM_train_set_96h_by_128w.hdf5' 
#'PVM_movement_integration_train_set_96h_by_128w_no_position.hdf5'
# test_filename = path + 'PVM_test_set.hdf5'

train_data = h5py.File(train_filename, 'r')
# test_data_reformat_nontracker =  h5py.File(test_filename, 'r')

In [10]:
pvm = OnTheFlyPVM(connect_dict, flat_map, norm=255.)

In [11]:
fname = '/home/mhazoglou/PVM_PyCUDA/MotionIntegration/model_for_transfer_learning_exp1_plus_NNN_hidden_size_5_3000000steps_0.01'

In [12]:
pvm.load_parameters(fname)

FileNotFoundError: [Errno 2] No such file or directory: '/home/mhazoglou/PVM_PyCUDA/MotionIntegration/model_for_transfer_learning_exp1_plus_NNN_hidden_size_10_3000000steps_0.01_connections.pkl'

In [None]:
learning_rate_list = [0.01] * 3000000
pvm.train(train_data, learning_rate_list, print_every=100000,
          save_every_print=True, filename=fname, interval= 100000)

--------------------------------------------------------------------------------
                   AVG MSE over last 100000 frames


In [None]:
pvm.save_mse(fname)                                                                                                                        

In [11]:
pvm.quick_animate(train_data, scale=5)

--------------------------------------------------------------------------------
Animating testing data
--------------------------------------------------------------------------------
Use a Keyboard interruption to exit early.
--------------------------------------------------------------------------------
