Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
[![Latest Release](https://img.shields.io/github/v/release/nedtaylor/diffstruc?sort=semver)](https://github.com/nedtaylor/diffstruc/releases "View on GitHub")
[![FPM](https://img.shields.io/badge/fpm-0.12.0-purple)](https://github.com/fortran-lang/fpm "View Fortran Package Manager")
[![GCC compatibility](https://img.shields.io/badge/gcc-15.2.0-green)](https://gcc.gnu.org/gcc-15/ "View GCC")
[![IFX compatibility](https://img.shields.io/badge/ifx-2025.2.0-green)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/fortran-compiler.html "View ifx")

# diffstruc

Expand Down Expand Up @@ -52,7 +53,11 @@ The library has the following dependencies
- [fpm](https://github.com/fortran-lang/fpm)

The library has been developed and tested using the following compilers:
- gfortran -- gcc 15.1.0
- gfortran -- gcc 15.2.0
- ifx -- ifx 2025.2.0

> **_NOTE:_** diffstruc is known to be incompatible with all versions of the gfortran compiler below `14.3.0` due to issues with the calling of the `final` procedure of `array_type`.



### Building with fpm
Expand Down
10 changes: 5 additions & 5 deletions src/diffstruc/diffstruc_operations_linalg_sub.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
contains

!###############################################################################
function matmul_arrays(a, b) result(c)
module function matmul_arrays(a, b) result(c)
!! Matrix multiplication of two autodiff arrays
implicit none
class(array_type), intent(in), target :: a, b
Expand Down Expand Up @@ -54,7 +54,7 @@ function matmul_arrays(a, b) result(c)
end if
end function matmul_arrays
!-------------------------------------------------------------------------------
function matmul_real2d(a, b) result(c)
module function matmul_real2d(a, b) result(c)
!! Matrix multiplication of a real array and an autodiff array
implicit none
class(array_type), intent(in), target :: a
Expand Down Expand Up @@ -91,7 +91,7 @@ function matmul_real2d(a, b) result(c)
c%owns_right_operand = .true.
end function matmul_real2d
!-------------------------------------------------------------------------------
function real2d_matmul(a, b) result(c)
module function real2d_matmul(a, b) result(c)
!! Matrix multiplication of two autodiff arrays
implicit none
real(real32), dimension(:,:), intent(in) :: a
Expand Down Expand Up @@ -195,7 +195,7 @@ end function get_partial_matmul_right


!###############################################################################
function outer_product_arrays(a, b) result(c)
module function outer_product_arrays(a, b) result(c)
!! Outer product of two autodiff arrays
implicit none
class(array_type), intent(in), target :: a, b
Expand Down Expand Up @@ -265,7 +265,7 @@ end function get_partial_outer_product_right


!###############################################################################
function transpose_array(a) result(c)
module function transpose_array(a) result(c)
!! Transpose an autodiff array
implicit none
class(array_type), intent(in), target :: a
Expand Down
129 changes: 89 additions & 40 deletions src/diffstruc/diffstruc_types.f90
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ module diffstruc__types
public :: operator(+), operator(-), operator(*), operator(/), operator(**)
public :: sum, mean, spread, unspread, exp, log

type array_ptr
type(array_type), pointer :: p => null()
end type array_ptr


!-------------------------------------------------------------------------------
! Automatic differentiation derived type
!-------------------------------------------------------------------------------
Expand Down Expand Up @@ -53,17 +52,28 @@ module diffstruc__types
type(array_type), pointer :: right_operand => null()
!! Right operand for backward pass
character(len=32) :: operation = 'none'
!! Name of the operation that created this array
logical :: owns_left_operand = .false.
!! Boolean whether this array owns its left operand
logical :: owns_right_operand = .false.
!! Boolean whether this array owns its right operand
logical :: owns_gradient = .true.
!! Flag indicating if this array owns its gradient memory
!! Boolean whether this array owns its gradient memory
logical :: fix_pointer = .false.
!! Boolean to prevent pointer changes during graph operations
!! ... i.e. during duplicate_graph or nullify_graph, this array will not be
!! ... duplicated or nullified
logical :: is_temporary = .true.
!! Boolean indicating if array is temporary

real(real32), dimension(:), allocatable :: direction
!! Direction vector for forward-mode differentiation
!! ... i.e. derivative wrt this direction will be computed

procedure(get_partial), pass(this), pointer :: get_partial_left => null()
!! Pointer procedure for getting partial derivative wrt left operand
procedure(get_partial), pass(this), pointer :: get_partial_right => null()
!! Pointer procedure for getting partial derivative wrt right operand

contains
procedure, pass(this) :: allocate => allocate_array
Expand All @@ -76,36 +86,54 @@ module diffstruc__types
!! Procedure for shallow assignment of array
procedure, pass(this) :: assign_and_deallocate_source
!! Procedure for assigning and deallocating source array
procedure :: assign => assign_array
generic, public :: assignment(=) => assign
!! Overloaded assignment operator

procedure, pass(this) :: set => set_array
!! Procedure for setting array
procedure, pass(this) :: extract => extract_array
!! Procedure for extracting array as a standard real array

procedure, pass(this) :: set_direction
!! Set the direction vector for forward-mode differentiation
procedure, pass(this) :: set_requires_grad
!! Set whether gradients are required
procedure, pass(this) :: grad_reverse
!! Reverse-mode: accumulate gradients wrt all inputs
procedure, pass(this) :: grad_forward
!! Forward-mode: return derivative wrt variable pointer

!! Backward pass for gradient computation
procedure, pass(this) :: zero_grad
!! Zero the gradient
procedure, pass(this) :: zero_all_grads
!! Zero the gradients
procedure, pass(this) :: reset_graph
!! Reset the gradient graph of this array
procedure, pass(this) :: duplicate_graph
!! Duplicate the computation graph and return pointer to new graph
procedure, pass(this) :: nullify_graph
!! Nullify the computation graph
procedure, pass(this) :: get_ptr_from_id
!! Get pointer to array in graph by its ID
procedure, pass(this) :: detach
!! Detach from computation graph
procedure, pass(this) :: set_requires_grad
!! Set requires_grad flag
procedure, pass(this) :: create_result => create_result_array
!! Helper to safely create result arrays

procedure, pass(this) :: print_graph
!! Print the computation graph

procedure :: add_arrays, add_real1d, add_real2d, add_scalar
generic, public :: operator(+) => &
add_arrays, add_real1d, add_real2d, add_scalar
!! Overloaded addition operator

procedure :: subtract_arrays, subtract_real1d, subtract_scalar, negate_array
generic, public :: operator(-) => &
subtract_arrays, subtract_real1d, subtract_scalar, negate_array
!! Overloaded subtraction operator

procedure :: assign => assign_array
generic, public :: assignment(=) => assign
!! Overloaded assignment operator

final :: finalise_array
!! Finaliser for array type
Expand Down Expand Up @@ -179,9 +207,8 @@ module function grad_forward(this, variable) result(output)
type(array_type), pointer :: output
end function grad_forward

module subroutine grad_reverse(this, record_graph, reset_graph)
module subroutine grad_reverse(this, reset_graph)
class(array_type), intent(inout) :: this
logical, intent(in), optional :: record_graph
logical, intent(in), optional :: reset_graph
end subroutine grad_reverse
end interface
Expand Down Expand Up @@ -258,9 +285,23 @@ module subroutine print_graph(this)
class(array_type), intent(in) :: this
end subroutine print_graph
end interface
!-------------------------------------------------------------------------------


!-------------------------------------------------------------------------------
! Pointer wrapper type
!-------------------------------------------------------------------------------
type :: array_ptr
type(array_type), pointer :: p => null()
end type array_ptr
!-------------------------------------------------------------------------------


! Operation interfaces
!-------------------------------------------------------------------------------
! Operation interfaces
!-------------------------------------------------------------------------------

! Arithmetic reduction interfaces that are directly overloaded into array_type
!-----------------------------------------------------------------------------
interface mean
module function mean_array(a, dim) result(c)
Expand Down Expand Up @@ -310,7 +351,9 @@ end function unspread_array
end interface


interface operator(+)
! Arithmetic operator interfaces that are directly overloaded into array_type
!-----------------------------------------------------------------------------
interface
module function add_arrays(a, b) result(c)
class(array_type), intent(in), target :: a, b
type(array_type), pointer :: c
Expand All @@ -322,39 +365,18 @@ module function add_real2d(a, b) result(c)
type(array_type), pointer :: c
end function add_real2d

module function real2d_add(a, b) result(c)
real(real32), dimension(:,:), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function real2d_add

module function add_real1d(a, b) result(c)
class(array_type), intent(in), target :: a
real(real32), dimension(:), intent(in) :: b
type(array_type), pointer :: c
end function add_real1d

module function real1d_add(a, b) result(c)
real(real32), dimension(:), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function real1d_add

module function add_scalar(a, b) result(c)
class(array_type), intent(in), target :: a
real(real32), intent(in) :: b
type(array_type), pointer :: c
end function add_scalar

module function scalar_add(a, b) result(c)
real(real32), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function scalar_add
end interface


interface operator(-)
module function subtract_arrays(a, b) result(c)
class(array_type), intent(in), target :: a, b
type(array_type), pointer :: c
Expand All @@ -372,16 +394,42 @@ module function subtract_scalar(a, b) result(c)
type(array_type), pointer :: c
end function subtract_scalar

module function scalar_subtract(a, b) result(c)
module function negate_array(a) result(c)
class(array_type), intent(in), target :: a
type(array_type), pointer :: c
end function negate_array
end interface


! Arithmetic operator interfaces that are not directly overloaded into array_type
interface operator(+)
module function real2d_add(a, b) result(c)
real(real32), dimension(:,:), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function real2d_add

module function real1d_add(a, b) result(c)
real(real32), dimension(:), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function real1d_add


module function scalar_add(a, b) result(c)
real(real32), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function scalar_subtract
end function scalar_add
end interface

module function negate_array(a) result(c)
class(array_type), intent(in), target :: a

interface operator(-)
module function scalar_subtract(a, b) result(c)
real(real32), intent(in) :: a
class(array_type), intent(in), target :: b
type(array_type), pointer :: c
end function negate_array
end function scalar_subtract
end interface


Expand Down Expand Up @@ -482,5 +530,6 @@ module function log_array(a) result(c)
type(array_type), pointer :: c
end function log_array
end interface
!-------------------------------------------------------------------------------

end module diffstruc__types
Loading