In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import unittest


In [25]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels//8, 1)
        self.key = nn.Conv2d(in_channels, in_channels//8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, height * width)
        attention = F.softmax(torch.bmm(query, key), dim=-1)
        value = self.value(x).view(batch_size, -1, height * width)
        out = torch.bmm(value, attention.permute(0, 2, 1)).view(batch_size, channels, height, width)
        out = self.gamma * out + x

        return out

In [26]:
class TestSelfAttention(unittest.TestCase):
    def test_forward_pass(self):
        batch_size = 4
        in_channels = 16
        height = 32
        width = 32

        # Instantiate the SelfAttention module
        self_attention = SelfAttention(in_channels)

        # Create a random input tensor of shape (batch_size, in_channels, height, width)
        x = Variable(torch.rand(batch_size, in_channels, height, width))

        # Perform a forward pass through the SelfAttention module
        out = self_attention(x)

        # Check if the output has the same shape as the input tensor
        self.assertEqual(out.shape, (batch_size, in_channels, height, width))

    def test_gamma_parameter(self):
        in_channels = 16
        self_attention = SelfAttention(in_channels)
        gamma = self_attention.gamma.item()

        # Check if the initial value of gamma is 0
        self.assertEqual(gamma, 0)

    def test_query_key_value_shapes(self):
        in_channels = 16
        self_attention = SelfAttention(in_channels)

        # Check if the query, key, and value layers have the correct shapes
        self.assertEqual(self_attention.query.weight.shape, (in_channels // 8, in_channels, 1, 1))
        self.assertEqual(self_attention.key.weight.shape, (in_channels // 8, in_channels, 1, 1))
        self.assertEqual(self_attention.value.weight.shape, (in_channels, in_channels, 1, 1))


if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

...
----------------------------------------------------------------------
Ran 3 tests in 0.013s

OK


In [19]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [20]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

In [21]:
class AttentionResidualWaveUNet(nn.Module):
    def __init__(self, in_channels=1):
        super(AttentionResidualWaveUNet, self).__init__()

        # Encoder
        self.enc1 = EncoderBlock(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = EncoderBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = EncoderBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        # Middle
        self.middle = EncoderBlock(256, 512)
        self.attention = SelfAttention(512)

        # Decoder
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DecoderBlock(256*2, 128)
        self.up2 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2, output_padding=1)  # Add output_padding
        self.dec2 = DecoderBlock(128*2, 64)
        self.up1 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.dec1 = DecoderBlock(64*2, 1)

    def forward(self, x):
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool1(enc1)
        enc2 = self.enc2(pool1)
        pool2 = self.pool2(enc2)
        enc3 = self.enc3(pool2)
        pool3 = self.pool3(enc3)

        # Middle
        middle = self.middle(pool3)
        attention_out = self.attention(middle)

        # Decoder
        up3 = self.up3(attention_out)
        merge3 = torch.cat([up3, enc3], dim=1)
        dec3 = self.dec3(merge3)
        up2 = self.up2(dec3)
        merge2 = torch.cat([up2, enc2], dim=1)
        dec2 = self.dec2(merge2)
        up1 = self.up1(dec2)
        merge1 = torch.cat([up1, enc1], dim=1)
        dec1 = self.dec1(merge1)

        return dec1

In [22]:
# Test the model
if __name__ == "__main__":
    model = AttentionResidualWaveUNet()
    x = torch.randn(1, 1, 128, 128)
    y = model(x)
    print(y.shape)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 65 but got size 64 for tensor number 1 in the list.