Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lib/mod_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module mod_layer
real(rk), allocatable :: z(:) ! arg. to activation function
procedure(activation_function), pointer, nopass :: activation => null()
procedure(activation_function), pointer, nopass :: activation_prime => null()
character(len=:), allocatable :: activation_str ! activation character string
contains
procedure, public, pass(self) :: set_activation
end type layer_type
Expand Down Expand Up @@ -115,7 +116,7 @@ subroutine dw_co_sum(dw)
end do
end subroutine dw_co_sum

pure subroutine set_activation(self, activation)
pure elemental subroutine set_activation(self, activation)
! Sets the activation function. Input string must match one of
! provided activation functions, otherwise it defaults to sigmoid.
! If activation not present, defaults to sigmoid.
Expand All @@ -125,21 +126,27 @@ pure subroutine set_activation(self, activation)
case('gaussian')
self % activation => gaussian
self % activation_prime => gaussian_prime
self % activation_str = 'gaussian'
case('relu')
self % activation => relu
self % activation_prime => relu_prime
self % activation_str = 'relu'
case('sigmoid')
self % activation => sigmoid
self % activation_prime => sigmoid_prime
self % activation_str = 'sigmoid'
case('step')
self % activation => step
self % activation_prime => step_prime
self % activation_str = 'step'
case('tanh')
self % activation => tanhf
self % activation_prime => tanh_prime
self % activation_str = 'tanh'
case default
self % activation => sigmoid
self % activation_prime => sigmoid_prime
self % activation_str = 'sigmoid'
end select
end subroutine set_activation

Expand Down
32 changes: 24 additions & 8 deletions src/lib/mod_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ module mod_network
procedure, public, pass(self) :: loss
procedure, public, pass(self) :: output
procedure, public, pass(self) :: save
procedure, public, pass(self) :: set_activation
procedure, public, pass(self) :: set_activation_equal
procedure, public, pass(self) :: set_activation_layers
procedure, public, pass(self) :: sync
procedure, public, pass(self) :: train_batch
procedure, public, pass(self) :: train_single
procedure, public, pass(self) :: update

generic, public :: set_activation => set_activation_equal, set_activation_layers
generic, public :: train => train_batch, train_single

end type network_type
Expand Down Expand Up @@ -136,13 +138,18 @@ subroutine load(self, filename)
! Loads the network from file.
class(network_type), intent(in out) :: self
character(len=*), intent(in) :: filename
integer(ik) :: fileunit, n, num_layers
integer(ik) :: fileunit, n, num_layers, layer_idx
integer(ik), allocatable :: dims(:)
character(len=100) :: buffer ! activation string
open(newunit=fileunit, file=filename, status='old', action='read')
read(fileunit, fmt=*) num_layers
allocate(dims(num_layers))
read(fileunit, fmt=*) dims
call self % init(dims)
do n = 1, num_layers
read(fileunit, fmt=*) layer_idx, buffer
call self % layers(layer_idx) % set_activation(trim(buffer))
end do
do n = 2, size(self % dims)
read(fileunit, fmt=*) self % layers(n) % b
end do
Expand Down Expand Up @@ -181,6 +188,9 @@ subroutine save(self, filename)
open(newunit=fileunit, file=filename)
write(fileunit, fmt=*) size(self % dims)
write(fileunit, fmt=*) self % dims
do n = 1, size(self % dims)
write(fileunit, fmt=*) n, self % layers(n) % activation_str
end do
do n = 2, size(self % dims)
write(fileunit, fmt=*) self % layers(n) % b
end do
Expand All @@ -190,17 +200,23 @@ subroutine save(self, filename)
close(fileunit)
end subroutine save

pure subroutine set_activation(self, activation)
pure subroutine set_activation_equal(self, activation)
! A thin wrapper around layer % set_activation().
! This method can be used to set an activation function
! for all layers at once.
class(network_type), intent(in out) :: self
character(len=*), intent(in) :: activation
integer :: n
do concurrent(n = 1:size(self % layers))
call self % layers(n) % set_activation(activation)
end do
end subroutine set_activation
call self % layers(:) % set_activation(activation)
end subroutine set_activation_equal

pure subroutine set_activation_layers(self, activation)
! A thin wrapper around layer % set_activation().
! This method can be used to set different activation functions
! for each layer separately.
class(network_type), intent(in out) :: self
character(len=*), intent(in) :: activation(size(self % layers))
call self % layers(:) % set_activation(activation)
end subroutine set_activation_layers

subroutine sync(self, image)
! Broadcasts network weights and biases from
Expand Down
14 changes: 14 additions & 0 deletions src/tests/test_network_save.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ program test_network_save
print *, 'Initializing 2 networks with random weights and biases'
net1 = network_type([768, 30, 10])
net2 = network_type([768, 30, 10])

print *, 'Save network 1 into file'
call net1 % save('test_network.dat')
call net2 % load('test_network.dat')
Expand All @@ -15,4 +16,17 @@ program test_network_save
all(net1 % layers(n) % w == net2 % layers(n) % w),&
', biases equal:', all(net1 % layers(n) % b == net2 % layers(n) % b)
end do
print *, ''

print *, 'Setting different activation functions for each layer of network 1'
call net1 % set_activation([character(len=10) :: 'sigmoid', 'tanh', 'gaussian'])
print *, 'Save network 1 into file'
call net1 % save('test_network.dat')
call net2 % load('test_network.dat')
print *, 'Load network 2 from file'
do n = 1, size(net1 % layers)
print *, 'Layer ', n, ', activation functions equal:',&
associated(net1 % layers(n) % activation, net2 % layers(n) % activation),&
'(network 1: ', net1 % layers(n) % activation_str, ', network 2: ', net2 % layers(n) % activation_str,')'
end do
end program test_network_save