In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
#
from UNet.blocks import *

## BasicConvBlock

In [None]:
c_in = 32
c_out = 64
n_blocks = 2
#
x = torch.rand((10, 32, 128, 128))
#
model = BasicConvBlock(32, 64, 2, True, True)
y = model(x)

print(x.shape)
print(y.shape)
model

In [None]:
x = torch.rand((10, 32, 128, 128))
model = BasicConvBlock(32, 32, 2, False, False)
model

## DownBlock

In [None]:
x = torch.rand((10, 64, 256, 256))
#
c_in = 64
c_out = 128
#
n_blocks = 2
#
model = DownBlock(c_in, c_out, n_blocks)
y = model(x)
#
assert y.shape == ((10, 128, 128, 128))
print(x.shape)
print(y.shape)
model

## Encoder

In [None]:
# INPUT
img_size = 64
bs = 10
n_c = 3

# INPUT
x = torch.rand((bs, n_c, img_size, img_size))
#
chs_tail = [n_c, 4]
chs_down = [4, 8, 16, 32, 64, 128, 256]
n_conv_blocks = 1
#
model = Encoder(chs_tail, chs_down, n_conv_blocks=n_conv_blocks)
model

In [None]:
xx = model(x)
for x in xx:
    print(x.shape)

## UpBlock

In [None]:
x_down = torch.rand((10, 64, 4, 4))
x_side = torch.rand((10, 32, 8, 8))
c_down = 64
c_out = 32

#x_down = torch.rand((10, 256, 1, 1))
#x_side = torch.rand((10, 128, 2, 2))
#c_down = 256
#c_out = 128

model = UpBlock(c_down, c_out, 1)
#
x = model.up(x_down)
x.shape

print(x_down.shape)
print(x.shape)
print(x_side.shape)
x = torch.cat([x, x_side], dim=1)
print(x.shape)
x = model.convs(x)
print(x.shape)

In [None]:
x = model(x_down, x_side)
print(x.shape)

## Decoder

In [None]:
# INPUT
img_size = 64
bs = 10
n_c = 3

# INPUT
x = torch.rand((bs, n_c, img_size, img_size))
#
chs_tail = [n_c, 4]
chs_down = [4, 8, 16, 32, 64, 128, 256]
chs_up = chs_down[::-1]
chs_head = [4, 1]
n_conv_blocks = 1
#
encoder = Encoder(chs_tail, chs_down, n_conv_blocks=n_conv_blocks)
decoder = Decoder(chs_head, chs_up, n_conv_blocks=n_conv_blocks)

In [None]:
xx = encoder(x)
for x in xx:
    print(x.shape)

In [None]:
x = decoder(xx[::-1])
print(x.shape)

# UNet

In [None]:
class UNet(nn.Module):
    def __init__(self, chs_tail, chs_down, chs_up, chs_head, n_conv_blocks):
        super().__init__()
        self.encoder = Encoder(chs_tail, chs_down, n_conv_blocks)
        self.decoder = Decoder(chs_head, chs_up, n_conv_blocks)
    
    def forward(self, x):
        xx = self.encoder(x)
        return self.decoder(xx[::-1])

In [None]:
# INPUT
img_size = 64
bs = 10
n_c = 3

# INPUT
x = torch.rand((bs, n_c, img_size, img_size))
#
chs_tail = [n_c, 4]
chs_down = [4, 8, 16, 32, 64]
chs_up = chs_down[::-1]
chs_head = [4, 1]
n_conv_blocks = 1
#
model = UNet(chs_tail, chs_down, chs_up, chs_head, n_conv_blocks)