Skip to content

Commit

Permalink
residual_attention_layer and components_test
Browse files Browse the repository at this point in the history
  • Loading branch information
aguzel committed Apr 10, 2023
1 parent f049cfc commit 68d54a6
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
67 changes: 67 additions & 0 deletions odak/learn/models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,70 @@ def forward(self, x):
mean = torch.mean(x, dim = 1, keepdim = True)
result = (x - mean) * (var + eps).rsqrt() * self.k
return result

class residual_attention_layer(torch.nn.Module):
"""
An attention layer.
"""
def __init__(
self,
input_channels = 2,
output_channels = 2,
kernel_size = 3,
bias = False,
activation = torch.nn.ReLU()
):
"""
An attention layer class.
Parameters
----------
input_channels : int
Number of input channels.
mid_channels : int
Number of middle channels.
kernel_size : int
Kernel size.
bias : bool
Set to True to let convolutional layers have bias term.
activation : torch.nn
Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
"""
super().__init__()
self.activation = activation
self.convolution = torch.nn.Sequential(
torch.nn.Conv2d(
input_channels,
output_channels,
kernel_size = kernel_size,
padding = kernel_size // 2,
bias = bias
),
torch.nn.BatchNorm2d(output_channels),
)


def forward(self, x_1, x_2):
"""
Forward model.
Parameters
----------
x_1 : torch.tensor
First input data.
x_2 : torch.tensor
Seconnd input data.
Returns
----------
result : torch.tensor
Estimated output.
"""
x_1_out = self.convolution(x_1)
x_2_out = self.convolution(x_2)
result = self.activation(x_1_out + x_2_out) * x_1
return result

30 changes: 30 additions & 0 deletions test/test_learn_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import odak.learn.models.components as components
import torch
import sys

def test():
# test residual block
x = torch.randn(1, 2, 32, 32)
residual_inference= components.residual_layer()
y = residual_inference(x)
print(y.shape)
# test convolution layer
convolution_inference = components.convolution_layer()
y = convolution_inference(x)
print(y.shape)
# test double convolution layer
double_convolution_inference = components.double_convolution()
y = double_convolution_inference(x)
print(y.shape)
# test normalization layer
normalization_inference = components.normalization()
y = normalization_inference(x)
print(y.shape)
# test attention layer
residual_attention_layer_inference = components.residual_attention_layer()
y = residual_attention_layer_inference(x , x)
print(y.shape)
assert True == True

if __name__ == '__main__':
sys.exit(test())

0 comments on commit 68d54a6

Please sign in to comment.