In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.cuda.is_available()

True

In [11]:
class StandardBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(StandardBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.elu1 = nn.ELU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.elu2 = nn.ELU()

    def forward(self, x):
        x = self.elu1(self.conv1(x))
        x = self.elu2(self.conv2(x))
        return x

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, channels, dilation):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(channels)
        self.elu1 = nn.ELU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(channels)
        self.elu2 = nn.ELU()
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)

    def forward(self, x):
        identity = x
        out = self.elu1(self.bn1(self.conv1(x)))
        out = self.elu2(self.bn2(self.conv2(out)))
        out = self.conv3(out)
        return out + identity

In [18]:
class OCTDenoisingNet(nn.Module):
    def __init__(self):
        super(OCTDenoisingNet, self).__init__()
        self.init_channels = 64

        # Downsampling Tower
        self.down1 = StandardBlock(self.init_channels, dilation=1)
        self.conv_down1 = nn.Conv2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1)

        self.down2 = ResidualBlock(self.init_channels, dilation=2)
        self.conv_down2 = nn.Conv2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1)

        self.down3 = ResidualBlock(self.init_channels, dilation=4)
        self.conv_down3 = nn.Conv2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1)

        # Bottleneck (latent space)
        self.bottleneck = StandardBlock(self.init_channels, dilation=1)

        # Upsampling Tower
        self.up1 = ResidualBlock(self.init_channels, dilation=4)
        self.tconv1 = nn.ConvTranspose2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.up2 = ResidualBlock(self.init_channels, dilation=4)
        self.tconv2 = nn.ConvTranspose2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.up3 = StandardBlock(self.init_channels, dilation=1)
        self.tconv3 = nn.ConvTranspose2d(self.init_channels, self.init_channels, kernel_size=3, stride=2, padding=1, output_padding=1)

        # Multi-scale fusion layers
        self.fuse_convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(self.init_channels, self.init_channels, kernel_size=1),
                nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False)
            ) for _ in range(4)
        ])

        # Final Output
        self.final_conv = nn.Conv2d(self.init_channels * 5, 1, kernel_size=1)

        self.tanh = nn.Tanh()

    def forward(self, x):
        skips = []

        # Downsampling
        x1 = self.down1(x)
        skips.append(x1)
        x = self.conv_down1(x1)

        x2 = self.down2(x)
        skips.append(x2)
        x = self.conv_down2(x2)

        x3 = self.down3(x)
        skips.append(x3)
        x = self.conv_down3(x3)

        # Bottleneck
        x4 = self.bottleneck(x)
        skips.append(x4)

        # Upsampling
        x = self.up1(x4)
        x = self.tconv1(x)
        x = x + F.interpolate(skips[2], size=x.shape[2:], mode='bilinear', align_corners=False)

        x = self.up2(x)
        x = self.tconv2(x)
        x = x + F.interpolate(skips[1], size=x.shape[2:], mode='bilinear', align_corners=False)

        x = self.up3(x)
        x = self.tconv3(x)
        x = x + F.interpolate(skips[0], size=x.shape[2:], mode='bilinear', align_corners=False)

        # Multi-scale fusion
        fused = [F.interpolate(self.fuse_convs[i](skips[i]), size=x.shape[2:], mode='bilinear', align_corners=False) for i in range(4)]
        x = torch.cat([x] + fused, dim=1)

        # Output layer
        x = self.final_conv(x)
        x = self.tanh(x)
        return x

In [19]:
model = OCTDenoisingNet()
model.eval()
dummy_input = torch.randn(1, 64, 496, 384)  # Simulated input with 64 feature channels
output = model(dummy_input)

In [20]:
output

tensor([[[[-0.0765, -0.0318, -0.0448,  ..., -0.0875, -0.0102, -0.0804],
          [-0.0575, -0.0031, -0.1226,  ..., -0.0551,  0.0615,  0.0162],
          [ 0.0191,  0.0627,  0.0969,  ..., -0.0339,  0.0275, -0.0084],
          ...,
          [ 0.0656, -0.1134, -0.0059,  ...,  0.0018, -0.1280, -0.0882],
          [-0.1242,  0.0196, -0.0886,  ..., -0.1279, -0.0266,  0.0512],
          [-0.0792, -0.0649, -0.1312,  ...,  0.1052, -0.0417, -0.0307]]]],
       grad_fn=<TanhBackward0>)