In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np

In [2]:
seed = 0
input_dim = 5
hidden_dim = 5 # 10
output_dim = input_dim
sigma = 0.5
wrap = True
input_peak_rate = 1.
input_pattern_matrix = np.empty([input_dim,input_dim])

In [3]:
input_indexes = np.arange(0, input_dim)
center_index = (input_dim-1)//2
field = input_peak_rate * np.exp(-((input_indexes-center_index)/sigma)**2.)
this_random = np.random.RandomState()
this_random.seed(seed)

for i in input_indexes:
    if wrap:
        input_pattern_matrix[i,:] = np.roll(field, i-center_index)
    else:
        input_pattern_matrix[i,:] = input_peak_rate * np.exp(-((input_indexes-i)/sigma)**2.)

In [4]:
fig, axes = plt.subplots(1, 2, figsize=(7,2.5))
axes[0].plot(input_indexes, field)
axes[0].set_ylabel('Input unit activity')
axes[0].set_xlabel('Latent dimension')

cbar = axes[1].imshow(input_pattern_matrix)
fig.colorbar(cbar, ax=axes[1])
axes[1].set_xlabel('Input pattern')
axes[1].set_ylabel('Input unit')
axes[1].set_title('Input activities')

fig.tight_layout()

<IPython.core.display.Javascript object>

In [5]:
initial_weight_scale = 1.
hidden_weight_scale = initial_weight_scale / input_dim
weight_matrix = this_random.uniform(-hidden_weight_scale, hidden_weight_scale, [hidden_dim, input_dim])
original_weight_matrix = np.copy(weight_matrix)

In [6]:
fig, axes = plt.subplots(1, 2, figsize=(7,2.5))
cbar = axes[0].imshow(original_weight_matrix)
fig.colorbar(cbar, ax=axes[0])
axes[0].set_title('Hidden weights')
axes[0].set_ylabel('Hidden units')
axes[0].set_xlabel('Input units')

hidden_activities = weight_matrix.dot(input_pattern_matrix)
original_hidden_activities = np.copy(hidden_activities)
cbar = axes[1].imshow(hidden_activities)
fig.colorbar(cbar, ax=axes[1])
axes[1].set_title('Hidden activities')
axes[1].set_ylabel('Hidden units')
axes[1].set_xlabel('Input patterns')

fig.tight_layout()

<IPython.core.display.Javascript object>

In [7]:
oja_delta_w = lambda pre, post, weight, learning_rate: learning_rate * post * (pre - post * weight)

In [8]:
learning_rate = 0.05
weight_matrix = np.copy(original_weight_matrix)
hidden_activities = np.copy(original_hidden_activities)
weight_matrix_history = []
hidden_activities_history = []
pattern_index_history = []
blocks = 4
block_size = 10
for block_o in range(blocks):
    for block_i in range(block_size):
        input_pattern_indexes = np.arange(len(input_pattern_matrix))
        this_random.shuffle(input_pattern_indexes)
        for pattern_index in input_pattern_indexes:
            delta_w_matrix = np.empty_like(weight_matrix)
            for i in range(input_dim):
                for j in range(hidden_dim):
                    delta_w_matrix[j,i] = oja_delta_w(input_pattern_matrix[i, pattern_index], hidden_activities[j,pattern_index], weight_matrix[j,i], learning_rate)
            weight_matrix = np.add(weight_matrix, delta_w_matrix)
            hidden_activities = weight_matrix.dot(input_pattern_matrix)
    weight_matrix_history.append(np.copy(weight_matrix))
    hidden_activities_history.append(np.copy(hidden_activities))

In [9]:
fig, axes = plt.subplots(2, blocks + 1, figsize=(10., 4.25))
cbar = axes[0,0].imshow(original_weight_matrix)
plt.colorbar(cbar, ax=axes[0,0])
axes[0,0].set_ylabel('Hidden unit')
axes[0,0].set_xlabel('Input unit')
axes[0,0].set_title('Weights\nInitial')
cbar = axes[1,0].imshow(original_hidden_activities)
axes[1,0].set_ylabel('Hidden unit')
axes[1,0].set_xlabel('Input pattern')
axes[1,0].set_title('Hidden activities\n')
plt.colorbar(cbar, ax=axes[1,0])
for i in range(len(hidden_activities_history)):
    cbar = axes[0,i+1].imshow(weight_matrix_history[i])
    plt.colorbar(cbar, ax=axes[0,i+1])
    axes[0,i+1].set_title('Block %i' % (i + 1))
    cbar = axes[1,i+1].imshow(hidden_activities_history[i])
    plt.colorbar(cbar, ax=axes[1,i+1])
fig.tight_layout()
fig.subplots_adjust(wspace=0.3, hspace=0.8)
fig.show()

<IPython.core.display.Javascript object>

In [10]:
print(np.sum(original_weight_matrix ** 2., axis=1))
print(np.sum(weight_matrix ** 2., axis=1))

[0.01073482 0.06515769 0.07291377 0.116091   0.0871146 ]
[0.38815252 0.80569246 0.78571612 0.8894828  0.83283324]


In [11]:
min_weight = min(np.min(original_weight_matrix), np.min(weight_matrix))
max_weight = max(np.max(original_weight_matrix), np.max(weight_matrix))
edges = np.linspace(min_weight, max_weight, 20)
plt.figure()
hist, _ = np.histogram(original_weight_matrix, bins=edges)
plt.plot(edges[:-1], hist, label='Before')
hist, _ = np.histogram(weight_matrix, bins=edges)
plt.plot(edges[:-1], hist, label = 'After')
plt.legend(loc='best', frameon=False)
plt.ylabel('Count')
plt.xlabel('Weight')

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Weight')

In [12]:
output_weight_scale = initial_weight_scale / hidden_dim
output_weight_matrix = this_random.uniform(-output_weight_scale, output_weight_scale, [output_dim, hidden_dim])
original_output_weight_matrix = np.copy(output_weight_matrix)

In [13]:
fig, axes = plt.subplots(1, 2, figsize=(7,2.5))
cbar = axes[0].imshow(original_output_weight_matrix)
fig.colorbar(cbar, ax=axes[0])
axes[0].set_title('Output weights')
axes[0].set_ylabel('Output units')
axes[0].set_xlabel('Hidden units')

output_activities = output_weight_matrix.dot(hidden_activities)
original_output_activities = np.copy(output_activities)
cbar = axes[1].imshow(original_output_activities)
fig.colorbar(cbar, ax=axes[1])
axes[1].set_title('Output activities')
axes[1].set_ylabel('Output units')
axes[1].set_xlabel('Input patterns')

fig.tight_layout()

<IPython.core.display.Javascript object>

In [14]:
# drive Oja with the target output
output_weight_matrix = np.copy(original_output_weight_matrix)
output_activities = np.copy(original_output_activities)
output_weight_matrix_history = []
output_activities_history = []
blocks = 4
block_size = 10
for block_o in range(blocks):
    for block_i in range(block_size):
        input_pattern_indexes = np.arange(len(input_pattern_matrix))
        this_random.shuffle(input_pattern_indexes)
        for pattern_index in input_pattern_indexes:
            this_input = input_pattern_matrix[:,pattern_index]
            target_post = this_input
            delta_w_matrix = np.empty_like(output_weight_matrix)
            for j in range(hidden_dim):
                for k in range(output_dim):
                    delta_w_matrix[k,j] = oja_delta_w(hidden_activities[j,pattern_index], target_post[k], output_weight_matrix[k,j], learning_rate)
            output_weight_matrix = np.add(output_weight_matrix, delta_w_matrix)
            output_activities = output_weight_matrix.dot(hidden_activities)
    output_weight_matrix_history.append(np.copy(output_weight_matrix))
    output_activities_history.append(np.copy(output_activities))

In [17]:
fig, axes = plt.subplots(2, blocks + 1, figsize=(10., 4.25))
cbar = axes[0,0].imshow(original_output_weight_matrix)
plt.colorbar(cbar, ax=axes[0,0])
axes[0,0].set_ylabel('Output unit')
axes[0,0].set_xlabel('Hidden unit')
axes[0,0].set_title('Weights\nInitial')
cbar = axes[1,0].imshow(original_output_activities)
axes[1,0].set_ylabel('Output unit')
axes[1,0].set_xlabel('Input pattern')
axes[1,0].set_title('Output activities\n')
plt.colorbar(cbar, ax=axes[1,0])
for i in range(len(output_activities_history)):
    cbar = axes[0,i+1].imshow(output_weight_matrix_history[i])
    plt.colorbar(cbar, ax=axes[0,i+1])
    axes[0,i+1].set_title('Block %i' % (i + 1))
    cbar = axes[1,i+1].imshow(output_activities_history[i])
    plt.colorbar(cbar, ax=axes[1,i+1])
fig.tight_layout()
fig.subplots_adjust(wspace=0.3, hspace=0.8)
fig.show()

<IPython.core.display.Javascript object>

In [16]:
np.argmax(output_activities, axis=0)

array([0, 1, 3, 3, 4])

In [18]:
# drive Oja with the delta (target - output)
output_weight_matrix = np.copy(original_output_weight_matrix)
output_activities = np.copy(original_output_activities)
output_weight_matrix_history = []
output_activities_history = []
blocks = 4
block_size = 10
for block_o in range(blocks):
    for block_i in range(block_size):
        input_pattern_indexes = np.arange(len(input_pattern_matrix))
        this_random.shuffle(input_pattern_indexes)
        for pattern_index in input_pattern_indexes:
            this_input = input_pattern_matrix[:,pattern_index]
            target_post = this_input
            delta_post = np.subtract(this_input, output_activities[:,pattern_index])
            # delta_post = np.sum(np.subtract(this_input, output_activities[:,pattern_index])**2.)
            delta_w_matrix = np.empty_like(output_weight_matrix)
            for j in range(hidden_dim):
                for k in range(output_dim):
                    delta_w_matrix[k,j] = oja_delta_w(hidden_activities[j,pattern_index], delta_post[k], output_weight_matrix[k,j], learning_rate)
            output_weight_matrix = np.add(output_weight_matrix, delta_w_matrix)
            output_activities = output_weight_matrix.dot(hidden_activities)
    output_weight_matrix_history.append(np.copy(output_weight_matrix))
    output_activities_history.append(np.copy(output_activities))

In [19]:
fig, axes = plt.subplots(2, blocks + 1, figsize=(10., 4.25))
cbar = axes[0,0].imshow(original_output_weight_matrix)
plt.colorbar(cbar, ax=axes[0,0])
axes[0,0].set_ylabel('Output unit')
axes[0,0].set_xlabel('Hidden unit')
axes[0,0].set_title('Weights\nInitial')
cbar = axes[1,0].imshow(original_output_activities)
axes[1,0].set_ylabel('Output unit')
axes[1,0].set_xlabel('Input pattern')
axes[1,0].set_title('Output activities\n')
plt.colorbar(cbar, ax=axes[1,0])
for i in range(len(output_activities_history)):
    cbar = axes[0,i+1].imshow(output_weight_matrix_history[i])
    plt.colorbar(cbar, ax=axes[0,i+1])
    axes[0,i+1].set_title('Block %i' % (i + 1))
    cbar = axes[1,i+1].imshow(output_activities_history[i])
    plt.colorbar(cbar, ax=axes[1,i+1])
fig.tight_layout()
fig.subplots_adjust(wspace=0.3, hspace=0.8)
fig.show()

<IPython.core.display.Javascript object>

In [20]:
print(np.sum(original_output_weight_matrix ** 2., axis=1))
print(np.sum(output_weight_matrix ** 2., axis=1))

[0.04574758 0.01588866 0.0496846  0.09045911 0.09928844]
[0.14649839 0.23776723 0.20094431 0.2750129  0.12207599]


In [21]:
np.argmax(output_activities, axis=0)

array([0, 1, 2, 3, 4])

In [22]:
from scipy.optimize import least_squares

In [23]:
def least_sq_output(w, hidden_activities, target):
    output_activities = w.reshape(original_output_weight_matrix.shape).dot(hidden_activities)
    return target.flatten() - output_activities.flatten()

In [24]:
result = least_squares(least_sq_output, original_output_weight_matrix.flatten(), args=(hidden_activities, input_pattern_matrix),
                      bounds=(-1.,1.))

In [25]:
result

 active_mask: array([-1,  1,  0,  0,  1,  1,  0,  0,  0,  1,  0,  1,  0,  0, -1,  0,  0,
        0,  1,  0,  0,  0, -1,  0,  1])
        cost: 0.5696601228573192
         fun: array([ 0.34974758, -0.19357678,  0.07839942, -0.07768488,  0.15275858,
       -0.2516891 ,  0.26119754, -0.07439478,  0.13834507, -0.00554463,
        0.17274983, -0.19028067,  0.31040156, -0.25131515, -0.09140326,
       -0.15793701,  0.16572756, -0.25428257,  0.28651414,  0.15246553,
        0.26237893,  0.10483541, -0.0887385 ,  0.06816802,  0.48893745])
        grad: array([ 6.05145213e-02, -4.75849940e-02, -2.58074842e-09, -5.68874729e-10,
       -4.55885220e-02, -9.49696340e-02,  2.06518830e-11,  9.64253000e-10,
        6.92850777e-11, -3.97923762e-03,  8.18449954e-10, -5.44215170e-02,
        7.73650657e-10, -6.42170234e-10,  1.39797133e-02, -1.23561820e-09,
       -5.67909081e-10,  1.23528968e-09, -6.22438593e-02, -3.15869622e-10,
       -2.19560919e-09,  3.64167869e-09,  1.24666640e-01, -8.94392757e-03,

In [27]:
least_sq_output_weights = np.array(result.x).reshape(output_weight_matrix.shape)

fig, axes = plt.subplots(1, 2, figsize=(7,2.5))
cbar = axes[0].imshow(least_sq_output_weights)
fig.colorbar(cbar, ax=axes[0])
axes[0].set_title('Output weights')
axes[0].set_ylabel('Output units')
axes[0].set_xlabel('Hidden units')

least_sq_output_activities = least_sq_output_weights.dot(hidden_activities)
cbar = axes[1].imshow(least_sq_output_activities)
fig.colorbar(cbar, ax=axes[1])
axes[1].set_title('Output activities')
axes[1].set_ylabel('Output units')
axes[1].set_xlabel('Input patterns')
fig.suptitle('Least squares')
fig.tight_layout()

<IPython.core.display.Javascript object>

In [28]:
np.argmax(least_sq_output_activities, axis=0)

array([0, 1, 2, 3, 4])

In [29]:
print(np.sum(least_sq_output_weights ** 2., axis=1))

[4.40885776 2.82982931 2.98475996 1.93532787 3.09560766]
