Skip to content

Commit 915cec5

Browse files
committed
Add URL to Keras CNN MNIST model; make URL constants more concise
1 parent 94cc86b commit 915cec5

File tree

5 files changed

+33
-16
lines changed

5 files changed

+33
-16
lines changed

example/mnist_from_keras.f90

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@ program mnist_from_keras
44
! from an HDF5 file and running an inferrence on the testing dataset.
55

66
use nf, only: network, label_digits, load_mnist
7-
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
7+
use nf_datasets, only: download_and_unpack, keras_dense_mnist_url
88

99
implicit none
1010

1111
type(network) :: net
1212
real, allocatable :: training_images(:,:), training_labels(:)
1313
real, allocatable :: validation_images(:,:), validation_labels(:)
1414
real, allocatable :: testing_images(:,:), testing_labels(:)
15-
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
15+
character(*), parameter :: keras_dense_path = 'keras_dense_mnist.h5'
1616
logical :: file_exists
1717

18-
inquire(file=test_data_path, exist=file_exists)
19-
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
18+
inquire(file=keras_dense_path, exist=file_exists)
19+
if (.not. file_exists) call download_and_unpack(keras_dense_mnist_url)
2020

2121
call load_mnist(training_images, training_labels, &
2222
validation_images, validation_labels, &
@@ -25,7 +25,7 @@ program mnist_from_keras
2525
print '("Loading a pre-trained MNIST model from Keras")'
2626
print '(60("="))'
2727

28-
net = network(test_data_path)
28+
net = network(keras_dense_path)
2929

3030
call net % print_info()
3131

src/nf/nf_datasets.f90

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@ module nf_datasets
88

99
private
1010

11-
public :: download_and_unpack, keras_model_dense_mnist_url, mnist_url
11+
public :: &
12+
download_and_unpack, &
13+
keras_cnn_mnist_url, &
14+
keras_dense_mnist_url, &
15+
mnist_url
1216

1317
character(*), parameter :: keras_snippets_baseurl = &
1418
'https://github.com/neural-fortran/keras-snippets/files'
1519
character(*), parameter :: neural_fortran_baseurl = &
1620
'https://github.com/modern-fortran/neural-fortran/files'
17-
character(*), parameter :: keras_model_dense_mnist_url = &
21+
character(*), parameter :: keras_cnn_mnist_url = &
22+
keras_snippets_baseurl // '/8892585/keras_cnn_mnist.tar.gz'
23+
character(*), parameter :: keras_dense_mnist_url = &
1824
keras_snippets_baseurl // '/8788739/keras_dense_mnist.tar.gz'
1925
character(*), parameter :: mnist_url = &
2026
neural_fortran_baseurl // '/8498876/mnist.tar.gz'

test/test_dense_network_from_keras.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ program test_dense_network_from_keras
22

33
use iso_fortran_env, only: stderr => error_unit
44
use nf, only: network
5-
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
5+
use nf_datasets, only: download_and_unpack, keras_dense_mnist_url
66

77
implicit none
88

@@ -12,7 +12,7 @@ program test_dense_network_from_keras
1212
logical :: ok = .true.
1313

1414
inquire(file=test_data_path, exist=file_exists)
15-
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
15+
if (.not. file_exists) call download_and_unpack(keras_dense_mnist_url)
1616

1717
net = network(test_data_path)
1818

test/test_io_hdf5.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program test_io_hdf5
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
4+
use nf_datasets, only: download_and_unpack, keras_dense_mnist_url
55
use nf_io_hdf5, only: hdf5_attribute_string, get_hdf5_dataset
66

77
implicit none
@@ -14,7 +14,7 @@ program test_io_hdf5
1414
logical :: ok = .true.
1515

1616
inquire(file=test_data_path, exist=file_exists)
17-
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
17+
if (.not. file_exists) call download_and_unpack(keras_dense_mnist_url)
1818

1919
attr = hdf5_attribute_string(test_data_path, '.', 'backend')
2020

test/test_keras_read_model.f90

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
program test_keras_read_model
22

33
use iso_fortran_env, only: stderr => error_unit
4-
use nf_datasets, only: download_and_unpack, keras_model_dense_mnist_url
4+
use nf_datasets, only: download_and_unpack, keras_dense_mnist_url, &
5+
keras_cnn_mnist_url
56
use nf_keras, only: get_keras_h5_layers, keras_layer
67
use nf, only: layer, network, dense, input
78

89
implicit none
910

1011
character(:), allocatable :: model_config_string
11-
character(*), parameter :: test_data_path = 'keras_dense_mnist.h5'
12+
character(*), parameter :: keras_dense_path = 'keras_dense_mnist.h5'
13+
character(*), parameter :: keras_cnn_path = 'keras_cnn_mnist.h5'
1214

1315
type(keras_layer), allocatable :: keras_layers(:)
1416

@@ -19,10 +21,12 @@ program test_keras_read_model
1921
logical :: file_exists
2022
logical :: ok = .true.
2123

22-
inquire(file=test_data_path, exist=file_exists)
23-
if (.not. file_exists) call download_and_unpack(keras_model_dense_mnist_url)
24+
! First test the dense model
2425

25-
keras_layers = get_keras_h5_layers(test_data_path)
26+
inquire(file=keras_dense_path, exist=file_exists)
27+
if (.not. file_exists) call download_and_unpack(keras_dense_mnist_url)
28+
29+
keras_layers = get_keras_h5_layers(keras_dense_path)
2630

2731
if (size(keras_layers) /= 3) then
2832
ok = .false.
@@ -51,6 +55,13 @@ program test_keras_read_model
5155
'Keras second and third layers should be dense.. failed'
5256
end if
5357

58+
! Now testing for the CNN model
59+
60+
inquire(file=keras_cnn_path, exist=file_exists)
61+
if (.not. file_exists) call download_and_unpack(keras_cnn_mnist_url)
62+
63+
keras_layers = get_keras_h5_layers(keras_cnn_path)
64+
5465
if (ok) then
5566
print '(a)', 'test_keras_read_model: All tests passed.'
5667
else

0 commit comments

Comments
 (0)