# Fussing about with a spectral U-net

In [1]:
# imports
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary

## Generate a 2x2 tensor

In [2]:
input = torch.rand(1,1,64,64)

In [3]:
input.shape

torch.Size([1, 1, 64, 64])

## Convolve

In [10]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, (1,3), padding=(0,1)),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, (1,3), padding=(0,1)),
        nn.ReLU(inplace=True)
    )   

class xUNet(nn.Module):

    def __init__(self):
        super(xUNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.dconv_down1 = double_conv(1, 16)
        self.dconv_down2 = double_conv(16, 32)
        self.dconv_down3 = double_conv(32, 64)
        #
        self.maxpool = nn.MaxPool2d((1,2))
        #
        self.upsample = nn.Upsample(scale_factor=(1,2), mode='bilinear', align_corners=True)        
        
        self.dconv_up2 = double_conv(32 + 64, 32)
        self.dconv_up1 = double_conv(16 + 32, 16)

        self.conv_last = nn.Conv2d(16, 1, 1)  # 
        
    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        x = self.dconv_down3(x)   
        
        # Come back up

        x = self.upsample(x)
        #import pdb; pdb.set_trace()

        x = torch.cat([x, conv2], dim=1)
        
        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1) 
        
        x = self.dconv_up1(x)
        
        out = self.conv_last(x)
        
        return out

In [11]:
xunet = xUNet()

In [12]:
summary(xunet, input_size=(1, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 64, 64]              64
              ReLU-2           [-1, 16, 64, 64]               0
            Conv2d-3           [-1, 16, 64, 64]             784
              ReLU-4           [-1, 16, 64, 64]               0
         MaxPool2d-5           [-1, 16, 64, 32]               0
            Conv2d-6           [-1, 32, 64, 32]           1,568
              ReLU-7           [-1, 32, 64, 32]               0
            Conv2d-8           [-1, 32, 64, 32]           3,104
              ReLU-9           [-1, 32, 64, 32]               0
        MaxPool2d-10           [-1, 32, 64, 16]               0
           Conv2d-11           [-1, 64, 64, 16]           6,208
             ReLU-12           [-1, 64, 64, 16]               0
           Conv2d-13           [-1, 64, 64, 16]          12,352
             ReLU-14           [-1, 64,



In [12]:
xunet(input).shape



> <ipython-input-10-fa67e0d1d38a>(42)forward()
-> x = torch.cat([x, conv2], dim=1)
(Pdb) x.shape
torch.Size([1, 64, 64, 32])
(Pdb) conv2.shape
torch.Size([1, 32, 64, 32])
(Pdb) c


torch.Size([1, 1, 64, 64])