In [1]:
import torch
from torch import nn
import pytorch_lightning as pl

In [108]:
class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lr = nn.LeakyReLU()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.mp = nn.MaxPool1d(kernel_size=2)
        pass

    def forward(self, x):
        print("DownBlock")
        print("x shape {}".format(x.shape))
        x = self.conv1d(x)
        x = self.lr(x)

        ctx = torch.clone(x)
        out = self.mp(x)       
        
        print("out {} ! ctx {}".format(out.shape, ctx.shape))

        return out, ctx

class UBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)

        self.up_sampling = nn.Sequential(
            nn.Upsample(scale_factor=2),
        )
        
        self.lr = nn.LeakyReLU()


    def add_ctx(self, x, ctx):
        # crop context (y)
        d_shape = (ctx.shape[-1]-x.shape[-1])//2 
        crop = ctx[:,:, d_shape:d_shape+x.shape[2]]
        #concatenate
        out = torch.cat([x, crop], 1) 
        return out



    def forward(self, x, ctx):
        print("UpBlock")
        print("x shape {} ! ctx {}".format(x.shape, ctx.shape))

        x = self.up_sampling(x)
        out = self.add_ctx(x, ctx)
        x = self.conv1d(x)
        x = self.lr(x)
        print("out shape {}".format(x.shape))

        return out




In [109]:
class UNet(pl.LightningModule):
    def __init__(self, channels):
        super().__init__()
        self.save_hyperparameters()
        self.channels = channels
        self.mid_conv = nn.Conv1d(in_channels=channels[-2],
                                out_channels=channels[-1],
                                kernel_size=3,
                                stride=1,
                                padding=1)

        
    def down_sampling(self, x):
        l_ctx = []
        for i in range(len(self.channels)-1):
            x, ctx = DBlock(in_channels=self.channels[i], out_channels=self.channels[i+1])(x)
            l_ctx.append(ctx)
        return x, l_ctx
    
    def up_sampling(self, x, l_ctx):
        for i in range(len(self.channels)-1, 0, -1):
            x = UBlock(in_channels=self.channels[i], out_channels=self.channels[i-1])(x, l_ctx[i-1])
        
        return x


    def forward(self, x):
        x, l_ctx = self.down_sampling(x)

        #x = self.mid_conv(x)

        out = self.up_sampling(x, l_ctx)
        return out


In [112]:
t = torch.randn(2, 2, 16)


unet = UNet([2,64,128, 256])
out = unet(t)


DownBlock
x shape torch.Size([2, 2, 16])
out torch.Size([2, 64, 8]) ! ctx torch.Size([2, 64, 16])
DownBlock
x shape torch.Size([2, 64, 8])
out torch.Size([2, 128, 4]) ! ctx torch.Size([2, 128, 8])
DownBlock
x shape torch.Size([2, 128, 4])
out torch.Size([2, 256, 2]) ! ctx torch.Size([2, 256, 4])
UpBlock
x shape torch.Size([2, 256, 2])
Upsampling
x shape torch.Size([2, 256, 4]) ! ctx torch.Size([2, 256, 4])
out shape torch.Size([2, 128, 4])
UpBlock
x shape torch.Size([2, 512, 4])
Upsampling
x shape torch.Size([2, 512, 8]) ! ctx torch.Size([2, 128, 8])


RuntimeError: Given groups=1, weight of size [64, 128, 3], expected input[2, 512, 8] to have 128 channels, but got 512 channels instead