In [1]:
import torch
from torch import nn
import pytorch_lightning as pl

In [17]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=3,
                                stride=1,
                                padding=1)
        self.lr = nn.LeakyReLU()

    def forward(self, x):
        x = self.conv(x)
        out = self.lr(x)
        return out


class DBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lr = nn.LeakyReLU()
        self.conv1 = ConvBlock(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)
        self.mp = nn.MaxPool1d(kernel_size=2)
        pass

    def forward(self, x):
        print("x shape {}".format(x.shape))
        x = self.conv1(x)
        x = self.conv2(x)
        
        ctx = torch.clone(x)
        out = self.mp(x)       
        
        print("out {} ! ctx {}".format(out.shape, ctx.shape))
        return out, ctx

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)


    def forward(self, x):
        print("Bottleneck")
        print("x shape {}".format(x.shape))
        x = self.conv1(x)
        out = self.conv2(x)
        print("out shape {}".format(out.shape))
        return out


class UBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, out_channels)
        self.conv2 = ConvBlock(out_channels, out_channels)
        self.up_conv = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, stride=1, kernel_size=3, padding=1)
        )
        
        self.lr = nn.LeakyReLU()


    def add_ctx(self, x, ctx):
        # crop context (y)
        d_shape = (ctx.shape[-1]-x.shape[-1])//2 
        crop = ctx[:,:, d_shape:d_shape+x.shape[2]]
        #concatenate
        out = torch.cat([x, crop], 1) 
        return out



    def forward(self, x, ctx):
        print("x shape {} ! ctx {}".format(x.shape, ctx.shape))
        x = self.up_conv(x)
        print("x shape {} ! ctx {}".format(x.shape, ctx.shape))
        x = self.add_ctx(x, ctx)
        print("x shape {}".format(x.shape))
        x = self.conv1(x)
        out = self.conv2(x)
        print("out shape {}".format(out.shape))
        return out




In [22]:
class UNet(pl.LightningModule):
    def __init__(self, channels):
        super().__init__()
        self.save_hyperparameters()
        self.channels = channels
        self.bottleneck = Bottleneck(channels[-2], channels[-1])
        self.end = ConvBlock(in_channels=channels[1], out_channels=channels[0])
        
    def down_sampling(self, x):
        l_ctx = []
        for i in range(len(self.channels)-2):
            print("DownBlock __{} : {} -> {}".format(i, self.channels[i], self.channels[i+1]))
            x, ctx = DBlock(in_channels=self.channels[i], out_channels=self.channels[i+1])(x)
            l_ctx = [ctx] + l_ctx
        return x, l_ctx
    
    def up_sampling(self, x, l_ctx):
        reverse = self.channels[::-1] # reverse without the bottleneck channel
        for i in range(len(reverse)-2):
            print("UpBlock __{} : {} -> {}".format(i, reverse[i], reverse[i+1]))
            x = UBlock(in_channels=reverse[i], out_channels=reverse[i+1])(x, l_ctx[i])
        
        return x


    def forward(self, x):
        x, l_ctx = self.down_sampling(x)
        print("LEN  : ", len(l_ctx))
        x = self.bottleneck(x)
        x = self.up_sampling(x, l_ctx)
        out = self.end(x)

        return out


In [25]:
t = torch.randn(2, 2, 16)


unet = UNet([2, 64, 128, 256, 512, 1024])
out = unet(t)
print(out.shape)


DownBlock __0 : 2 -> 64
x shape torch.Size([2, 2, 16])
out torch.Size([2, 64, 8]) ! ctx torch.Size([2, 64, 16])
DownBlock __1 : 64 -> 128
x shape torch.Size([2, 64, 8])
out torch.Size([2, 128, 4]) ! ctx torch.Size([2, 128, 8])
DownBlock __2 : 128 -> 256
x shape torch.Size([2, 128, 4])
out torch.Size([2, 256, 2]) ! ctx torch.Size([2, 256, 4])
DownBlock __3 : 256 -> 512
x shape torch.Size([2, 256, 2])
out torch.Size([2, 512, 1]) ! ctx torch.Size([2, 512, 2])
LEN  :  4
Bottleneck
x shape torch.Size([2, 512, 1])
out shape torch.Size([2, 1024, 1])
UpBlock __0 : 1024 -> 512
x shape torch.Size([2, 1024, 1]) ! ctx torch.Size([2, 512, 2])
x shape torch.Size([2, 512, 2]) ! ctx torch.Size([2, 512, 2])
x shape torch.Size([2, 1024, 2])
out shape torch.Size([2, 512, 2])
UpBlock __1 : 512 -> 256
x shape torch.Size([2, 512, 2]) ! ctx torch.Size([2, 256, 4])
x shape torch.Size([2, 256, 4]) ! ctx torch.Size([2, 256, 4])
x shape torch.Size([2, 512, 4])
out shape torch.Size([2, 256, 4])
UpBlock __2 : 256 