New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add resnet50 example #3266

merged 16 commits into from Aug 3, 2016

update weights links and remove load_weights

  • Loading branch information...
MoyanZitto committed Jul 22, 2016
commit 23729b3620e86c2e8698d821db8191938fe661c2
@@ -16,8 +16,8 @@
provided by Kaiming He directly.
For now we provide pretrained weights for tensorflow backend, pls donwload pretrained file at:
(Don't be afraid of Chinese, just click the bottom at top right with '(98.2M)') (For China) (For Other contries)
If you are using theano backend,
you can transfer tf weights to th weights by hand under the instruction at:
@@ -118,28 +118,6 @@ def conv_block(input_tensor, nb_filter, stage, block, kernel_size=3):
return out
def load_weights(model, weights_path):
This function load the pretrained weights to the model
f = h5py.File(weights_path, 'r')
for layer in model.layers:
if[:3] == 'res':
layer.set_weights([f[]['weights'][:], f[]['bias'][:]])
elif[:2] == 'bn':
scale_name = 'scale'[2:]
weights = []
model.get_layer('conv1').set_weights([f['conv1']['weights'][:], f['conv1']['bias'][:]])
model.get_layer('fc1000').set_weights([f['fc1000']['weights'][:].T, f['fc1000']['bias'][:]])
return model
def read_img(img_path):
this function returns preprocessed image
@@ -167,15 +145,15 @@ def read_img(img_path):
# decenterize
img[:, :, 0] -= mean[0]
img[:, :, 1] -= mean[1]
img[:, :, 2] -= mean[2]
# 'RGB'->'BGR'
img = img[:, :, ::-1]
# 'tf'->'th'
img = np.transpose(img, (2, 0, 1))
# expand dim for test
img = np.expand_dims(img, axis=0)
@@ -227,7 +205,7 @@ def get_resnet50():
if __name__ == '__main__':
resnet_model = get_resnet50()
resnet_model = load_weights(resnet_model, 'resnet50.h5')
test_img1 = read_img('cat.jpg')
test_img2 = read_img('airplane.jpg')
# you may download synset_words from address given at the begining of this file
ProTip! Use n and p to navigate between commits in a pull request.