Skip to content

Commit

Permalink
addition of maxabs
Browse files Browse the repository at this point in the history
  • Loading branch information
Vandenplas, Jeremie committed Jun 14, 2024
1 parent f506901 commit 3f2f7f3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
8 changes: 7 additions & 1 deletion example/dense_mnist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ program dense_mnist
net, validation_images, label_digits(validation_labels)) * 100, ' %'

block
real :: output_metrics(10,2) ! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
real, allocatable :: output_metrics(:,:) ! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metric=corr())
print *, "Metrics: quadratic loss, Pearson corr.:", sum(output_metrics, 1) / size(output_metrics, 1)
end block

block
real, allocatable :: output_metrics(:,:) ! 3 metrics; 1st is default loss function (quadratic), others are Pearson corr.
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metrics=[corr(), corr()])
print *, "Metrics: quadratic loss, Pearson corr.:", sum(output_metrics, 1) / size(output_metrics, 1)
end block

end do epochs

contains
Expand Down
2 changes: 1 addition & 1 deletion src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module nf
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: mse, quadratic
use nf_metrics, only: corr
use nf_metrics, only: corr, maxabs
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
25 changes: 23 additions & 2 deletions src/nf/nf_metrics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module nf_metrics
private
public :: metric_type
public :: corr
public :: maxabs

type, abstract :: metric_type
contains
Expand All @@ -27,6 +28,12 @@ end function metric_interface
procedure, nopass :: eval => corr_eval
end type corr

type, extends(metric_type) :: maxabs
!! Maximum absolute difference
contains
procedure, nopass :: eval => maxabs_eval
end type maxabs

contains

pure module function corr_eval(true, predicted) result(res)
Expand All @@ -37,7 +44,7 @@ pure module function corr_eval(true, predicted) result(res)
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting loss value
!! Resulting correlation value
real :: m_true, m_pred

m_true = sum(true) / size(true)
Expand All @@ -48,4 +55,18 @@ pure module function corr_eval(true, predicted) result(res)

end function corr_eval

end module nf_metrics
pure function maxabs_eval(true, predicted) result(res)
!! Maximum absolute difference function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting maximum absolute difference value

res = maxval(abs(true - predicted))

end function maxabs_eval

end module nf_metrics

0 comments on commit 3f2f7f3

Please sign in to comment.