Skip to content

Commit 678b2c0

Browse files
committed
Require passing only out_features to linear2d(); tidy up
1 parent 6f33ebe commit 678b2c0

File tree

6 files changed

+24
-14
lines changed

6 files changed

+24
-14
lines changed

example/linear2d.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ program linear2d_example
1616

1717
net = network([ &
1818
input(3, 4), &
19-
linear2d(3, 1), &
19+
linear2d(1), &
2020
flatten() &
2121
])
2222

src/nf/nf_layer_constructors.f90

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,14 @@ module function reshape(output_shape) result(res)
185185
!! Resulting layer instance
186186
end function reshape
187187

188-
module function linear2d(sequence_length, out_features) result(res)
189-
integer, intent(in) :: sequence_length, out_features
188+
module function linear2d(out_features) result(res)
189+
!! Rank-2 (sequence_length, out_features) linear layer constructor.
190+
!! sequence_length is determined at layer initialization, based on the
191+
!! output shape of the previous layer.
192+
integer, intent(in) :: out_features
193+
!! Number of output features
190194
type(layer) :: res
195+
!! Resulting layer instance
191196
end function linear2d
192197

193198
end interface

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,14 @@ module function reshape(output_shape) result(res)
150150

151151
end function reshape
152152

153-
module function linear2d(sequence_length, out_features) result(res)
154-
integer, intent(in) :: sequence_length, out_features
153+
154+
module function linear2d(out_features) result(res)
155+
integer, intent(in) :: out_features
155156
type(layer) :: res
156157

157158
res % name = 'linear2d'
158-
res % layer_shape = [sequence_length, out_features]
159159
allocate(res % p, source=linear2d_layer(out_features))
160+
160161
end function linear2d
161162

162163
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ impure elemental module subroutine init(self, input)
301301
call this_layer % init(input % layer_shape)
302302
end select
303303

304-
! The shape of conv2d, maxpool2d, or flatten layers is not known
304+
! The shape of linear2d, conv2d, maxpool2d, or flatten layers is not known
305305
! until we receive an input layer.
306306
select type(this_layer => self % p)
307307
type is(conv2d_layer)
@@ -310,6 +310,8 @@ impure elemental module subroutine init(self, input)
310310
self % layer_shape = shape(this_layer % output)
311311
type is(flatten_layer)
312312
self % layer_shape = shape(this_layer % output)
313+
type is(linear2d_layer)
314+
self % layer_shape = shape(this_layer % output)
313315
end select
314316

315317
self % input_layer_shape = input % layer_shape

src/nf/nf_linear2d_layer.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ module nf_linear2d_layer
1111
type, extends(base_layer) :: linear2d_layer
1212
integer :: sequence_length, in_features, out_features, batch_size
1313

14-
real, allocatable :: weights(:, :)
14+
real, allocatable :: weights(:,:)
1515
real, allocatable :: biases(:)
16-
real, allocatable :: output(:, :)
17-
real, allocatable :: gradient(:, :) ! input gradient
18-
real, allocatable :: dw(:, :) ! weight gradients
16+
real, allocatable :: output(:,:)
17+
real, allocatable :: gradient(:,:) ! input gradient
18+
real, allocatable :: dw(:,:) ! weight gradients
1919
real, allocatable :: db(:) ! bias gradients
2020

2121
contains
@@ -40,13 +40,13 @@ end function linear2d_layer_cons
4040
interface
4141
pure module subroutine forward(self, input)
4242
class(linear2d_layer), intent(in out) :: self
43-
real, intent(in) :: input(:, :)
43+
real, intent(in) :: input(:,:)
4444
end subroutine forward
4545

4646
pure module subroutine backward(self, input, gradient)
4747
class(linear2d_layer), intent(in out) :: self
48-
real, intent(in) :: input(:, :)
49-
real, intent(in) :: gradient(:, :)
48+
real, intent(in) :: input(:,:)
49+
real, intent(in) :: gradient(:,:)
5050
end subroutine backward
5151

5252
module subroutine init(self, input_shape)

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ module function linear2d_layer_cons(out_features) result(res)
1010
type(linear2d_layer) :: res
1111

1212
res % out_features = out_features
13+
1314
end function linear2d_layer_cons
1415

1516

@@ -34,6 +35,7 @@ module subroutine init(self, input_shape)
3435

3536
allocate(self % dw(self % in_features, self % out_features))
3637
self % dw = 0
38+
3739
allocate(self % db(self % out_features))
3840
self % db = 0
3941

0 commit comments

Comments
 (0)