diff --git a/.gitignore b/.gitignore index adccfaa8..1cf4b4d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ *.o *.mod -build* -data/mnist/*.dat +build +data/*/*.dat diff --git a/CMakeLists.txt b/CMakeLists.txt index cd5ca477..108e41b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,7 +91,7 @@ foreach(execid mnist network_save network_sync set_activation_function) add_test(test_${execid} bin/test_${execid}) endforeach() -foreach(execid mnist save_and_load simple sine) +foreach(execid mnist mnist_epochs save_and_load simple sine) add_executable(example_${execid} src/tests/example_${execid}.f90) target_link_libraries(example_${execid} neural ${LIBS}) add_test(example_${execid} bin/example_${execid}) diff --git a/LICENSE b/LICENSE index 754e4c43..f5d80357 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2018-2019 Milan Curcic and neural-fortran contributors +Copyright (c) 2018-2020 Milan Curcic and neural-fortran contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 6fd3b5df..54b80239 100644 --- a/README.md +++ b/README.md @@ -266,13 +266,11 @@ program example_mnist batch_size = 100 num_epochs = 10 - if (this_image() == 1) then - write(*, '(a,f5.2,a)') 'Initial accuracy: ',& - net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' - end if + if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' epochs: do n = 1, num_epochs - mini_batches: do i = 1, size(tr_labels) / batch_size + batches: do i = 1, size(tr_labels) / batch_size ! pull a random mini-batch from the dataset call random_number(pos) @@ -286,12 +284,10 @@ program example_mnist ! train the network on the mini-batch call net % train(input, output, eta=3._rk) - end do mini_batches + end do batches - if (this_image() == 1) then - write(*, '(a,i2,a,f5.2,a)') 'Epoch ', n, ' done, Accuracy: ',& - net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' - end if + if (this_image() == 1) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' end do epochs @@ -322,6 +318,45 @@ for example on 16 cores using [OpenCoarrays](https://github.com/sourceryinstitut $ cafrun -n 16 ./example_mnist ``` +### Montesinos-Lopez et al. (2018) univariate example + +The Montesinos-Lopez et al. (2018) univariate example is extracted from the study: + +``` +Montesinos-Lopez et al. 2018. Multi-environment genomic prediction of plant traits using deep learners with dense architecture. G3, 8, 3813-3828. +``` + +This example uses the data from the dataset "Data\_Maize\_1to3", and was extracted using the R code in the Appendix of this paper. + + +The Montesinos-Lopez univariate data is included with the repo and you will have to unpack it first: + +``` +cd data/montesinos_uni +tar xzvf montesinos_uni.tar.gz +cd - +``` + +### Montesinos-Lopez et al. (2018) multivariate example + +The Montesinos-Lopez et al. (2018) multivariate example is extracted from the study: + +``` +Montesinos-Lopez et al. 2018. Multi-trait, multi-environment deep learning modeling for genomic-enabled prediction of plant traits. G3, 8, 3829-3840. +``` + +This example uses the data from the dataset "Data\_Maize\_set\_1", and was extracted using the R code in the Appendix B of this paper. + + +The Montesinos-Lopez multivariate data is included with the repo and you will have to unpack it first: + +``` +cd data/montesinos_multi +tar xzvf montesinos_multi.tar.gz +cd - +``` + + ## Contributing neural-fortran is currently a proof-of-concept with potential for diff --git a/data/montesinos_multi.tar.gz b/data/montesinos_multi.tar.gz new file mode 100644 index 00000000..2c43625b Binary files /dev/null and b/data/montesinos_multi.tar.gz differ diff --git a/data/montesinos_uni.tar.gz b/data/montesinos_uni.tar.gz new file mode 100644 index 00000000..df1920fc Binary files /dev/null and b/data/montesinos_uni.tar.gz differ diff --git a/src/lib/mod_network.f90 b/src/lib/mod_network.f90 index 5cd05532..01f69a60 100644 --- a/src/lib/mod_network.f90 +++ b/src/lib/mod_network.f90 @@ -23,23 +23,26 @@ module mod_network procedure, public, pass(self) :: init procedure, public, pass(self) :: load procedure, public, pass(self) :: loss - procedure, public, pass(self) :: output + procedure, public, pass(self) :: output_batch + procedure, public, pass(self) :: output_single procedure, public, pass(self) :: save procedure, public, pass(self) :: set_activation_equal procedure, public, pass(self) :: set_activation_layers procedure, public, pass(self) :: sync procedure, public, pass(self) :: train_batch + procedure, public, pass(self) :: train_epochs procedure, public, pass(self) :: train_single procedure, public, pass(self) :: update + generic, public :: output => output_batch, output_single generic, public :: set_activation => set_activation_equal, set_activation_layers - generic, public :: train => train_batch, train_single + generic, public :: train => train_batch, train_epochs, train_single end type network_type interface network_type module procedure :: net_constructor - endinterface network_type + end interface network_type contains @@ -58,6 +61,7 @@ type(network_type) function net_constructor(dims, activation) result(net) call net % sync(1) end function net_constructor + pure real(rk) function accuracy(self, x, y) ! Given input x and output y, evaluates the position of the ! maximum value of the output and returns the number of matches @@ -74,6 +78,7 @@ pure real(rk) function accuracy(self, x, y) accuracy = real(good) / size(x, dim=2) end function accuracy + pure subroutine backprop(self, y, dw, db) ! Applies a backward propagation through the network ! and returns the weight and bias gradients. @@ -104,6 +109,7 @@ pure subroutine backprop(self, y, dw, db) end subroutine backprop + pure subroutine fwdprop(self, x) ! Performs the forward propagation and stores arguments to activation ! functions and activations themselves for use in backprop. @@ -119,6 +125,7 @@ pure subroutine fwdprop(self, x) end associate end subroutine fwdprop + subroutine init(self, dims) ! Allocates and initializes the layers with given dimensions dims. class(network_type), intent(in out) :: self @@ -134,6 +141,7 @@ subroutine init(self, dims) self % layers(size(dims)) % w = 0 end subroutine init + subroutine load(self, filename) ! Loads the network from file. class(network_type), intent(in out) :: self @@ -142,23 +150,24 @@ subroutine load(self, filename) integer(ik), allocatable :: dims(:) character(len=100) :: buffer ! activation string open(newunit=fileunit, file=filename, status='old', action='read') - read(fileunit, fmt=*) num_layers + read(fileunit, *) num_layers allocate(dims(num_layers)) - read(fileunit, fmt=*) dims + read(fileunit, *) dims call self % init(dims) do n = 1, num_layers - read(fileunit, fmt=*) layer_idx, buffer + read(fileunit, *) layer_idx, buffer call self % layers(layer_idx) % set_activation(trim(buffer)) end do do n = 2, size(self % dims) - read(fileunit, fmt=*) self % layers(n) % b + read(fileunit, *) self % layers(n) % b end do do n = 1, size(self % dims) - 1 - read(fileunit, fmt=*) self % layers(n) % w + read(fileunit, *) self % layers(n) % w end do close(fileunit) end subroutine load + pure real(rk) function loss(self, x, y) ! Given input x and expected output y, returns the loss of the network. class(network_type), intent(in) :: self @@ -166,8 +175,10 @@ pure real(rk) function loss(self, x, y) loss = 0.5 * sum((y - self % output(x))**2) / size(x) end function loss - pure function output(self, x) result(a) + + pure function output_single(self, x) result(a) ! Use forward propagation to compute the output of the network. + ! This specific procedure is for a single sample of 1-d input data. class(network_type), intent(in) :: self real(rk), intent(in) :: x(:) real(rk), allocatable :: a(:) @@ -178,7 +189,22 @@ pure function output(self, x) result(a) a = self % layers(n) % activation(matmul(transpose(layers(n-1) % w), a) + layers(n) % b) end do end associate - end function output + end function output_single + + + pure function output_batch(self, x) result(a) + ! Use forward propagation to compute the output of the network. + ! This specific procedure is for a batch of 1-d input data. + class(network_type), intent(in) :: self + real(rk), intent(in) :: x(:,:) + real(rk), allocatable :: a(:,:) + integer(ik) :: i + allocate(a(self % dims(size(self % dims)), size(x, dim=2))) + do i = 1, size(x, dim=2) + a(:,i) = self % output_single(x(:,i)) + end do + end function output_batch + subroutine save(self, filename) ! Saves the network to a file. @@ -200,6 +226,7 @@ subroutine save(self, filename) close(fileunit) end subroutine save + pure subroutine set_activation_equal(self, activation) ! A thin wrapper around layer % set_activation(). ! This method can be used to set an activation function @@ -209,6 +236,7 @@ pure subroutine set_activation_equal(self, activation) call self % layers(:) % set_activation(activation) end subroutine set_activation_equal + pure subroutine set_activation_layers(self, activation) ! A thin wrapper around layer % set_activation(). ! This method can be used to set different activation functions @@ -233,6 +261,7 @@ subroutine sync(self, image) end do layers end subroutine sync + subroutine train_batch(self, x, y, eta) ! Trains a network using input data x and output data y, ! and learning rate eta. The learning rate is normalized @@ -273,6 +302,38 @@ subroutine train_batch(self, x, y, eta) end subroutine train_batch + + subroutine train_epochs(self, x, y, eta, num_epochs, batch_size) + ! Trains for num_epochs epochs with mini-bachtes of size equal to batch_size. + class(network_type), intent(in out) :: self + integer(ik), intent(in) :: num_epochs, batch_size + real(rk), intent(in) :: x(:,:), y(:,:), eta + + integer(ik) :: i, n, nsamples, nbatch + integer(ik) :: batch_start, batch_end + + real(rk) :: pos + + nsamples = size(y, dim=2) + nbatch = nsamples / batch_size + + epochs: do n = 1, num_epochs + batches: do i = 1, nbatch + + !pull a random mini-batch from the dataset + call random_number(pos) + batch_start = int(pos * (nsamples - batch_size + 1)) + if (batch_start == 0) batch_start = 1 + batch_end = batch_start + batch_size - 1 + + call self % train(x(:,batch_start:batch_end), y(:,batch_start:batch_end), eta) + + end do batches + end do epochs + + end subroutine train_epochs + + pure subroutine train_single(self, x, y, eta) ! Trains a network using a single set of input data x and output data y, ! and learning rate eta. @@ -285,6 +346,7 @@ pure subroutine train_single(self, x, y, eta) call self % update(dw, db, eta) end subroutine train_single + pure subroutine update(self, dw, db, eta) ! Updates network weights and biases with gradients dw and db, ! scaled by learning rate eta. diff --git a/src/tests/example_mnist.f90 b/src/tests/example_mnist.f90 index 846bdca2..1192e07e 100644 --- a/src/tests/example_mnist.f90 +++ b/src/tests/example_mnist.f90 @@ -12,7 +12,6 @@ program example_mnist real(rk), allocatable :: tr_images(:,:), tr_labels(:) real(rk), allocatable :: te_images(:,:), te_labels(:) - !real(rk), allocatable :: va_images(:,:), va_labels(:) real(rk), allocatable :: input(:,:), output(:,:) type(network_type) :: net @@ -23,18 +22,16 @@ program example_mnist call load_mnist(tr_images, tr_labels, te_images, te_labels) - net = network_type([784, 10, 10]) + net = network_type([784, 30, 10]) - batch_size = 1000 + batch_size = 100 num_epochs = 10 - if (this_image() == 1) then - write(*, '(a,f5.2,a)') 'Initial accuracy: ',& - net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' - end if + if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' epochs: do n = 1, num_epochs - mini_batches: do i = 1, size(tr_labels) / batch_size + batches: do i = 1, size(tr_labels) / batch_size ! pull a random mini-batch from the dataset call random_number(pos) @@ -48,12 +45,10 @@ program example_mnist ! train the network on the mini-batch call net % train(input, output, eta=3._rk) - end do mini_batches + end do batches - if (this_image() == 1) then - write(*, '(a,i2,a,f5.2,a)') 'Epoch ', n, ' done, Accuracy: ',& - net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' - end if + if (this_image() == 1) print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' end do epochs diff --git a/src/tests/example_mnist_epochs.f90 b/src/tests/example_mnist_epochs.f90 new file mode 100644 index 00000000..08ba04a8 --- /dev/null +++ b/src/tests/example_mnist_epochs.f90 @@ -0,0 +1,36 @@ +program example_mnist + + ! A training example with the MNIST dataset. + ! Uses stochastic gradient descent and mini-batch size of 100. + ! Can be run in serial or parallel mode without modifications. + + use mod_kinds, only: ik, rk + use mod_mnist, only: label_digits, load_mnist + use mod_network, only: network_type + + implicit none + + real(rk), allocatable :: tr_images(:,:), tr_labels(:) + real(rk), allocatable :: te_images(:,:), te_labels(:) + + type(network_type) :: net + + integer(ik) :: i, n, num_epochs + integer(ik) :: batch_size + + call load_mnist(tr_images, tr_labels, te_images, te_labels) + + net = network_type([size(tr_images, dim=1), 10, size(label_digits(tr_labels), dim=1)]) + + batch_size = 100 + num_epochs = 10 + + if (this_image() == 1) print '(a,f5.2,a)', 'Initial accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' + + call net % train(tr_images, label_digits(tr_labels), 3._rk, num_epochs, batch_size) + + if (this_image() == 1) print '(a,f5.2,a)', 'Epochs done, Accuracy: ', & + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' + +end program example_mnist