# Storing Max States
This notebook validates the method of storing states and position associated with maximum activation of nodes in trained networks. Specifically, it tests masking and fancy indexing with numpy arrays.

In [1]:
import numpy as np

In [2]:
A = np.random.rand(10)
max_values = np.zeros([10,4])
print(A)

[ 0.65377149  0.97670074  0.89319191  0.65940942  0.86941973  0.5848019
  0.49380642  0.71930155  0.36795871  0.28061257]


In [3]:
max_mask = A > np.amin(max_values, axis=1)
max_mask

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)

In [4]:
idx = np.argmin(max_values, axis=1)
idx

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [5]:
max_values[np.arange(10), idx] = np.where(max_mask, A, max_values[np.arange(10), idx])
max_values

array([[ 0.65377149,  0.        ,  0.        ,  0.        ],
       [ 0.97670074,  0.        ,  0.        ,  0.        ],
       [ 0.89319191,  0.        ,  0.        ,  0.        ],
       [ 0.65940942,  0.        ,  0.        ,  0.        ],
       [ 0.86941973,  0.        ,  0.        ,  0.        ],
       [ 0.5848019 ,  0.        ,  0.        ,  0.        ],
       [ 0.49380642,  0.        ,  0.        ,  0.        ],
       [ 0.71930155,  0.        ,  0.        ,  0.        ],
       [ 0.36795871,  0.        ,  0.        ,  0.        ],
       [ 0.28061257,  0.        ,  0.        ,  0.        ]])

In [6]:
def update_next_state():
    state = np.random.rand(10)
    print(state)
    max_mask = state > np.amin(max_values, axis=1)
    idx = np.argmin(max_values, axis=1)
    max_values[np.arange(10), idx] = np.where(max_mask, state, max_values[np.arange(10), idx])
    print(max_values)

In [7]:
update_next_state()

[ 0.67395851  0.31278337  0.33944912  0.89094268  0.34939364  0.04903622
  0.65198444  0.74319213  0.88560719  0.70092124]
[[ 0.65377149  0.67395851  0.          0.        ]
 [ 0.97670074  0.31278337  0.          0.        ]
 [ 0.89319191  0.33944912  0.          0.        ]
 [ 0.65940942  0.89094268  0.          0.        ]
 [ 0.86941973  0.34939364  0.          0.        ]
 [ 0.5848019   0.04903622  0.          0.        ]
 [ 0.49380642  0.65198444  0.          0.        ]
 [ 0.71930155  0.74319213  0.          0.        ]
 [ 0.36795871  0.88560719  0.          0.        ]
 [ 0.28061257  0.70092124  0.          0.        ]]


In [8]:
max_values = np.zeros([10,4])
for i in range(10):
    print("Trial %i: " % i)
    update_next_state()

Trial 0: 
[ 0.22214984  0.39161895  0.8536798   0.70126864  0.09105564  0.04636342
  0.20885335  0.20002238  0.32582889  0.59259226]
[[ 0.22214984  0.          0.          0.        ]
 [ 0.39161895  0.          0.          0.        ]
 [ 0.8536798   0.          0.          0.        ]
 [ 0.70126864  0.          0.          0.        ]
 [ 0.09105564  0.          0.          0.        ]
 [ 0.04636342  0.          0.          0.        ]
 [ 0.20885335  0.          0.          0.        ]
 [ 0.20002238  0.          0.          0.        ]
 [ 0.32582889  0.          0.          0.        ]
 [ 0.59259226  0.          0.          0.        ]]
Trial 1: 
[ 0.07430636  0.5302218   0.89404339  0.68504785  0.5586259   0.93512223
  0.06795093  0.62927661  0.17030579  0.3145529 ]
[[ 0.22214984  0.07430636  0.          0.        ]
 [ 0.39161895  0.5302218   0.          0.        ]
 [ 0.8536798   0.89404339  0.          0.        ]
 [ 0.70126864  0.68504785  0.          0.        ]
 [ 0.09105564  0.55

Every value is initially loaded into the four columns for the first four trials, as expected. Then if the value in the next trial is greater than any of four current values, the lowest value is replaced with that value.