In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

num_complex = 4
delta = 0.2
alpha = 0.02

In [None]:
from pprint import pprint

def main():
    N = 8
    K = 4
    grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

    dataset = {}

    # horizontal

    for i in range(N):
        for j in range(N):
            grid[i][j] = [0,1,0,0]
        if 1 not in dataset:
            dataset[1] = [grid]
        else:
            dataset[1].append(grid)
        
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

    # vertical

    for i in range(N):
        for j in range(N):
            grid[j][i] = [1,0,0,0]
            if 2 not in dataset:
                dataset[2] = [grid]
        else:
            dataset[2].append(grid)
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

    # ldiagonal - da
    
    row = 0
    col = 0

    for i in range(N):
        col = i
        row = 0

        while (col >= 0 and row < N):
            grid[row][col] = [0,0,1,0]
            row += 1
            col -= 1
        if 3 not in dataset:
            dataset[3] = [grid]
        else:
            dataset[3].append(grid)
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

    
    for i in range(1, N):
        row = i
        col = N - 1

        while (col >= 0 and row < N):
            grid[row][col] = [0,0,1,0]
            row += 1
            col -= 1
        if 3 not in dataset:
            dataset[3] = [grid]
        else:
            dataset[3].append(grid)
        
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

    # rdiagonal - db

    for i in range(N):
        col = N - 1 - i
        row = 0

        while (col < N and row < N):
            grid[row][col] = [0,0,0,1]
            row += 1
            col += 1
        if 4 not in dataset:
            dataset[4] = [grid]
        else:
            dataset[4].append(grid)
        
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])

        

    for i in range(1, N):
        row = i
        col = 0

        while (col >= 0 and row < N):
            grid[row][col] = [0,0,0,1]
            row += 1
            col += 1
        if 4 not in dataset:
            dataset[4] = [grid]
        else:
            dataset[4].append(grid)
        
        grid = np.array([[[0,0,0,0] for _ in range(N)] for _ in range(N)])
        
    dataset[5] = (dataset[2][:5] + dataset[3][4:])[1:]
    dataset[6] = dataset[2][:5] + dataset[4][::-1][4:]
    
    visual_grid = [[1 if 1 in dataset[1][0][i][j] else 0 for i in range(N)] for j in range(N)]
    return dataset

        
            
if __name__ == '__main__':
    main()

In [None]:
data = main()
hframes = data[1]
vframes = data[2]
daframes = data[3]
dbframes = data[4]

In [10]:
import copy

#initialize grid of simple cells
def make_grid(n):
    return [[make_simple_unit() for i in range(n)] for j in range(n)]

#initialize simple unit with four feature detectors
def make_simple_unit():
    return {'v': initialize(),'h':initialize(),'da':initialize(),'db':initialize()}

#initialize feature detector, one weight for each connection to each complex cell
def initialize():
    return {'firing':0, 'weights':[random.uniform(0,0.1) for i in range(num_complex)]}

#initialize complex unit
def make_complex_unit():
    return {'firing':0,'trace':0}

#initialize complex layer
def make_complex_layer():
    return [make_complex_unit() for i in range(num_complex)]

#determine feature detector connection weight change for each time step
def find_weight_change(trace,firing,weight):
    return alpha*trace*(firing - weight)

#determine new complex cell trace for each time step
def find_new_trace(firing, trace):
    return (1-delta)*trace + delta*firing    

#determines which complex cell should fire
def which_firing(grid, complex_layer):
    sums = {i:0 for i in range(num_complex)}
    for i,row in enumerate(grid):
        for j,unit in enumerate(row):
            for orientation in unit:
                firing = unit[orientation]['firing']
                if (firing == 1):
                    for complex_cell_num in sums:
                        sums[complex_cell_num] += firing * unit[orientation]['weights'][complex_cell_num]
                weights = unit[orientation]['weights']
                for complex_cell_num, weight in enumerate(weights):
                    grid[i][j][orientation]['weights'][complex_cell_num] += find_weight_change(complex_layer[complex_cell_num]['trace'],firing,weight)
    inverse = [(value, key) for key, value in sums.iteritems()]
    return max(inverse)[1]

#find instances of firing feature detectors in a frame (firing feature detectors -> "ffds")
def read_frame(frame):
    ffds = []
    for (row_num,col_num,feature_detector),is_firing in np.ndenumerate(frame):
        if (is_firing == 1):
            ffds.append((row_num,col_num,feature_detector))
    return ffds

def run_model(data):
    grid = make_grid(len(data[1][0]))
    complex_layer = make_complex_layer()
    fd2orient = {0:'v', 1:'h', 2:'da', 3:'db'}
    complex_num_firing = 0
    grid_record = []
    complex_record = []
    for k,frames in data.iteritems():
        #train frame a given number of times
        for i in range(100):
            #read through frames
            for frame in frames:
                
                #determine which feature detectors are firing
                ffds = read_frame(frame)
                for index in ffds:
                    row = index[0]
                    col = index[1]
                    orientation = fd2orient[index[2]]
                    grid[row][col][orientation]['firing'] = 1
                
                #determine which complex cell will fire
                complex_num_firing = which_firing(grid, complex_layer)
                complex_layer[complex_num_firing]['firing'] = 1
                
                #determine new trace of each complex unit
                for complex_unit in complex_layer:
                    complex_unit['trace'] = find_new_trace(complex_unit['firing'],complex_unit['trace'])  
                
                #reset units
                for index in ffds:
                    row = index[0]
                    col = index[1]
                    orientation = fd2orient[index[2]]
                    grid[row][col][orientation]['firing'] = 0
                complex_layer[complex_num_firing]['firing'] = 0
            
            complex_record.append(copy.deepcopy(complex_layer))
            grid_record.append(copy.deepcopy(grid))
            
    return complex_record, grid_record
                

In [11]:
complex_record, grid_record = run_model(data)

In [12]:
vweights = [[[simple['v']['weights'] for simple in row] for row in grid] for grid in grid_record]
hweights = [[[simple['h']['weights'] for simple in row] for row in grid] for grid in grid_record]
daweights = [[[simple['da']['weights'] for simple in row] for row in grid] for grid in grid_record]
dbweights = [[[simple['db']['weights'] for simple in row] for row in grid] for grid in grid_record]
traces = [[entry['trace'] for entry in layer] for layer in complex_record]

In [13]:
import matplotlib.pyplot as plt

Traceback (most recent call last):
  File "/Users/rankine/anaconda/lib/python2.7/site-packages/IPython/core/ultratb.py", line 970, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
  File "/Users/rankine/anaconda/lib/python2.7/site-packages/IPython/core/ultratb.py", line 233, in wrapped
    return f(*args, **kwargs)
  File "/Users/rankine/anaconda/lib/python2.7/site-packages/IPython/core/ultratb.py", line 267, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
  File "/Users/rankine/anaconda/lib/python2.7/inspect.py", line 1049, in getinnerframes
    framelist.append((tb.tb_frame,) + getframeinfo(tb, context))
  File "/Users/rankine/anaconda/lib/python2.7/inspect.py", line 1009, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
  File "/Users/rankine/anaconda/lib/python2.7/inspect.py", line 454, in getsourcefile
    if hasattr(getmodule(object, filename), '__loader__'):
 

ERROR: Internal Python error in the inspect module.
Below is the traceback from this internal error.


Unfortunately, your original traceback can not be constructed.



TypeError: 'NoneType' object is not iterable

In [14]:
steps = range(len(traces))
points = zip(steps, traces)
cc0traces = [trace[0] for trace in traces] 
cc1traces = [trace[1] for trace in traces]
cc2traces = [trace[2] for trace in traces]
cc3traces = [trace[3] for trace in traces]

In [None]:
#plt.plot(steps, cc0traces, 'r--', steps, cc1traces, 'bs', steps, cc2traces, 'g^', steps, cc3traces, 'b--')
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(steps, cc0traces, c='b', label='0')
ax1.scatter(steps, cc1traces, c='r', label='0')
ax1.scatter(steps, cc2traces, c='g', label='0')
ax1.scatter(steps, cc3traces, c='p', label='0')
plt.show()