In [1]:
import torch
from torch import nn
import pytorch_lightning as pl

In [89]:
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lr = nn.LeakyReLU()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.mp = nn.MaxPool1d(kernel_size=2)
        pass

    def forward(self, x):
        x = self.conv1d(x)
        x = self.lr(x)

        ctx = torch.clone(x)
        out = self.mp(x)       
        print("Down Sampling")
        print("out {} ! ctx {}".format(out.shape, ctx.shape))

        return out, ctx

class UBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.up_sampling = nn.Sequential(
            nn.Upsample(scale_factor=2),
            self.conv1d
        )
        
        self.lr = nn.LeakyReLU()


    def add_ctx(self, x, ctx):
        # crop context (y)
        d_shape = [(ctx.shape[i]-x.shape[i])//2 for i in range(len(x.shape))]
        crop = ctx[d_shape[0]:d_shape[0]+x.shape[0],
            d_shape[1]:d_shape[1]+x.shape[1],
            d_shape[2]:d_shape[2]+x.shape[2],
        ]
        #concatenate
        out = torch.cat([x, crop], 1) 
        return out



    def forward(self, x, ctx):
        print("Upsampling")
        print("x shape {} ! ctx {}".format(x.shape, ctx.shape))
        x = self.up_sampling(x)
        out = self.add_ctx(x, ctx)
        x = self.conv1d(x)
        x = self.lr(x)

        return out




In [93]:
class UNet(pl.LightningModule):
    def __init__(self, channels):
        super().__init__()
        self.save_hyperparameters()
        self.channels = channels
        self.mid_conv = nn.Conv1d(in_channels=channels[-2],
                                out_channels=channels[-1],
                                kernel_size=3,
                                stride=1,
                                padding=1)

        
    def down_sampling(self, x):
        l_ctx = []
        for i in range(len(self.channels)-1):
            x, ctx = DBlock(in_channels=self.channels[i], out_channels=self.channels[i+1])(x)
            l_ctx.append(ctx)
        return x, l_ctx
    
    def up_sampling(self, x, l_ctx):
        for i in range(len(self.channels)-1, 0, -1):
            print(i)
            print("in_channels={}, out_channels={}".format(self.channels[i], self.channels[i-1]))
            x = UBlock(in_channels=self.channels[i], out_channels=self.channels[i-1])(x, l_ctx[i-1])
        
        return x


    def forward(self, x):
        x, l_ctx = self.down_sampling(x)

        x = self.mid_conv(x)

        out = self.up_sampling(x, l_ctx)
        return out


In [94]:
t = torch.randn(2, 2, 10)


unet = UNet([2,4,8])
out = unet(t)


Down Sampling
x shape torch.Size([2, 4, 5]) ! ctx torch.Size([2, 4, 10])
1
in_channels=4, out_channels=2


RuntimeError: Given groups=1, weight of size [2, 4, 3], expected input[2, 8, 10] to have 4 channels, but got 8 channels instead