In [1]:
import numpy as np

In [2]:
W1shape = (5, 2)
W2shape = (2, 10)

In [3]:
W1 = np.arange(0, 10).reshape(W1shape)
W2 = np.arange(10, 30).reshape(W2shape)

In [4]:
W1,W2

(array([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]]),
 array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]))

In [5]:
W1.flatten()

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [6]:
Ws = [W1.flatten(), W2.flatten()]
Ws

In [7]:
Ws = np.hstack((W1.flatten(), W2.flatten()))
Ws

In [8]:
Ws[0]

In [9]:
Ws[0] = 1111
Ws

In [10]:
W1

array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])

Nope. Didn't work.  Must first construct `Ws`, then assign `W1` and `W2` as views into `Ws`.

In [11]:
n_weights = W1shape[0] * W1shape[1]
n_weights += W2shape[0] * W2shape[1]
n_weights

30

In [12]:
Ws = np.zeros(30)
Ws

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Now, create our weight matrices as views into this vector.

In [13]:
first_index = 0

nw = np.prod(W1shape)
last_index = first_index + nw
W1 = Ws[first_index:last_index].reshape(W1shape)
first_index += nw

nw = np.prod(W2shape)
last_index += nw
W2 = Ws[first_index:last_index].reshape(W2shape)

W1.shape, W2.shape

((5, 2), (2, 10))

In [14]:
W1, W2

(array([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]]),
 array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

In [15]:
Ws

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Before we initalized W1 and W2 like
```
W1 = np.arange(0, 10).reshape(5, 2)
W2 = np.arange(10, 30).reshape(2, 10)
```
Now we must make sure we assign to the values of the existing matrix `W1`, not create a new version of that matrix.

In [16]:
W1[:] = np.arange(0, 10).reshape(W1shape)
W2[:] = np.arange(10, 30).reshape(W2shape)
W1, W2

(array([[0., 1.],
        [2., 3.],
        [4., 5.],
        [6., 7.],
        [8., 9.]]),
 array([[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]]))

In [17]:
Ws

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29.])

In [18]:
Ws[0] = 1111
Ws

array([1.111e+03, 1.000e+00, 2.000e+00, 3.000e+00, 4.000e+00, 5.000e+00,
       6.000e+00, 7.000e+00, 8.000e+00, 9.000e+00, 1.000e+01, 1.100e+01,
       1.200e+01, 1.300e+01, 1.400e+01, 1.500e+01, 1.600e+01, 1.700e+01,
       1.800e+01, 1.900e+01, 2.000e+01, 2.100e+01, 2.200e+01, 2.300e+01,
       2.400e+01, 2.500e+01, 2.600e+01, 2.700e+01, 2.800e+01, 2.900e+01])

In [19]:
W1

array([[1.111e+03, 1.000e+00],
       [2.000e+00, 3.000e+00],
       [4.000e+00, 5.000e+00],
       [6.000e+00, 7.000e+00],
       [8.000e+00, 9.000e+00]])

In [20]:
W2[0, 0] = -999
W2

array([[-999.,   11.,   12.,   13.,   14.,   15.,   16.,   17.,   18.,
          19.],
       [  20.,   21.,   22.,   23.,   24.,   25.,   26.,   27.,   28.,
          29.]])

In [21]:
Ws

array([ 1.111e+03,  1.000e+00,  2.000e+00,  3.000e+00,  4.000e+00,
        5.000e+00,  6.000e+00,  7.000e+00,  8.000e+00,  9.000e+00,
       -9.990e+02,  1.100e+01,  1.200e+01,  1.300e+01,  1.400e+01,
        1.500e+01,  1.600e+01,  1.700e+01,  1.800e+01,  1.900e+01,
        2.000e+01,  2.100e+01,  2.200e+01,  2.300e+01,  2.400e+01,
        2.500e+01,  2.600e+01,  2.700e+01,  2.800e+01,  2.900e+01])

In [22]:
len(Ws)

30

If we want to update all of the weights, using an optimization algorithm for example, don't forget to use `Ws[:] = ` rather than `Ws = `.

In [23]:
Ws[:] = np.arange(30) * 100
Ws

array([   0.,  100.,  200.,  300.,  400.,  500.,  600.,  700.,  800.,
        900., 1000., 1100., 1200., 1300., 1400., 1500., 1600., 1700.,
       1800., 1900., 2000., 2100., 2200., 2300., 2400., 2500., 2600.,
       2700., 2800., 2900.])

In [24]:
W1, W2

(array([[  0., 100.],
        [200., 300.],
        [400., 500.],
        [600., 700.],
        [800., 900.]]),
 array([[1000., 1100., 1200., 1300., 1400., 1500., 1600., 1700., 1800.,
         1900.],
        [2000., 2100., 2200., 2300., 2400., 2500., 2600., 2700., 2800.,
         2900.]]))