diff --git a/README.md b/README.md index 662a841..bbdf8a5 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![Latest Release](https://img.shields.io/github/v/release/nedtaylor/diffstruc?sort=semver)](https://github.com/nedtaylor/diffstruc/releases "View on GitHub") [![FPM](https://img.shields.io/badge/fpm-0.12.0-purple)](https://github.com/fortran-lang/fpm "View Fortran Package Manager") [![GCC compatibility](https://img.shields.io/badge/gcc-15.2.0-green)](https://gcc.gnu.org/gcc-15/ "View GCC") +[![IFX compatibility](https://img.shields.io/badge/ifx-2025.2.0-green)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/fortran-compiler.html "View ifx") # diffstruc @@ -52,7 +53,11 @@ The library has the following dependencies - [fpm](https://github.com/fortran-lang/fpm) The library has been developed and tested using the following compilers: -- gfortran -- gcc 15.1.0 +- gfortran -- gcc 15.2.0 +- ifx -- ifx 2025.2.0 + +> **_NOTE:_** diffstruc is known to be incompatible with all versions of the gfortran compiler below `14.3.0` due to issues with the calling of the `final` procedure of `array_type`. + ### Building with fpm diff --git a/src/diffstruc/diffstruc_operations_linalg_sub.f90 b/src/diffstruc/diffstruc_operations_linalg_sub.f90 index f458883..1ac71d4 100644 --- a/src/diffstruc/diffstruc_operations_linalg_sub.f90 +++ b/src/diffstruc/diffstruc_operations_linalg_sub.f90 @@ -6,7 +6,7 @@ contains !############################################################################### - function matmul_arrays(a, b) result(c) + module function matmul_arrays(a, b) result(c) !! Matrix multiplication of two autodiff arrays implicit none class(array_type), intent(in), target :: a, b @@ -54,7 +54,7 @@ function matmul_arrays(a, b) result(c) end if end function matmul_arrays !------------------------------------------------------------------------------- - function matmul_real2d(a, b) result(c) + module function matmul_real2d(a, b) result(c) !! Matrix multiplication of a real array and an autodiff array implicit none class(array_type), intent(in), target :: a @@ -91,7 +91,7 @@ function matmul_real2d(a, b) result(c) c%owns_right_operand = .true. end function matmul_real2d !------------------------------------------------------------------------------- - function real2d_matmul(a, b) result(c) + module function real2d_matmul(a, b) result(c) !! Matrix multiplication of two autodiff arrays implicit none real(real32), dimension(:,:), intent(in) :: a @@ -195,7 +195,7 @@ end function get_partial_matmul_right !############################################################################### - function outer_product_arrays(a, b) result(c) + module function outer_product_arrays(a, b) result(c) !! Outer product of two autodiff arrays implicit none class(array_type), intent(in), target :: a, b @@ -265,7 +265,7 @@ end function get_partial_outer_product_right !############################################################################### - function transpose_array(a) result(c) + module function transpose_array(a) result(c) !! Transpose an autodiff array implicit none class(array_type), intent(in), target :: a diff --git a/src/diffstruc/diffstruc_types.f90 b/src/diffstruc/diffstruc_types.f90 index 50b0855..c4b3edc 100644 --- a/src/diffstruc/diffstruc_types.f90 +++ b/src/diffstruc/diffstruc_types.f90 @@ -13,9 +13,8 @@ module diffstruc__types public :: operator(+), operator(-), operator(*), operator(/), operator(**) public :: sum, mean, spread, unspread, exp, log - type array_ptr - type(array_type), pointer :: p => null() - end type array_ptr + + !------------------------------------------------------------------------------- ! Automatic differentiation derived type !------------------------------------------------------------------------------- @@ -53,17 +52,28 @@ module diffstruc__types type(array_type), pointer :: right_operand => null() !! Right operand for backward pass character(len=32) :: operation = 'none' + !! Name of the operation that created this array logical :: owns_left_operand = .false. + !! Boolean whether this array owns its left operand logical :: owns_right_operand = .false. + !! Boolean whether this array owns its right operand logical :: owns_gradient = .true. - !! Flag indicating if this array owns its gradient memory + !! Boolean whether this array owns its gradient memory logical :: fix_pointer = .false. + !! Boolean to prevent pointer changes during graph operations + !! ... i.e. during duplicate_graph or nullify_graph, this array will not be + !! ... duplicated or nullified logical :: is_temporary = .true. + !! Boolean indicating if array is temporary real(real32), dimension(:), allocatable :: direction + !! Direction vector for forward-mode differentiation + !! ... i.e. derivative wrt this direction will be computed procedure(get_partial), pass(this), pointer :: get_partial_left => null() + !! Pointer procedure for getting partial derivative wrt left operand procedure(get_partial), pass(this), pointer :: get_partial_right => null() + !! Pointer procedure for getting partial derivative wrt right operand contains procedure, pass(this) :: allocate => allocate_array @@ -76,36 +86,54 @@ module diffstruc__types !! Procedure for shallow assignment of array procedure, pass(this) :: assign_and_deallocate_source !! Procedure for assigning and deallocating source array - procedure :: assign => assign_array - generic, public :: assignment(=) => assign - !! Overloaded assignment operator + procedure, pass(this) :: set => set_array !! Procedure for setting array procedure, pass(this) :: extract => extract_array !! Procedure for extracting array as a standard real array procedure, pass(this) :: set_direction + !! Set the direction vector for forward-mode differentiation + procedure, pass(this) :: set_requires_grad + !! Set whether gradients are required procedure, pass(this) :: grad_reverse !! Reverse-mode: accumulate gradients wrt all inputs procedure, pass(this) :: grad_forward !! Forward-mode: return derivative wrt variable pointer - !! Backward pass for gradient computation procedure, pass(this) :: zero_grad + !! Zero the gradient procedure, pass(this) :: zero_all_grads !! Zero the gradients procedure, pass(this) :: reset_graph + !! Reset the gradient graph of this array procedure, pass(this) :: duplicate_graph + !! Duplicate the computation graph and return pointer to new graph procedure, pass(this) :: nullify_graph + !! Nullify the computation graph procedure, pass(this) :: get_ptr_from_id + !! Get pointer to array in graph by its ID procedure, pass(this) :: detach !! Detach from computation graph - procedure, pass(this) :: set_requires_grad - !! Set requires_grad flag procedure, pass(this) :: create_result => create_result_array !! Helper to safely create result arrays procedure, pass(this) :: print_graph + !! Print the computation graph + + procedure :: add_arrays, add_real1d, add_real2d, add_scalar + generic, public :: operator(+) => & + add_arrays, add_real1d, add_real2d, add_scalar + !! Overloaded addition operator + + procedure :: subtract_arrays, subtract_real1d, subtract_scalar, negate_array + generic, public :: operator(-) => & + subtract_arrays, subtract_real1d, subtract_scalar, negate_array + !! Overloaded subtraction operator + + procedure :: assign => assign_array + generic, public :: assignment(=) => assign + !! Overloaded assignment operator final :: finalise_array !! Finaliser for array type @@ -179,9 +207,8 @@ module function grad_forward(this, variable) result(output) type(array_type), pointer :: output end function grad_forward - module subroutine grad_reverse(this, record_graph, reset_graph) + module subroutine grad_reverse(this, reset_graph) class(array_type), intent(inout) :: this - logical, intent(in), optional :: record_graph logical, intent(in), optional :: reset_graph end subroutine grad_reverse end interface @@ -258,9 +285,23 @@ module subroutine print_graph(this) class(array_type), intent(in) :: this end subroutine print_graph end interface +!------------------------------------------------------------------------------- + + +!------------------------------------------------------------------------------- +! Pointer wrapper type +!------------------------------------------------------------------------------- + type :: array_ptr + type(array_type), pointer :: p => null() + end type array_ptr +!------------------------------------------------------------------------------- - ! Operation interfaces +!------------------------------------------------------------------------------- +! Operation interfaces +!------------------------------------------------------------------------------- + + ! Arithmetic reduction interfaces that are directly overloaded into array_type !----------------------------------------------------------------------------- interface mean module function mean_array(a, dim) result(c) @@ -310,7 +351,9 @@ end function unspread_array end interface - interface operator(+) + ! Arithmetic operator interfaces that are directly overloaded into array_type + !----------------------------------------------------------------------------- + interface module function add_arrays(a, b) result(c) class(array_type), intent(in), target :: a, b type(array_type), pointer :: c @@ -322,39 +365,18 @@ module function add_real2d(a, b) result(c) type(array_type), pointer :: c end function add_real2d - module function real2d_add(a, b) result(c) - real(real32), dimension(:,:), intent(in) :: a - class(array_type), intent(in), target :: b - type(array_type), pointer :: c - end function real2d_add - module function add_real1d(a, b) result(c) class(array_type), intent(in), target :: a real(real32), dimension(:), intent(in) :: b type(array_type), pointer :: c end function add_real1d - module function real1d_add(a, b) result(c) - real(real32), dimension(:), intent(in) :: a - class(array_type), intent(in), target :: b - type(array_type), pointer :: c - end function real1d_add - module function add_scalar(a, b) result(c) class(array_type), intent(in), target :: a real(real32), intent(in) :: b type(array_type), pointer :: c end function add_scalar - module function scalar_add(a, b) result(c) - real(real32), intent(in) :: a - class(array_type), intent(in), target :: b - type(array_type), pointer :: c - end function scalar_add - end interface - - - interface operator(-) module function subtract_arrays(a, b) result(c) class(array_type), intent(in), target :: a, b type(array_type), pointer :: c @@ -372,16 +394,42 @@ module function subtract_scalar(a, b) result(c) type(array_type), pointer :: c end function subtract_scalar - module function scalar_subtract(a, b) result(c) + module function negate_array(a) result(c) + class(array_type), intent(in), target :: a + type(array_type), pointer :: c + end function negate_array + end interface + + + ! Arithmetic operator interfaces that are not directly overloaded into array_type + interface operator(+) + module function real2d_add(a, b) result(c) + real(real32), dimension(:,:), intent(in) :: a + class(array_type), intent(in), target :: b + type(array_type), pointer :: c + end function real2d_add + + module function real1d_add(a, b) result(c) + real(real32), dimension(:), intent(in) :: a + class(array_type), intent(in), target :: b + type(array_type), pointer :: c + end function real1d_add + + + module function scalar_add(a, b) result(c) real(real32), intent(in) :: a class(array_type), intent(in), target :: b type(array_type), pointer :: c - end function scalar_subtract + end function scalar_add + end interface - module function negate_array(a) result(c) - class(array_type), intent(in), target :: a + + interface operator(-) + module function scalar_subtract(a, b) result(c) + real(real32), intent(in) :: a + class(array_type), intent(in), target :: b type(array_type), pointer :: c - end function negate_array + end function scalar_subtract end interface @@ -482,5 +530,6 @@ module function log_array(a) result(c) type(array_type), pointer :: c end function log_array end interface +!------------------------------------------------------------------------------- end module diffstruc__types diff --git a/src/diffstruc/diffstruc_types_sub.f90 b/src/diffstruc/diffstruc_types_sub.f90 index c718ed8..cf0c552 100644 --- a/src/diffstruc/diffstruc_types_sub.f90 +++ b/src/diffstruc/diffstruc_types_sub.f90 @@ -466,17 +466,13 @@ end function grad_forward !############################################################################### - module subroutine grad_reverse(this, record_graph, reset_graph) + module subroutine grad_reverse(this, reset_graph) !! Perform backward pass starting from this array implicit none class(array_type), intent(inout) :: this - logical, intent(in), optional :: record_graph logical, intent(in), optional :: reset_graph - logical :: record_graph_ - record_graph_ = .true. - if(present(record_graph)) record_graph_ = record_graph if(present(reset_graph))then if(reset_graph) call this%reset_graph() end if @@ -488,7 +484,6 @@ module subroutine grad_reverse(this, record_graph, reset_graph) ! Safely initialise gradient without copying computation graph call this%grad%allocate(array_shape=[this%shape, size(this%val,2)]) this%grad%is_sample_dependent = this%is_sample_dependent - this%grad%requires_grad = record_graph_ this%grad%operation = 'none' this%grad%left_operand => null() this%grad%right_operand => null() @@ -505,11 +500,7 @@ module subroutine grad_reverse(this, record_graph, reset_graph) end if ! Recursively compute gradients - if(record_graph_)then - call reverse_mode_ptr(this, this%grad, 0) - else - call reverse_mode(this, this%grad, 0) - end if + call reverse_mode_ptr(this, this%grad, 0) end subroutine grad_reverse !############################################################################### @@ -638,7 +629,7 @@ end function forward_over_reverse !############################################################################### - module recursive subroutine reverse_mode_ptr(array, upstream_grad, depth) + recursive subroutine reverse_mode_ptr(array, upstream_grad, depth) !! Backward operation for arrays implicit none class(array_type), intent(inout) :: array @@ -679,13 +670,13 @@ end subroutine reverse_mode_ptr recursive subroutine accumulate_gradient_ptr(array, grad, depth) !! Accumulate gradient for array with safe memory management implicit none - type(array_type), intent(inout) :: array + type(array_type), intent(inout), target :: array type(array_type), intent(in), pointer :: grad integer, intent(in) :: depth integer :: s logical :: is_directional - type(array_type), pointer :: directional_grad + type(array_type), pointer :: directional_grad, tmp_ptr is_directional = .false. if(allocated(array%direction))then @@ -740,104 +731,6 @@ end subroutine accumulate_gradient_ptr !############################################################################### -!############################################################################### - module recursive subroutine reverse_mode(array, upstream_grad, depth) - !! Backward operation for arrays - implicit none - class(array_type), intent(inout) :: array - type(array_type), intent(in) :: upstream_grad - integer, intent(in) :: depth - - type(array_type) :: left_partial, right_partial - - ! write(*,'("Performing backward operation for: ",A,T60,"id: ",I0)') & - ! trim(array%operation), array%id - if(depth.gt.max_recursion_depth)then - write(0,*) "MAX RECURSION DEPTH REACHED IN REVERSE MODE", depth - return - end if - array%is_forward = .false. - if(associated(array%left_operand))then - if(array%left_operand%requires_grad)then - left_partial = array%get_partial_left(upstream_grad) - left_partial%is_temporary = .false. - call accumulate_gradient(array%left_operand, left_partial, depth) - end if - end if - if(associated(array%right_operand))then - if(array%right_operand%requires_grad)then - right_partial = array%get_partial_right(upstream_grad) - right_partial%is_temporary = .false. - call accumulate_gradient(array%right_operand, right_partial, depth) - end if - end if - ! write(*,*) "done operation: ", trim(array%operation) - end subroutine reverse_mode -!############################################################################### - - -!############################################################################### - recursive subroutine accumulate_gradient(array, grad, depth) - !! Accumulate gradient for array with safe memory management - implicit none - type(array_type), intent(inout) :: array - type(array_type), intent(inout) :: grad - integer, intent(in) :: depth - - integer :: s - - ! Apply direction if specified (in-place to avoid copy) - if(allocated(array%direction))then - if(size(array%direction).gt.0)then - do s = 1, size(grad%val, 2) - grad%val(:, s) = grad%val(:, s) * array%direction - end do - end if - end if - - if(.not. associated(array%grad))then - ! First gradient accumulation - allocate and set - allocate(array%grad) - if(array%is_sample_dependent)then - call array%grad%allocate(array_shape=[grad%shape,size(grad%val,2)]) - array%grad%val = grad%val - else - call array%grad%allocate(array_shape=[grad%shape,1]) - array%grad%val(:,1) = sum(grad%val, dim = 2) - end if - array%grad%is_scalar = array%is_scalar - array%grad%is_sample_dependent = array%is_sample_dependent - array%grad%requires_grad = .false. - array%grad%owns_gradient = .false. - array%owns_gradient = .true. - array%grad%is_temporary = array%is_temporary - else - ! Accumulate to existing gradient (in-place) - if(array%is_sample_dependent)then - array%grad%val = array%grad%val + grad%val - else - ! rtmp1 = real(size(grad%val,2), real32) - ! ! mean reduction - ! do concurrent(s = 1:size(grad%val,1)) - ! array%grad%val(s,1) = array%grad%val(s,1) + sum(grad%val(s,:)) / rtmp1 - ! end do - ! sum reduction - array%grad%val(:,1) = array%grad%val(:,1) + sum(grad%val, dim = 2) - end if - end if - - if(associated(array%left_operand).or.associated(array%right_operand))then - call reverse_mode(array, grad, depth+1) - end if - if(array%grad%is_temporary)then - call array%grad%nullify_graph() - call array%grad%deallocate() - deallocate(array%grad) - array%owns_gradient = .false. - end if - end subroutine accumulate_gradient -!############################################################################### - !##############################################################################! ! * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * ! @@ -877,7 +770,7 @@ end subroutine zero_all_grads !############################################################################### - module recursive subroutine zero_all_fixed_pointer_grads(this) + recursive subroutine zero_all_fixed_pointer_grads(this) !! Zero the gradients of this array implicit none type(array_type), intent(inout) :: this @@ -933,7 +826,7 @@ subroutine double_map_add(src_map, dst_map, src_ptr, dst_ptr) !! Add pointer pair to double-array map (grow if needed) implicit none type(array_ptr), allocatable :: src_map(:), dst_map(:) - type(array_type), pointer, intent(in) :: src_ptr, dst_ptr + type(array_type), intent(in), target :: src_ptr, dst_ptr integer :: n, i, newcap if(.not. allocated(src_map)) allocate(src_map(default_map_capacity)) if(.not. allocated(dst_map)) allocate(dst_map(default_map_capacity)) @@ -958,7 +851,7 @@ subroutine single_map_add(map, ptr) !! Add pointer to single-array map (grow if needed) implicit none type(array_ptr), allocatable :: map(:) - type(array_type), pointer, intent(in) :: ptr + type(array_type), intent(in), target :: ptr integer :: n, i if(.not. allocated(map))then @@ -981,7 +874,7 @@ function map_find(map, target) result(idx) !! Check if target pointer exists in single-array map implicit none type(array_ptr), allocatable :: map(:) - type(array_type), pointer, intent(in) :: target + type(array_type), intent(in), target :: target integer :: idx, n, i idx = 0 @@ -1273,8 +1166,18 @@ module subroutine detach(this) this%requires_grad = .false. this%operation = 'none' + if(this%owns_left_operand.and.associated(this%left_operand))then + call this%left_operand%deallocate() + deallocate(this%left_operand) + end if + if(this%owns_right_operand.and.associated(this%right_operand))then + call this%right_operand%deallocate() + deallocate(this%right_operand) + end if this%left_operand => null() this%right_operand => null() + this%owns_left_operand = .false. + this%owns_right_operand = .false. end subroutine detach !############################################################################### @@ -1297,7 +1200,7 @@ end subroutine set_requires_grad !############################################################################### - subroutine set_direction(this, direction) + module subroutine set_direction(this, direction) !! Set the direction for the array (for higher-order derivatives) implicit none class(array_type), intent(inout) :: this diff --git a/test/test_functions.f90 b/test/test_functions.f90 index d7a5f77..24cd056 100644 --- a/test/test_functions.f90 +++ b/test/test_functions.f90 @@ -30,7 +30,7 @@ program test_functions ! Compute first derivatives (gradient) call x%set_direction([1._real32, 1._real32]) - call f%grad_reverse( record_graph=.true.) + call f%grad_reverse() write(*,*) "First derivatives (gradient):" if(associated(x%grad)) then write(*,*) " df/dx1 =", x%grad%val(1,1) @@ -61,7 +61,7 @@ program test_functions ! Compute first derivatives (gradient) call x%set_direction([1._real32, 1._real32]) - call f%grad_reverse( record_graph=.true., reset_graph=.true. ) + call f%grad_reverse( reset_graph=.true. ) write(*,*) "First derivatives (gradient):" if(associated(x%grad)) then write(*,*) " df/dx1 =", x%grad%val(1,1) @@ -103,7 +103,7 @@ program test_functions call f%reset_graph() f = mean( x * tanh(x), dim = 1 ) write(*,*) "Function value f =", f%val(:,1) - call f%grad_reverse( record_graph=.true. ) + call f%grad_reverse() write(*,*) "Gradient (first derivatives):" if(associated(x%grad)) then write(*,*) " df/dx1 =", x%grad%val(:,1) @@ -118,7 +118,7 @@ program test_functions ! f = x**4 f = x * x * x!* x * x!tanh(x) !mean( x * tanh(x), dim = 1 ) write(*,*) "Function value f =", f%val(:,1) - call f%grad_reverse( record_graph=.true. ) + call f%grad_reverse() write(*,*) "Gradient (first derivatives):" if(associated(x%grad)) then write(*,*) " df/dx1 =", x%grad%val(:,1) @@ -142,7 +142,7 @@ program test_functions loss = sum(residual ** 2, dim=1) write(*,*) " Loss (should be close to 0):", loss%val(1,1) - call loss%grad_reverse( record_graph=.true. ) + call loss%grad_reverse() write(*,*) "Gradient of loss:" if(associated(x%grad)) then diff --git a/test/test_memory.f90 b/test/test_memory.f90 index 4b59a3c..d378ef6 100644 --- a/test/test_memory.f90 +++ b/test/test_memory.f90 @@ -35,7 +35,7 @@ program test_memory_detailed temp => x**2 + y * x + exp(x * 0.01_real32) call f%assign_and_deallocate_source(temp) f%is_temporary = .false. - call f%grad_reverse(record_graph=.false., reset_graph=.true.) + call f%grad_reverse(reset_graph=.true.) call f%nullify_graph() call f%deallocate() if (mod(i, 10) == 0) then @@ -78,7 +78,7 @@ program test_memory_detailed xgradgrad%is_temporary = .false. end if else - call temp%grad_reverse(record_graph=.true., reset_graph=.true.) + call temp%grad_reverse(reset_graph=.true.) end if ! Explicit cleanup of temp (THIS IS KEY TO AVOIDING LEAKS)