Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ add_library(neural-fortran
src/nf/nf_maxpool2d_layer.f90
src/nf/nf_maxpool2d_layer_submodule.f90
src/nf/nf_metrics.f90
src/nf/nf_multihead_attention.f90
src/nf/nf_multihead_attention_submodule.f90
src/nf/nf_multihead_attention_layer.f90
src/nf/nf_multihead_attention_layer_submodule.f90
src/nf/nf_network.f90
src/nf/nf_network_submodule.f90
src/nf/nf_optimizers.f90
Expand Down
14 changes: 5 additions & 9 deletions src/nf/nf_cross_attention_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ module nf_cross_attention_layer
end type cross_attention_layer

interface cross_attention_layer
module function cross_attention_layer_cons(n_heads) result(res)
!! This function returns the `cross_attention_layer` instance.
integer, intent(in) :: sequence_length, model_dimension, n_heads
type(cross_attention_layer) :: res
end function cross_attention_layer_cons
module procedure cross_attention_layer_cons
end interface cross_attention_layer

contains
module function cross_attention_layer_cons(n_heads) result(res)
function cross_attention_layer_cons(n_heads) result(res)
!! This function returns the `cross_attention_layer` instance.
integer, intent(in) :: n_heads
type(cross_attention_layer) :: res
res % n_heads = n_heads
end function cross_attention_layer_cons

pure module subroutine backward(self, input, gradient)
pure subroutine backward(self, input, gradient)
!! Cross Attention Back propagation
class(cross_attention_layer), intent(in out) :: self
real, intent(in) :: input(:, :, :)
Expand All @@ -46,7 +42,7 @@ pure module subroutine backward(self, input, gradient)
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
end subroutine backward

pure module subroutine forward(self, input)
pure subroutine forward(self, input)
!! Cross Attention Forward propagation
!! Input Shape (kind, sequence_length, model_dimension)
!! where kind is 1 for Query and 2 for Key-Value
Expand All @@ -56,7 +52,7 @@ pure module subroutine forward(self, input)
call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :))
end subroutine forward

module subroutine init(self, input_shape)
subroutine init(self, input_shape)
class(cross_attention_layer), intent(in out) :: self
integer, intent(in) :: input_shape(:)

Expand Down
14 changes: 5 additions & 9 deletions src/nf/nf_self_attention_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ module nf_self_attention_layer
end type self_attention_layer

interface self_attention_layer
module function self_attention_layer_cons(n_heads) result(res)
!! This function returns the `self_attention_layer` instance.
integer, intent(in) :: n_heads
type(self_attention_layer) :: res
end function self_attention_layer_cons
module procedure self_attention_layer_cons
end interface self_attention_layer

contains
module function self_attention_layer_cons(n_heads) result(res)
function self_attention_layer_cons(n_heads) result(res)
!! This function returns the `self_attention_layer` instance.
integer, intent(in) :: n_heads
type(self_attention_layer) :: res
res % n_heads = n_heads
end function self_attention_layer_cons

pure module subroutine backward(self, input, gradient, attention_mask)
pure subroutine backward(self, input, gradient, attention_mask)
!! Self Attention back propagation
!! Returns sum of Query, Key and Value gradients
class(self_attention_layer), intent(in out) :: self
Expand All @@ -50,7 +46,7 @@ pure module subroutine backward(self, input, gradient, attention_mask)
+ self % value_layer % gradient
end subroutine backward

pure module subroutine forward(self, input)
pure subroutine forward(self, input)
!! Cross Attention forward propagation
!! Passes input three times into MultiHead Attention
!! Input Shape: (sequence_length, model_dimension)
Expand All @@ -60,7 +56,7 @@ pure module subroutine forward(self, input)
call self % common_forward(input, input, input)
end subroutine forward

module subroutine init(self, input_shape)
subroutine init(self, input_shape)
class(self_attention_layer), intent(in out) :: self
integer, intent(in) :: input_shape(:)

Expand Down