Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract weight matrix #91

Closed
BrianMiner opened this issue Apr 29, 2015 · 31 comments
Closed

Extract weight matrix #91

BrianMiner opened this issue Apr 29, 2015 · 31 comments

Comments

@BrianMiner
Copy link

Is it possible to extract the weight matrix from a network?

@fchollet
Copy link
Member

Sure. The method model.save_weights() will do it for you and store the weights to hdf5.

If you want to do it manually, you'd do something like:

for layer in model.layers:
    weights = layer.get_weights() # list of numpy arrays

@BrianMiner
Copy link
Author

Thanks! Are the bias included here?

@BrianMiner
Copy link
Author

I cant quite figure out the format of the weights returned....even with
a simple model with 0 hidden layers, I get back a large number of
weights. I also see that activation parameters have weights associated
with them....
Wondering if this format is documented or is going through the code the
only way? Biases are here to?

I am running this simple model on the otto dataset:

print("Building model...")

model = Sequential()
model.add(Dense(dims, 10, init='glorot_uniform'))
model.add(PReLU((10,)))
model.add(BatchNormalization((10,)))
model.add(Dropout(0.5))

model.add(Dense(10, nb_classes, init='glorot_uniform'))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer="adam")

print("Training model...")

model.fit(train_X, train_y, nb_epoch=1, batch_size=250, validation_sp

And then this

for layer in model.layers:
g=layer.get_config()
h=layer.get_weights()
print (g)
print (h)

And the output is....

​{'output_dim': 10, 'init': 'glorot_uniform', 'activation': 'linear',
'name': 'Dense', 'input_dim': 94}
[array([[ -7.99849009e-02, -2.45893497e-02, 4.69488823e-02,
6.59520472e-04, 5.47647455e-02, -5.37778421e-02,
-6.25924897e-02, -8.44473401e-02, -8.43347936e-02,
-1.41206931e-02],
[ -7.28297196e-03, -6.94909196e-02, 5.09896591e-02,
1.07518776e-01, -1.40819659e-01, 1.28122287e-03,
-7.18242022e-02, -7.47769234e-02, -5.61066325e-04,
5.22953511e-02],
[ 4.48867873e-02, -5.98358224e-02, -3.34403798e-02,
-6.98303095e-02, 1.42869021e-02, 9.69443043e-03,
-5.72198990e-02, 1.26906643e-01, 1.44663706e-01,
-1.05488360e-01],
[ -4.94764224e-03, -6.72982314e-02, 1.98947019e-02,
-2.34841952e-04, 8.15450053e-02, 9.93956783e-02,
5.33930437e-02, 9.88775421e-02, 2.62561318e-03,
-1.37490951e-01],
[ -9.69757727e-02, -1.93766233e-02, -2.67758527e-02,
1.24528297e-03, -4.38467242e-02, -3.73318840e-02,
-3.62683822e-02, 3.58399029e-02, -2.16549907e-03,
6.28375524e-02],
[ 2.30916925e-02, -5.03360711e-02, -3.48193628e-03,
-1.11732154e-02, -2.15673704e-02, -5.69503631e-02,
9.87205698e-03, -2.12394581e-02, 3.23201450e-02,
8.15617402e-02],
[ -1.22399136e-01, -1.53275598e-03, -2.49756349e-02,
2.55169260e-02, 6.42912644e-02, -8.35097613e-02,
-1.16564598e-01, 5.86628949e-02, -3.66701814e-02,
7.93936586e-02],
[ -8.18904005e-03, 9.61438433e-02, 1.05611869e-01,
-5.29767001e-02, 1.30970726e-01, -1.05098848e-01,
-5.67985073e-02, 8.96001608e-02, -4.06112779e-03,
7.38853071e-02],
[ -1.10601817e-01, 4.01777134e-02, -1.04958505e-01,
-1.86499188e-01, -2.37804665e-02, -3.76688537e-02,
-1.45232209e-01, -6.36418185e-02, -7.63368207e-02,
-1.17734138e-01],
[ -2.40944892e-02, -1.99559485e-02, -3.10534144e-02,
2.51051600e-02, -3.66193065e-02, -6.94967520e-02,
4.28971762e-02, -4.68141541e-02, 2.08197228e-02,
4.14676310e-02],
[ 1.85548637e-01, -2.00046197e-01, -1.74210642e-01,
9.19546529e-02, 2.09622917e-01, 1.66180606e-01,
2.18369084e-01, 1.06692421e-01, 2.23854503e-02,
-1.63984764e-01],
[ -2.25432936e-02, 1.21225411e-02, -1.77713444e-02,
2.49248395e-02, -9.09446976e-02, -2.12943360e-02,
-5.79632667e-02, -7.98705389e-02, 1.45603023e-02,
-3.29101095e-02],
[ 1.06982420e-02, 8.10667692e-03, -7.40422225e-03,
8.73337501e-02, -4.70680094e-02, -6.96424801e-04,
-3.52869543e-02, -5.75422152e-02, -2.50050629e-03,
5.30206962e-02],
[ -1.05554822e-01, 1.05913849e-01, -1.99050840e-01,
-1.09049315e-01, -2.13796040e-01, -2.46177639e-01,
3.93859830e-02, -1.60071728e-01, -2.43849880e-01,
-9.01017306e-02],
[ -2.66245883e-02, 1.02689584e-01, -1.13140709e-01,
-2.50536214e-01, -1.36009340e-01, -9.52662533e-02,
1.40512889e-02, -2.11842101e-01, -2.26055642e-01,
-1.35109722e-01],
[ -7.53730907e-03, -1.64213430e-03, -6.63893471e-02,
-7.66609176e-02, -3.59994294e-02, -9.46264700e-02,
2.58486499e-02, 2.08334686e-03, -8.75103806e-02,
-1.10772615e-02],
[ 1.52794017e-02, 7.91648709e-03, 1.07050460e-01,
-3.69858265e-02, 5.07007745e-02, -5.97656110e-02,
-3.49382162e-02, 1.55669800e-01, 1.03023857e-02,
-4.40477760e-02],
[ 4.07815059e-03, 2.83262161e-02, -1.48807663e-03,
9.73448516e-02, 3.56439926e-03, -1.23538034e-02,
9.76411415e-02, 2.06518137e-02, 6.17017622e-02,
-8.53835248e-02],
[ -8.09946660e-02, 2.61380996e-02, 2.63300387e-02,
9.93539858e-02, -1.10327820e-01, 4.95747644e-02,
-3.89292545e-02, -1.14508714e-02, 3.19129217e-02,
1.23465986e-01],
[ 3.59885269e-03, 2.76080271e-03, -2.78764395e-02,
4.24239723e-02, 1.53515202e-03, -2.69810376e-02,
1.63534000e-02, 5.40332907e-02, 1.14591409e-01,
3.83767457e-02],
[ 2.84191996e-03, -6.19442984e-02, -8.03084884e-03,
-1.51138132e-02, 4.29634390e-02, 3.55585086e-02,
-3.89301471e-02, -3.44196220e-02, -1.92320194e-02,
-6.95947810e-02],
[ 8.89051092e-03, -3.63276576e-02, 3.77512802e-02,
-8.61697928e-02, 4.37051203e-02, 1.25958620e-01,
1.27073221e-02, 1.74475895e-02, 3.51703513e-02,
1.76466516e-02],
[ 7.71217281e-02, 1.81580258e-02, 2.90213940e-02,
4.13172177e-02, 8.94682418e-02, -5.88464358e-03,
-8.54934082e-02, -6.62388395e-03, 6.46094670e-03,
-1.88742827e-02],
[ -2.91908240e-02, -2.65753849e-02, 1.52980919e-02,
6.58449558e-02, -7.86922291e-02, 4.58466034e-02,
-3.39622131e-02, 8.56353043e-02, -3.02583216e-02,
1.06242259e-01],
[ -3.46668468e-02, 5.48445800e-02, -9.99637775e-02,
-1.40776997e-01, -2.33112086e-01, -1.12747490e-01,
-3.38637078e-03, -2.72000452e-01, -2.43135959e-01,
1.40148791e-02],
[ 4.55380939e-02, -2.05104710e-01, -1.23208875e-01,
7.37236719e-02, 3.40874932e-02, 1.38913292e-01,
2.39351600e-01, -1.87328804e-02, 7.01187834e-02,
-5.02784378e-02],
[ -1.56907260e-02, -6.93088723e-02, -1.32831942e-01,
1.40660959e-02, 7.01844406e-02, 6.88994173e-02,
1.09614443e-01, 5.21141313e-03, 2.66728291e-02,
-2.12535714e-01],
[ 8.63241877e-03, -2.00266780e-01, 4.21728732e-03,
-2.71931798e-04, -3.74458866e-02, -1.02733646e-02,
1.26404038e-02, 4.53453957e-02, -5.47819209e-03,
-1.78313855e-01],
[ -3.40579999e-02, 3.68030410e-02, 2.78233896e-02,
-6.87630905e-02, 5.79211738e-02, 3.53004862e-02,
2.97676136e-02, -2.83821290e-03, -6.19672378e-02,
4.38129833e-02],
[ -1.29424714e-01, -5.27272972e-02, 6.69243394e-02,
1.19757129e-01, 3.84862554e-02, 9.18853869e-02,
-4.82323177e-02, 1.87875149e-02, 4.63434479e-02,
5.18075847e-02],
[ -4.84346313e-02, -7.05440572e-03, -1.17486716e-01,
-3.78191092e-02, -1.63198220e-02, 6.35379808e-02,
-2.28866377e-02, -4.73959864e-02, 6.47443882e-02,
1.71767526e-02],
[ 7.92322710e-02, -3.99799449e-03, -2.68663861e-02,
2.16343925e-02, 7.08523118e-02, 9.40224531e-03,
-2.73172165e-02, 3.90645337e-02, 1.43386517e-02,
-3.28124923e-02],
[ -7.49563779e-02, 8.96276663e-03, -3.00325036e-02,
-9.26680367e-02, -1.68292320e-01, -1.46136493e-01,
-2.32867781e-02, -1.37576449e-01, -1.08664407e-01,
-3.03448062e-02],
[ -1.58937859e-01, 1.24359548e-01, 1.72723048e-01,
-7.16377564e-02, 8.15923267e-02, -9.26028128e-02,
-3.52089676e-02, -1.71921631e-01, 1.19313224e-01,
5.04153287e-02],
[ 5.87777373e-03, -9.82951252e-02, -1.44826277e-01,
2.87068172e-03, 5.77770075e-02, 3.83663771e-02,
-7.79279419e-02, 5.30456023e-02, -4.32744374e-02,
-8.80488631e-02],
[ -9.51507092e-02, 2.48270494e-02, 1.85502184e-01,
-6.15492334e-02, 1.18012058e-01, -3.35554541e-02,
1.92680772e-03, 1.25766050e-03, -1.33313640e-01,
3.44885089e-02],
[ -8.23068727e-02, 6.00759323e-03, 3.07414607e-02,
3.78444659e-02, 1.59902793e-02, 1.32360708e-02,
-5.58786321e-02, -8.98881639e-03, 9.11022060e-02,
1.27618163e-01],
[ -8.39977134e-02, -2.19334460e-02, 2.93163173e-02,
-9.67890487e-03, -1.45040381e-01, -1.11580661e-02,
-5.45275977e-02, 3.19171362e-02, 6.28167381e-02,
5.43802318e-02],
[ -3.19924245e-02, 9.04734505e-02, -8.88122521e-03,
-6.81262022e-02, 3.81438529e-02, -1.10441657e-01,
-2.75602345e-02, 4.45400922e-02, 1.65449719e-01,
-4.30487717e-03],
[ -8.83774422e-02, 7.58000898e-02, -1.84174130e-01,
-8.53399835e-02, -1.66059125e-01, -1.49298746e-01,
2.59884981e-02, -1.58198291e-01, -1.70953601e-01,
-2.42574980e-02],
[ 3.76959883e-02, 3.19914391e-02, 8.79207932e-02,
3.22329258e-02, -4.57678893e-03, 6.73169597e-02,
3.67542760e-03, -1.08364242e-02, -7.73627399e-03,
1.05920958e-01],
[ 2.22782407e-01, -1.80813558e-01, -5.38938983e-02,
-1.37104372e-01, 1.53482621e-01, 1.41474566e-01,
6.76793766e-02, 9.57258689e-02, 1.20902074e-01,
-8.28063868e-02],
[ 1.45137965e-02, 4.45537228e-02, -2.40643314e-01,
-5.28557398e-02, -9.96888936e-02, 3.65267589e-02,
3.64279623e-02, -8.18503052e-02, -2.02683058e-01,
-1.82493162e-01],
[ 8.98117186e-02, -3.25222791e-02, 4.78793257e-03,
-3.44254824e-02, -9.32265642e-02, -3.37535001e-02,
-2.59333102e-02, -3.14748988e-02, 1.50924621e-02,
-3.88573589e-02],
[ 9.16648589e-02, -9.07544921e-03, 1.15557421e-02,
-6.69898791e-02, 6.43994728e-02, -1.03488028e-01,
-8.25832880e-02, 1.04004400e-01, -1.32508566e-02,
-2.85507593e-02],
[ 1.18606710e-01, -1.62804455e-01, -8.94677872e-02,
9.15753083e-02, -2.46587921e-02, 7.59351306e-02,
6.87165226e-02, -4.79627016e-02, -3.69506298e-02,
-1.07193629e-01],
[ 1.17239054e-02, -2.67437157e-02, -7.68840391e-04,
8.10397970e-02, -1.99904474e-02, 5.51565844e-02,
5.58673400e-04, 3.17302125e-02, 1.11154388e-01,
1.07863925e-01],
[ -1.04641810e-01, 2.78747603e-02, -1.45182739e-03,
-1.53233877e-01, -6.29874225e-02, -1.29592620e-01,
-1.51859556e-01, -4.14495814e-02, -8.45453923e-02,
-1.15042618e-01],
[ -2.20509269e-02, -1.47940439e-01, -4.76452491e-02,
-2.20822576e-02, -2.04331960e-03, 4.55278328e-02,
7.08764808e-02, -1.73128630e-02, 1.93976654e-02,
-2.61378301e-02],
[ 7.04435092e-02, 2.91919488e-03, 4.20234866e-02,
-7.55152665e-02, -4.63803985e-02, -1.07999505e-01,
-8.16274725e-02, 4.19000215e-02, 7.89168322e-02,
-3.86468662e-04],
[ 5.83150997e-03, -1.60543542e-01, -1.42290587e-02,
-6.85734380e-03, 8.61183085e-03, -4.91538703e-02,
3.61739483e-03, 3.47957593e-02, -9.38960015e-03,
1.87941107e-02],
[ 3.86580568e-02, -6.17538775e-02, -3.52311159e-02,
-6.14824396e-03, 3.28743091e-02, 2.30341515e-02,
4.17623711e-02, 3.78719485e-02, 6.40326123e-02,
-1.25724117e-01],
[ -2.69466654e-02, 8.76149869e-02, 8.43603169e-02,
-2.43710820e-02, -7.43506410e-03, -5.92964338e-03,
1.19778728e-02, 1.49802507e-02, -1.24118188e-01,
1.19810388e-01],
[ 1.43340287e-02, -8.91459774e-02, 3.11272246e-02,
1.25600087e-01, 7.31098870e-02, -9.06883544e-03,
3.69710535e-03, -5.08471161e-02, 7.35172675e-02,
-1.54848409e-01],
[ 4.85676467e-02, -3.05382790e-02, 3.27869005e-02,
6.41000141e-02, 7.11917729e-02, 1.18288425e-01,
6.83290253e-03, 9.93292852e-02, -3.07955208e-03,
-1.28590112e-02],
[ -3.03334700e-02, 1.08860071e-02, 8.97625740e-02,
-4.83023573e-02, -3.12615198e-02, -8.61628871e-02,
-1.04541486e-02, 1.33751888e-01, 8.89201773e-02,
-1.06815586e-01],
[ 1.64302198e-01, -1.73583248e-01, -1.29705113e-01,
1.15130432e-01, 5.06143426e-02, 1.38332743e-01,
8.68229188e-02, 2.03039990e-02, 9.97073441e-03,
-1.84693280e-01],
[ 1.39708295e-02, -1.20358859e-01, 2.09749290e-02,
1.59206874e-01, -1.31967745e-01, 9.07331382e-04,
-2.98483952e-02, -2.11940084e-02, 5.57213581e-02,
1.17593488e-01],
[ 1.95239811e-02, 2.48040812e-02, 1.13217120e-01,
-1.32227109e-01, 4.16574397e-02, 4.98457202e-02,
-3.24562114e-02, -6.51531123e-02, -7.33154552e-02,
-3.98947626e-02],
[ 2.66432299e-01, -2.97074213e-01, -5.99969370e-02,
-5.87481183e-02, 1.60074011e-01, 1.30176286e-01,
2.77686796e-01, 1.31290317e-01, 1.22633535e-01,
-1.09230713e-01],
[ 2.00969073e-02, -1.84964992e-01, -5.83985571e-03,
-7.41362658e-02, 1.38365878e-02, 1.06839467e-01,
-2.74847583e-02, -1.92903743e-02, -5.20260714e-03,
2.50301918e-04],
[ -6.17150407e-02, 8.06537180e-02, 7.02343393e-02,
-5.42272016e-02, 9.97773977e-03, -1.98928364e-01,
-2.17237932e-01, -6.73968443e-03, -1.97736271e-01,
5.00901130e-02],
[ -1.66149509e-02, -2.62200709e-03, -2.08381446e-02,
2.57409840e-02, -2.42515149e-02, -3.21673100e-02,
1.49380691e-02, -2.72776353e-02, -8.78237320e-02,
1.68875553e-02],
[ -5.57623273e-02, 3.65157464e-02, -1.13528556e-01,
9.79066670e-03, -5.63900992e-02, -3.71015632e-02,
-1.53407485e-04, -1.72846924e-01, -1.09153670e-01,
-3.32763024e-02],
[ 5.97840241e-03, 4.14377038e-02, 4.48701934e-03,
-6.99469980e-02, 4.48644760e-02, 1.05679300e-01,
-6.51186109e-02, 9.74034880e-03, 2.48868577e-02,
-3.36158742e-02],
[ 1.20077827e-02, -2.03385970e-02, -2.75987445e-03,
-1.16149630e-02, 1.06375004e-01, 5.39069232e-02,
-8.91773269e-02, 2.75564667e-03, 4.46157051e-02,
5.58006479e-02],
[ -1.16916048e-02, 4.60958859e-02, -2.17421025e-02,
-1.86881864e-01, 4.66562070e-02, 4.02356717e-02,
-1.15677821e-01, 4.66245330e-02, -3.02338327e-02,
1.65767924e-03],
[ -6.92209031e-02, -5.00074868e-02, 3.97814441e-02,
2.01908782e-01, -1.24789202e-01, 1.66400945e-01,
-1.04113481e-01, 2.69553450e-02, 4.85165628e-02,
1.82033767e-01],
[ -1.28536858e-01, -3.02310722e-03, 8.69889310e-02,
8.70096528e-02, -4.53932410e-02, 1.66790953e-01,
-1.28925981e-01, -1.83845627e-02, 2.31594216e-02,
1.55204566e-01],
[ -6.96121049e-02, 5.98652894e-02, -4.84150225e-03,
1.57034354e-02, 3.75668782e-02, -1.64650619e-02,
-1.70003480e-02, 1.68722860e-03, -1.58515948e-02,
-4.94031462e-02],
[ -6.33468947e-02, -8.75024885e-03, 5.29224644e-03,
1.02262168e-01, -1.58498505e-02, 5.67770925e-02,
-5.12286565e-02, 7.78941169e-02, 9.62220805e-02,
1.12947449e-01],
[ -1.71659989e-01, 1.16973947e-01, -1.05326386e-01,
-1.66891420e-01, -1.01974213e-01, -6.23915854e-02,
2.04867068e-02, -2.09272962e-01, -4.98385084e-02,
6.39348977e-02],
[ -2.33088886e-02, 4.77167627e-04, 5.12577177e-03,
8.72510224e-02, -6.98289790e-02, -3.83786347e-02,
-1.61912813e-01, 9.99245590e-03, 6.40901957e-03,
1.06622519e-03],
[ -3.22696147e-03, 7.50338364e-04, -4.13295523e-02,
1.11532490e-02, 2.13605470e-03, 7.01725919e-02,
6.56769972e-02, -1.22168268e-02, -3.98609596e-02,
4.02942635e-02],
[ -6.30274025e-02, -8.49640888e-02, 1.38809509e-02,
2.04496514e-01, -1.47865061e-01, 1.14922183e-01,
-1.48793624e-01, 5.45763123e-02, 4.44691877e-02,
1.79254325e-01],
[ -3.70425750e-02, -2.35610560e-02, -4.93268816e-03,
1.31040411e-02, 2.73776025e-02, 1.15118941e-01,
-1.38940692e-01, 3.52187024e-02, 1.86298744e-02,
1.01562678e-01],
[ 1.76901211e-03, 7.25849291e-02, 4.30484675e-02,
-5.29386234e-02, 8.85783359e-02, -2.10795505e-03,
-1.75977381e-02, -6.26338776e-02, 1.10745137e-02,
-3.01402319e-02],
[ -6.51010148e-03, 4.46026348e-03, 3.29594028e-02,
7.42153750e-02, 4.51823405e-02, 2.38861625e-02,
-1.51084355e-01, 2.58145276e-02, 9.78499330e-02,
1.13541447e-02],
[ -1.70992478e-02, -4.46845213e-02, 6.57091110e-02,
-7.22719951e-02, 1.38301387e-01, 3.39854928e-02,
-2.71362382e-03, 5.89179971e-02, 3.98174930e-02,
-2.14658967e-02],
[ 1.47533650e-01, -1.04364285e-01, -8.94515699e-02,
6.64953902e-02, 2.39417285e-02, -2.14146403e-02,
6.62148791e-02, -1.08153465e-02, 5.16891148e-02,
-7.39043785e-02],
[ 2.40624194e-03, 5.38938896e-02, 3.42756790e-02,
4.45555846e-02, 1.21189843e-01, -4.61384428e-02,
4.66344444e-02, 2.32483197e-02, 1.73072110e-02,
-5.31442668e-02],
[ 4.43278059e-02, -1.25351815e-01, 3.30721381e-03,
-4.40161903e-02, 4.02718580e-02, -2.78869336e-02,
5.86890902e-02, 7.31859328e-02, 2.70040392e-02,
1.43055350e-02],
[ -7.64451720e-02, 2.79163924e-02, -4.95644115e-02,
1.40135399e-04, -3.44646526e-02, -5.21001244e-02,
-1.36563861e-02, -3.01335168e-02, -2.14250554e-02,
3.37404738e-02],
[ 5.63956295e-03, 1.31333839e-02, 1.86413210e-02,
3.75087111e-02, 1.14190531e-01, -1.23009164e-02,
5.04972589e-02, -1.08671571e-02, -1.72429908e-02,
6.35628664e-03],
[ -7.81176131e-02, 1.22820378e-02, -4.88642529e-02,
-1.68663757e-02, -4.93644076e-02, -8.76378246e-03,
6.76276358e-02, -2.72725790e-02, -1.12157596e-01,
6.67786778e-02],
[ -1.35052643e-01, 6.17655018e-02, -9.91510117e-02,
-1.36961912e-01, -9.93016461e-02, -2.23026506e-01,
5.43391763e-02, -7.92043386e-02, -4.64242573e-02,
-1.31950335e-01],
[ 9.02050417e-02, 4.89889014e-02, -3.76353176e-02,
-8.21240609e-02, 5.91559984e-02, -2.88601734e-02,
-2.75170553e-02, 4.00444237e-02, 1.02223471e-01,
-8.85973866e-03],
[ -2.79819946e-02, 1.57770779e-02, -1.74634653e-01,
-1.50595889e-01, -1.98462406e-01, -1.58971182e-01,
5.48719855e-02, -1.66760843e-01, -2.75554553e-01,
2.64535338e-02],
[ -8.80281079e-02, -1.79294346e-02, 7.08099011e-02,
1.95550220e-02, -9.05946128e-02, 5.07943950e-02,
-2.29652551e-02, -4.74400823e-02, -2.48743014e-02,
-3.02452203e-02],
[ -1.41639591e-01, -1.14204327e-02, 6.37410597e-03,
3.71080625e-02, -2.52950025e-01, 9.53980112e-02,
-1.04131975e-01, 1.33543914e-02, 5.09378104e-02,
1.09227010e-01],
[ -1.41634459e-02, 7.14195268e-02, 6.84697215e-02,
-9.43577039e-02, 5.96288760e-03, -1.19336735e-01,
-3.07758752e-02, -5.98731421e-02, -9.55384648e-03,
3.84556520e-04],
[ -2.50815619e-02, 4.53682164e-02, 5.46400441e-02,
-2.23102338e-02, 1.17296376e-01, 5.90221088e-02,
-1.21542180e-02, 8.26175342e-03, 6.07650355e-02,
-3.03832471e-02],
[ 9.42653504e-02, -7.87588973e-02, -6.33857297e-02,
5.43245905e-02, 2.31799575e-02, 1.10878154e-01,
4.02822240e-02, 1.89821951e-02, 6.25729673e-02,
-4.26672129e-02],
[ -2.05037699e-02, 1.29917332e-01, -1.01022859e-01,
-9.12096671e-02, 9.67299134e-02, 4.83183980e-02,
-2.53932230e-02, -5.23461370e-02, 9.16840297e-02,
5.16286812e-02]]), array([ 0.13384319, 0.21280285,
0.12673509, 0.0481298 , 0.15588539,
0.16365858, 0.16476655, 0.21108415, 0.17096487, 0.1085302 ])]
{'name': 'PReLU', 'input_shape': (10,)}
[array([ 0.0933486 , 0.21287123, 0.00783342, 0.01653181, 0.17309322,
0.16339809, 0.17657923, 0.16427643, 0.15932031, 0.1788937 ])]
{'epsilon': 1e-06, 'mode': 0, 'name': 'BatchNormalization',
'input_shape': (10,)}
[array([-0.24547945, -0.21909049, 0.24213387, 0.2301027 , -0.2308098 ,
-0.26039818, 0.25852819, -0.2666662 , -0.23721958,
-0.21551781]), array([-0.18698727, 0.18234116, -0.19378199,
0.14785786, 0.20994197,
-0.17636872, -0.0286329 , 0.19325841, 0.15674127, 0.18424748])]
{'p': 0.5, 'name': 'Dropout'}
[]
{'output_dim': 9, 'init': 'glorot_uniform', 'activation': 'linear',
'name': 'Dense', 'input_dim': 10}
[array([[ 0.09531779, 0.08962424, 0.2058567 , -0.01249131, 0.30507946,
-0.53983095, -0.06707083, 0.19728828, 0.20076663],
[-0.32254363, -0.42794729, -0.25050942, -0.33067779, -0.41560606,
0.44383431, -0.41416482, 0.06048449, -0.41940528],
[ 0.44522338, -0.32401378, -0.47518886, 0.18116826, 0.37900766,
-0.25756208, 0.19068314, 0.20710987, 0.54537 ],
[-0.31705924, -0.24385525, -0.42045825, -0.06039081, -0.23035229,
0.13377106, -0.34439084, 0.50066419, -0.29471258],
[-0.46195437, 0.32470544, 0.23317512, -0.03074575, -0.03195262,
-0.12666796, -0.0689242 , 0.29082389, -0.03790227],
[ 0.08793353, 0.23260939, -0.02797571, 0.08840966, 0.26917765,
-0.42360304, 0.53772368, -0.30116916, 0.14197116],
[-0.02250078, -0.25081609, -0.09740926, 0.40152244, -0.22063047,
0.26022703, -0.34145828, -0.53870416, -0.48703162],
[-0.45320813, 0.50861172, 0.19536802, 0.07115475, 0.16693423,
-0.20375885, -0.26160637, -0.09466415, 0.0106583 ],
[-0.03785352, 0.45134675, 0.46924738, 0.003416 , -0.3582257 ,
-0.38137255, -0.47833458, -0.37664575, 0.31648622],
[-0.20752858, 0.18646851, 0.23139626, -0.32362606, -0.12266297,
0.43289664, -0.07509048, -0.52767154, -0.15752394]]),
array([-0.16717777, 0.16997226, 0.05989215, -0.16888715, -0.16850433,
0.16732873, -0.16105396, 0.09626825, -0.12726908])]
{'activation': 'softmax', 'name': 'Activation'}
[]

@fchollet
Copy link
Member

PReLU and BatchNormalization do, of course, have weights. They are learnable layers.

Also the model above has one "hidden layer" of size 10 (but reasoning in terms of hidden layers can be confusing, better think in terms of operations). You are doing one projection from the initial space to a space of size 10, then a second projection from that space to a space of size nb_classes.

@mthrok
Copy link
Contributor

mthrok commented Apr 29, 2015

I use the following function to print out dumped weight file from keras.

from __future__ import print_function

import h5py

def print_structure(weight_file_path):
    """
    Prints out the structure of HDF5 file.

    Args:
      weight_file_path (str) : Path to the file to analyze
    """
    f = h5py.File(weight_file_path)
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(weight_file_path))
            print("Root attributes:")
        for key, value in f.attrs.items():
            print("  {}: {}".format(key, value))

        if len(f.items())==0:
            return 

        for layer, g in f.items():
            print("  {}".format(layer))
            print("    Attributes:")
            for key, value in g.attrs.items():
                print("      {}: {}".format(key, value))

            print("    Dataset:")
            for p_name in g.keys():
                param = g[p_name]
                print("      {}: {}".format(p_name, param.shape))
    finally:
        f.close()

and output is something like this
(This is from my model and does not represent from your model.)

  layer_0
    Attributes:
      nb_params: 2
      subsample: [1 1]
      init: glorot_uniform
      nb_filter: 32
      name: Convolution2D
      activation: linear
      border_mode: full
      nb_col: 3
      stack_size: 3
      nb_row: 3
    Dataset:
      param_0: (32, 3, 3, 3)
      param_1: (32,)
  layer_1
    Attributes:
      nb_params: 0
      activation: relu
      name: Activation
    Dataset:
  layer_2
    Attributes:
      nb_params: 2
      subsample: [1 1]
      init: glorot_uniform
      nb_filter: 32
      name: Convolution2D
      activation: linear
      border_mode: valid
      nb_col: 3
      stack_size: 32
      nb_row: 3
    Dataset:
      param_0: (32, 32, 3, 3)
      param_1: (32,)
  layer_3
    Attributes:
      nb_params: 0
      activation: relu
      name: Activation
    Dataset:
  layer_4
    Attributes:
      nb_params: 0
      name: MaxPooling2D
      ignore_border: True
      poolsize: [2 2]
    Dataset:

So I can tell that layer_0 has Convolution2D and weight is stored at ‘param_0’ attribute and its shape is (32, 3, 3, 3), which means there are 32 filters, with 3 channels, 3pixel height 3pixel width, and bias is stored at ‘param_1’ and its shape (32,), one for each filter.

To access them, use model.layers[0].params[0] for weight and model.layers[0].params[1] for bias

@BrianMiner
Copy link
Author

@mthrok thank you - this is really helpful!

@xuxy09
Copy link

xuxy09 commented Jun 9, 2016

@mthrok It's extremely helpful. Thank you very much!

@vinayakumarr
Copy link

Instead of storing weights i want to store features. How to do this?

@ronzillia
Copy link

@mthrok I tried your function, it report
print(" {}: {}".format(p_name, param.shape))

AttributeError: 'Group' object has no attribute 'shape'

However, when I tried print(" {}: {}".format(p_name, param.shape)) independently, it's able to work. Do you have any idea about that?

@Faur
Copy link

Faur commented May 1, 2017

Has anything changed with Keras 2.x? it isn't working for me

@extrospective
Copy link

Changing from param.shape to param on the line which failed seems to work in eliminating the AttributeError ronzilllia mentions.

@jewes
Copy link

jewes commented Oct 20, 2017

Some modification to mthrok's answer to slove the issue "AttributeError: 'Group' object has no attribute 'shape'"

from __future__ import print_function

import h5py

def print_structure(weight_file_path):
    """
    Prints out the structure of HDF5 file.

    Args:
      weight_file_path (str) : Path to the file to analyze
    """
    f = h5py.File(weight_file_path)
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(weight_file_path))
            print("Root attributes:")
        for key, value in f.attrs.items():
            print("  {}: {}".format(key, value))

        if len(f.items())==0:
            return 

        for layer, g in f.items():
            print("  {}".format(layer))
            print("    Attributes:")
            for key, value in g.attrs.items():
                print("      {}: {}".format(key, value))

            print("    Dataset:")
            for p_name in g.keys():
                param = g[p_name]
                subkeys = param.keys()
                for k_name in param.keys():
                    print("      {}/{}: {}".format(p_name, k_name, param.get(k_name)[:]))
    finally:
        f.close()

Then it will prints something like below:

Root attributes:
  layer_names: ['dense_2']
  backend: tensorflow
  keras_version: 2.0.8
  dense_2
    Attributes:
      weight_names: ['dense_2/kernel:0' 'dense_2/bias:0']
    Dataset:
      dense_2/bias:0: [ 2.00016475]
      dense_2/kernel:0: [[ 2.99988198]]

@midsummer123
Copy link

I'm currently working with a tied-weight autoencoder which requires the extraction of weight matrix from a previous convolutional layer. However, some of the code I tried was written in an older version of layer.W, which seems to return a matrix with different dimension as the current method layer.get_weights()[0]. Does anyone have any idea of what I should do to fix it?

@vbhavank
Copy link

I'm trying to drop weights(zero out some weights) from individual layers during testing my ResNet50 model(trained for Aerial scene classification), loaded from my model.h5 weight file in Keras. The get_weights solves half the problem but not sure how I would place the weights back, after making changes to it.
Does anyone have an idea how to edit weights of individual layers and then test the model in Keras?

@M00NSH0T
Copy link

M00NSH0T commented Apr 14, 2018

anyone have something written that can throw each layer in as columns of a spreadsheet? I can only see a small number of weights with this... my network is pretty big. I don't see any zeroes in the printout, but I want to make sure my leakyRelu layers actually aren't getting any (or at least many) weights zeroed out.

@2017develper
Copy link

i have the weightof model which is trained with matlab the file called weights.mat i want to load this weights in keras how i can do that ??
i think in keras there are only h5 file who can load it i don't know how i can load

@kristpapadopoulos
Copy link

see my initial attempt to extract layer parameters into a csv file. I needed to see the magnitude of the layer weights and this was a way to view them.

https://github.com/kristpapadopoulos/keras_tools/blob/master/extract_parameters.py

@fabrahman
Copy link

@kristpapadopoulos I tried using ur code and I am getting this error:

Parameter File: checkpoints/weights_multi_fb_20180513_2.best.hdf5
Extracting Model Parameters to CSV File...
Traceback (most recent call last):
File "print_weights.py", line 54, in
weights[layer].extend(param[k_name].value[:].flatten().tolist())
AttributeError: 'Group' object has no attribute 'value'

Do you have any idea?

@kristpapadopoulos
Copy link

I made an update if the group object has no parameters (ie layer has no parameters) then None is assigned to avoid issue.

https://github.com/kristpapadopoulos/keras_tools/blob/master/extract_parameters.py

@deronnek
Copy link

deronnek commented Aug 7, 2018

If you're just looking to print the weights, I would suggest using the h5dump utility:
https://support.hdfgroup.org/HDF5/doc/RM/Tools.html#Tools-Dump

@S601327412
Copy link

Some modification to mthrok's answer to slove the issue "AttributeError: 'Group' object has no attribute 'shape'"

from __future__ import print_function

import h5py

def print_structure(weight_file_path):
    """
    Prints out the structure of HDF5 file.

    Args:
      weight_file_path (str) : Path to the file to analyze
    """
    f = h5py.File(weight_file_path)
    try:
        if len(f.attrs.items()):
            print("{} contains: ".format(weight_file_path))
            print("Root attributes:")
        for key, value in f.attrs.items():
            print("  {}: {}".format(key, value))

        if len(f.items())==0:
            return 

        for layer, g in f.items():
            print("  {}".format(layer))
            print("    Attributes:")
            for key, value in g.attrs.items():
                print("      {}: {}".format(key, value))

            print("    Dataset:")
            for p_name in g.keys():
                param = g[p_name]
                subkeys = param.keys()
                for k_name in param.keys():
                    print("      {}/{}: {}".format(p_name, k_name, param.get(k_name)[:]))
    finally:
        f.close()

Then it will prints something like below:

Root attributes:
  layer_names: ['dense_2']
  backend: tensorflow
  keras_version: 2.0.8
  dense_2
    Attributes:
      weight_names: ['dense_2/kernel:0' 'dense_2/bias:0']
    Dataset:
      dense_2/bias:0: [ 2.00016475]
      dense_2/kernel:0: [[ 2.99988198]]

AttributeError: 'slice' object has no attribute 'encode'

@lukasmichel
Copy link

I just encountered the same problem. I solved it by using model.save_weights("path") instead of model.save("path").

@Fordacre
Copy link

@S601327412 It might be caused by the Layer wrapper. If you create a model with a Layer Wrapper, there will be a nested group in your h5 structure which isn't consistent to mthrok's code.

@mactul
Copy link

mactul commented Sep 19, 2019

Hello,
Sorry for my very bad english
I work to create a little package of deep learning for my calculator.
I have just implemented the predict function.
I would like to train my model with keras, and after, give the coefficients to my function.

My function take in input, a list of coefficient (w1, w2, w3, ..., wn)
How can I have just a list of all coefficient, in a model without convolution?

Thank you very much for your reply

@td2014
Copy link
Contributor

td2014 commented Sep 19, 2019

Hi @mactul. Perhaps using something like the "get_weights()" function at the end of this code block might be what you are looking for:

from keras.models import Sequential
from keras.layers import Dense, Activation
# For a single-input model with 2 classes (binary classification):

model = Sequential()
model.add(Dense(8, activation='relu', input_dim=10))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Generate dummy data
import numpy as np
data = np.random.random((1000, 10))
labels = np.random.randint(2, size=(1000, 1))

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)

# print out the weights (coefficients)
print(model.get_weights())

I hope this helps.
Thanks.

@mactul
Copy link

mactul commented Sep 21, 2019

I have tryed your method, it's work good, but how can I translate this array in a list ?

I want to have coefficient of neuronal's input, and there place in the network

A sample example:
w1-
       -w4
w2-        -w6
       -w5
w3-

[w1, w2, w3, w4, w5, w6]

sorry, I'm not very explicit.

@td2014
Copy link
Contributor

td2014 commented Sep 21, 2019

Hi @mactul. Below I created a simplified version with explicit details of the weights. I did a prediction using the model.predict in Keras and also a direct calculation using the weights and biases.

This simplified model has 4 inputs into 2 nodes, then into one output node:

                          x1    x2   x3    x4  (input)


w1_11  w1_12  w1_13  w1_14 (+bias)     w1_21  w1_22  w1_23  w1_24  (+bias)   (layer 1)


                             w2_1   w2_2   (+bias)       (layer 2-output)

I hope this helps.
Thanks.

# Simplified version
#
from keras.models import Sequential
from keras.layers import Dense, Activation
# For a single-input model with 2 classes (binary classification):

model = Sequential()
model.add(Dense(2, activation='linear', input_dim=4))
model.add(Dense(1, activation='linear'))
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Generate dummy data
import numpy as np
data = np.random.random((1000, 4))  # input is 1000 examples x 4 features
labels = np.random.randint(2, size=(1000, 1)) # output(label) is 1000 examples x 1

# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32)
model.summary()
print()

# print out the weight coefficients of first layer
kernel_1, bias_1 = model.layers[0].get_weights()
print('kernel_1:')
print(kernel_1)
print('bias_1:')
print(bias_1)
print()
# print out the weight coefficients of second output layer
kernel_2, bias_2 = model.layers[1].get_weights()
print('kernel_2:')
print(kernel_2)
print('bias:_2')
print(bias_2)
print()


#predict with keras.model.predict
x_test = np.array([1, 1.5, -4 ,3])
x_test=np.expand_dims(x_test, axis = 0)
print('test input:')
print(x_test)
print('model.predict:')
print(model.predict(x_test))


#predict with direct calculation using manual summations
print()
node_1_sum = \
    x_test[0,0]*kernel_1[0,0]+ \
    x_test[0,1]*kernel_1[1,0]+ \
    x_test[0,2]*kernel_1[2,0]+ \
    x_test[0,3]*kernel_1[3,0]+ bias_1[0]

print('node_1_sum:')
print(node_1_sum)

node_2_sum = \
    x_test[0,0]*kernel_1[0,1]+ \
    x_test[0,1]*kernel_1[1,1]+ \
    x_test[0,2]*kernel_1[2,1]+ \
    x_test[0,3]*kernel_1[3,1]+ bias_1[1]

print('node_2_sum:')
print(node_2_sum)

#output layer

output_layer = node_1_sum*kernel_2[0] + node_2_sum*kernel_2[1] + bias_2[0]

print('final result of network using manual calculations = ', output_layer)

@AmberrrLiu
Copy link

@td2014 hi ,I want to get weights from two model and average them.Then put the averaged weights in a new model.Three models hava same structure. How can I implement this? Wish you can help me ,thank you.

@td2014
Copy link
Contributor

td2014 commented Oct 28, 2019

Hi @AmberrrLiu . I don't know for sure, but you might want to take a look at the first section of this page: https://keras.io/models/about-keras-models/ . It mentions get_weights, set_weights, save_weights, and load_weights functions. It might be possible to get the weights from each of your models, do the averaging using python/numpy, then set the weights in the new model. You can also save and reload if that works for you.

I hope this helps.

@AmberrrLiu
Copy link

AmberrrLiu commented Oct 30, 2019 via email

@yunjiangster
Copy link

For tensorflow, don't use tf.layers.dense, since it only returns the result of dense layer. Instead use tf.layers.Dense:

layer = tf.layers.Dense(units=units)
dummy_net = tf.zeros([0, net.get_shapes()[1]], dtype=tf.float)
layer.apply(dummy_net)
kernel, bias = layer.weights

From here on you can use kernel/bias anyway you like.

hubingallin pushed a commit to hubingallin/keras that referenced this issue Sep 22, 2023
* Add meanX metrics

* All regression metrics except for root mean squared error

* Formatting issues

* Add RootMeanSquaredError

* Docstring spacing

* Line too long fix
kernel-loophole pushed a commit to kernel-loophole/keras that referenced this issue Sep 25, 2023
* Add meanX metrics

* All regression metrics except for root mean squared error

* Formatting issues

* Add RootMeanSquaredError

* Docstring spacing

* Line too long fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests