diff --git a/.gitignore b/.gitignore index 2d21e4db..1cf4b4d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ *.o *.mod build -data/mnist/*.dat +data/*/*.dat diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b6b1b22..0e32cf92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,7 +78,7 @@ foreach(execid mnist network_save network_sync set_activation_function) add_test(test_${execid} bin/test_${execid}) endforeach() -foreach(execid mnist save_and_load simple sine) +foreach(execid mnist montesinos_uni montesinos_multi save_and_load simple sine) add_executable(example_${execid} src/tests/example_${execid}.f90) target_link_libraries(example_${execid} neural ${LIBS}) add_test(example_${execid} bin/example_${execid}) diff --git a/README.md b/README.md index 6fd3b5df..20a93766 100644 --- a/README.md +++ b/README.md @@ -322,6 +322,45 @@ for example on 16 cores using [OpenCoarrays](https://github.com/sourceryinstitut $ cafrun -n 16 ./example_mnist ``` +### Montesinos-Lopez et al. (2018) univariate example + +The Montesinos-Lopez et al. (2018) univariate example is extracted from the study: + +``` +Montesinos-Lopez et al. 2018. Multi-environment genomic prediction of plant traits using deep learners with dense architecture. G3, 8, 3813-3828. +``` + +This example uses the data from the dataset "Data\_Maize\_1to3", and was extracted using the R code in the Appendix of this paper. + + +The Montesinos-Lopez univariate data is included with the repo and you will have to unpack it first: + +``` +cd data/montesinos_uni +tar xzvf montesinos_uni.tar.gz +cd - +``` + +### Montesinos-Lopez et al. (2018) multivariate example + +The Montesinos-Lopez et al. (2018) multivariate example is extracted from the study: + +``` +Montesinos-Lopez et al. 2018. Multi-trait, multi-environment deep learning modeling for genomic-enabled prediction of plant traits. G3, 8, 3829-3840. +``` + +This example uses the data from the dataset "Data\_Maize\_set\_1", and was extracted using the R code in the Appendix B of this paper. + + +The Montesinos-Lopez multivariate data is included with the repo and you will have to unpack it first: + +``` +cd data/montesinos_multi +tar xzvf montesinos_multi.tar.gz +cd - +``` + + ## Contributing neural-fortran is currently a proof-of-concept with potential for diff --git a/data/montesinos_multi.tar.gz b/data/montesinos_multi.tar.gz new file mode 100644 index 00000000..2c43625b Binary files /dev/null and b/data/montesinos_multi.tar.gz differ diff --git a/data/montesinos_uni.tar.gz b/data/montesinos_uni.tar.gz new file mode 100644 index 00000000..df1920fc Binary files /dev/null and b/data/montesinos_uni.tar.gz differ diff --git a/src/lib/mod_network.f90 b/src/lib/mod_network.f90 index 4d3d7aa0..5b983198 100644 --- a/src/lib/mod_network.f90 +++ b/src/lib/mod_network.f90 @@ -23,15 +23,18 @@ module mod_network procedure, public, pass(self) :: init procedure, public, pass(self) :: load procedure, public, pass(self) :: loss - procedure, public, pass(self) :: output + procedure, public, pass(self) :: output_batch + procedure, public, pass(self) :: output_single procedure, public, pass(self) :: save procedure, public, pass(self) :: set_activation procedure, public, pass(self) :: sync procedure, public, pass(self) :: train_batch + procedure, public, pass(self) :: train_epochs procedure, public, pass(self) :: train_single procedure, public, pass(self) :: update - generic, public :: train => train_batch, train_single + generic, public :: output => output_batch, output_single + generic, public :: train => train_batch, train_epochs, train_single end type network_type @@ -159,7 +162,7 @@ pure real(rk) function loss(self, x, y) loss = 0.5 * sum((y - self % output(x))**2) / size(x) end function loss - pure function output(self, x) result(a) + pure function output_single(self, x) result(a) ! Use forward propagation to compute the output of the network. class(network_type), intent(in) :: self real(rk), intent(in) :: x(:) @@ -171,7 +174,21 @@ pure function output(self, x) result(a) a = self % layers(n) % activation(matmul(transpose(layers(n-1) % w), a) + layers(n) % b) end do end associate - end function output + end function output_single + + pure function output_batch(self, x) result(a) + class(network_type), intent(in) :: self + real(rk), intent(in) :: x(:,:) + real(rk), allocatable :: a(:,:) + + integer(ik) :: i + + allocate(a(self%dims(size(self%dims)),size(x,dim=2))) + do i = 1, size(x, dim=2) + a(:,i)=self%output(x(:,i)) + enddo + + end function output_batch subroutine save(self, filename) ! Saves the network to a file. @@ -255,6 +272,37 @@ subroutine train_batch(self, x, y, eta) end subroutine train_batch + subroutine train_epochs(self, x, y, eta,num_epochs,num_batch_size) + !Performs the training for nun_epochs epochs with mini-bachtes of size equal to num_batch_size + class(network_type), intent(in out) :: self + integer(ik),intent(in)::num_epochs,num_batch_size + real(rk), intent(in) :: x(:,:), y(:,:), eta + + integer(ik)::i,n,nsamples,nbatch + integer(ik)::batch_start,batch_end + + real(rk)::pos + + nsamples=size(y,dim=2) + + nbatch=nsamples/num_batch_size + + epoch: do n=1,num_epochs + mini_batches: do i=1,nbatch + + !pull a random mini-batch from the dataset + call random_number(pos) + batch_start=int(pos*(nsamples-num_batch_size+1)) + if(batch_start.eq.0)batch_start=1 + batch_end=batch_start+num_batch_size-1 + + call self%train(x(:,batch_start:batch_end),y(:,batch_start:batch_end),eta) + + enddo mini_batches + enddo epoch + + end subroutine train_epochs + pure subroutine train_single(self, x, y, eta) ! Trains a network using a single set of input data x and output data y, ! and learning rate eta. diff --git a/src/tests/example_mnist.f90 b/src/tests/example_mnist.f90 index 846bdca2..0c2c1db2 100644 --- a/src/tests/example_mnist.f90 +++ b/src/tests/example_mnist.f90 @@ -12,18 +12,15 @@ program example_mnist real(rk), allocatable :: tr_images(:,:), tr_labels(:) real(rk), allocatable :: te_images(:,:), te_labels(:) - !real(rk), allocatable :: va_images(:,:), va_labels(:) - real(rk), allocatable :: input(:,:), output(:,:) type(network_type) :: net integer(ik) :: i, n, num_epochs - integer(ik) :: batch_size, batch_start, batch_end - real(rk) :: pos + integer(ik) :: batch_size call load_mnist(tr_images, tr_labels, te_images, te_labels) - net = network_type([784, 10, 10]) + net = network_type([size(tr_images,dim=1), 10, size(label_digits(te_labels),dim=1)]) batch_size = 1000 num_epochs = 10 @@ -33,28 +30,12 @@ program example_mnist net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' end if - epochs: do n = 1, num_epochs - mini_batches: do i = 1, size(tr_labels) / batch_size - - ! pull a random mini-batch from the dataset - call random_number(pos) - batch_start = int(pos * (size(tr_labels) - batch_size + 1)) - batch_end = batch_start + batch_size - 1 - - ! prepare mini-batch - input = tr_images(:,batch_start:batch_end) - output = label_digits(tr_labels(batch_start:batch_end)) - - ! train the network on the mini-batch - call net % train(input, output, eta=3._rk) - - end do mini_batches - - if (this_image() == 1) then - write(*, '(a,i2,a,f5.2,a)') 'Epoch ', n, ' done, Accuracy: ',& - net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' - end if + call net%train(tr_images,label_digits(tr_labels),3._rk,num_epochs,batch_size) + + if (this_image() == 1) then + write(*, '(a,f5.2,a)') 'Epochs done, Accuracy: ',& + net % accuracy(te_images, label_digits(te_labels)) * 100, ' %' + endif - end do epochs end program example_mnist diff --git a/src/tests/example_montesinos_multi.f90 b/src/tests/example_montesinos_multi.f90 new file mode 100644 index 00000000..c9fd47ef --- /dev/null +++ b/src/tests/example_montesinos_multi.f90 @@ -0,0 +1,136 @@ +program example_montesinos_multi + use mod_kinds,only:ik,rk + use mod_network,only:network_type + implicit none + integer(ik)::ny1_tr,ny2_tr,nx1_tr,nx2_tr + integer(ik)::ny1_ts,ny2_ts,nx1_ts,nx2_ts + + integer(ik)::batch_size,num_epochs + + real(rk),allocatable::y_tr(:,:),x_tr(:,:) + real(rk),allocatable::y_ts(:,:),x_ts(:,:) + + type(network_type)::net + + call readfile('../data/montesinos_multi/y_tr.dat',ny1_tr,ny2_tr,y_tr) + call readfile('../data/montesinos_multi/x_tr.dat',nx1_tr,nx2_tr,x_tr) + + net=network_type([nx1_tr,50,50,ny1_tr]) + + batch_size=50 + num_epochs=50 + + !training + call net%train(x_tr,y_tr,3._rk,num_epochs,batch_size) + + call net%sync(1) + + !validation + call readfile('../data/montesinos_multi/y_ts.dat',ny1_ts,ny2_ts,y_ts) + call readfile('../data/montesinos_multi/x_ts.dat',nx1_ts,nx2_ts,x_ts) + + if(this_image().eq.1)then + write(*,*)'Correlation(s): ',corr_array(net%output(x_ts),y_ts) + endif + +contains + +subroutine readfile(filename,n,m,array) + character(len=*),intent(in)::filename + integer(ik),intent(out)::n,m + real(rk),allocatable,intent(out)::array(:,:) + + integer(ik)::un,i,io + + open(newunit=un,file=filename,status='old',action='read') + call numlines(un,m) + call numcol(un,n) + + allocate(array(n,m)) + rewind(un) + do i=1,m + read(un,*,iostat=io)array(:,i) + if(io.ne.0)exit + enddo + close(un) + +end subroutine + +pure function corr_array(array1,array2) result(a) + real(rk),intent(in)::array1(:,:),array2(:,:) + real(rk),allocatable::a(:) + + integer(ik)::i,n + + n=size(array1,dim=1) + + allocate(a(n)) + a=0.0_rk + do i=1,n + a(i)=corr(array1(i,:),array2(i,:)) + enddo + +end function + +pure real(rk) function corr(array1,array2) + real(rk),intent(in)::array1(:),array2(:) + + real(rk)::mean1,mean2 + + !brute force + + mean1=sum(array1)/size(array1) + mean2=sum(array2)/size(array2) + corr=dot_product(array1-mean1,array2-mean2)/sqrt(sum((array1-mean1)**2)*sum((array2-mean2)**2)) + +end function + +subroutine numlines(unfile,n) + implicit none + integer::io + integer,intent(in)::unfile + integer,intent(out)::n + rewind(unfile) + n=0 + do + read(unfile,*,iostat=io) + if (io.ne.0) exit + n=n+1 + enddo + rewind(unfile) +end subroutine + +subroutine numcol(unfile,n) + implicit none + integer,intent(in)::unfile + character(len=1000000)::a + integer,intent(out)::n + integer::curr,first,last,lena,stat,i + rewind(unfile) + read(unfile,"(a)")a + curr=1;lena=len(a);n=0 + do + first=0 + do i=curr,lena + if (a(i:i) /= " ") then + first=i + exit + endif + enddo + if (first == 0) exit + curr=first+1 + last=0 + do i=curr,lena + if (a(i:i) == " ") then + last=i + exit + endif + enddo + if (last == 0) last=lena + n=n+1 + curr=last+1 + enddo + rewind(unfile) +end subroutine + +end program diff --git a/src/tests/example_montesinos_uni.f90 b/src/tests/example_montesinos_uni.f90 new file mode 100644 index 00000000..8f0b8e85 --- /dev/null +++ b/src/tests/example_montesinos_uni.f90 @@ -0,0 +1,137 @@ +program example_montesinos_uni + use mod_kinds,only:ik,rk + use mod_network,only:network_type + implicit none + integer(ik)::ny1_tr,ny2_tr,nx1_tr,nx2_tr + integer(ik)::ny1_ts,ny2_ts,nx1_ts,nx2_ts + + integer(ik)::batch_size,num_epochs + + real(rk),allocatable::y_tr(:,:),x_tr(:,:) + real(rk),allocatable::y_ts(:,:),x_ts(:,:) + + type(network_type)::net + + call readfile('../data/montesinos_uni/y_tr.dat',ny1_tr,ny2_tr,y_tr) + call readfile('../data/montesinos_uni/x_tr.dat',nx1_tr,nx2_tr,x_tr) + + !net=network_type([nx1_tr,50,50,ny1_tr],'relu') + net=network_type([nx1_tr,50,50,ny1_tr]) + + batch_size=30 + num_epochs=20 + + !training + call net%train(x_tr,y_tr,3._rk,num_epochs,batch_size) + + call net%sync(1) + + !validation + call readfile('../data/montesinos_uni/y_ts.dat',ny1_ts,ny2_ts,y_ts) + call readfile('../data/montesinos_uni/x_ts.dat',nx1_ts,nx2_ts,x_ts) + + if(this_image().eq.1)then + write(*,*)'Correlation(s): ',corr_array(net%output(x_ts),y_ts) + endif + +contains + +subroutine readfile(filename,n,m,array) + character(len=*),intent(in)::filename + integer(ik),intent(out)::n,m + real(rk),allocatable,intent(out)::array(:,:) + + integer(ik)::un,i,io + + open(newunit=un,file=filename,status='old',action='read') + call numlines(un,m) + call numcol(un,n) + + allocate(array(n,m)) + rewind(un) + do i=1,m + read(un,*,iostat=io)array(:,i) + if(io.ne.0)exit + enddo + close(un) + +end subroutine + +pure function corr_array(array1,array2) result(a) + real(rk),intent(in)::array1(:,:),array2(:,:) + real(rk),allocatable::a(:) + + integer(ik)::i,n + + n=size(array1,dim=1) + + allocate(a(n)) + a=0.0_rk + do i=1,n + a(i)=corr(array1(i,:),array2(i,:)) + enddo + +end function + +pure real(rk) function corr(array1,array2) + real(rk),intent(in)::array1(:),array2(:) + + real(rk)::mean1,mean2 + + !brute force + + mean1=sum(array1)/size(array1) + mean2=sum(array2)/size(array2) + corr=dot_product(array1-mean1,array2-mean2)/sqrt(sum((array1-mean1)**2)*sum((array2-mean2)**2)) + +end function + +subroutine numlines(unfile,n) + implicit none + integer::io + integer,intent(in)::unfile + integer,intent(out)::n + rewind(unfile) + n=0 + do + read(unfile,*,iostat=io) + if (io.ne.0) exit + n=n+1 + enddo + rewind(unfile) +end subroutine + +subroutine numcol(unfile,n) + implicit none + integer,intent(in)::unfile + character(len=1000000)::a + integer,intent(out)::n + integer::curr,first,last,lena,stat,i + rewind(unfile) + read(unfile,"(a)")a + curr=1;lena=len(a);n=0 + do + first=0 + do i=curr,lena + if (a(i:i) /= " ") then + first=i + exit + endif + enddo + if (first == 0) exit + curr=first+1 + last=0 + do i=curr,lena + if (a(i:i) == " ") then + last=i + exit + endif + enddo + if (last == 0) last=lena + n=n+1 + curr=last+1 + enddo + rewind(unfile) +end subroutine + +end program