In [15]:
import torch
import torch.nn as nn

In [16]:
X = torch.zeros((2, 1, 256, 256))

In [17]:
nn.Conv2d
nn.MaxPool2d
nn.ConvTranspose2d
nn.ReLU

torch.nn.modules.activation.ReLU

In [57]:
def cprint(key, x):
    print(f"{key:15s} {x.detach().numpy().shape}")

In [81]:
class Unet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self._build_encoder()
        self._build_decoder()
        self.final_layer = nn.Conv2d(32*2, n_classes, 1, stride=1, padding=0)
        self.activation = nn.ReLU()
        self.final_activation = nn.Sigmoid()
    
    def _build_encoder(self):
        print('Building encoder')
        self.conv_1a = nn.Conv2d(
            in_channels=1,
            out_channels=32,
            kernel_size=(3,3),
            stride=1,
            padding=1
        )
        self.conv_1b = nn.Conv2d(32, 32, 3, stride=1, padding=1)
        
        self.pool_2 = nn.MaxPool2d(2)
        self.conv_2a = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv_2b = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        
        self.pool_3 = nn.MaxPool2d(2)
        self.conv_3a = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_3b = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        
        self.pool_4 = nn.MaxPool2d(2)
        self.conv_4a = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_4b = nn.Conv2d(256, 256, 3, stride=1, padding=1)
    
    def _build_decoder(self): 
        print('Building decoder')
        self.up_3 = nn.ConvTranspose2d(256, 128, 2, stride=2, padding=0)
        self.D_conv_3a = nn.Conv2d(128*2, 128, 3, stride=1, padding=1)
        self.D_conv_3b = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        
        self.up_2 = nn.ConvTranspose2d(128, 64, 2, stride=2, padding=0)
        self.D_conv_2a = nn.Conv2d(64*2, 64, 3, stride=1, padding=1)
        self.D_conv_2b = nn.Conv2d(64, 64, 3, stride=1, padding=1)
        
        self.up_1 = nn.ConvTranspose2d(64, 32, 2, stride=2, padding=0)
        self.D_conv_1a = nn.Conv2d(32*2, 32, 3, stride=1, padding=1)
        self.D_conv_1b = nn.Conv2d(32, 32, 3, stride=1, padding=1)
     
        self.prev_state = None
        
    def forward(self, x):
        cprint('X', X)
        
        # Encoder
        print("Encoder")
        out = self.activation(self.conv_1a(x))
        cprint("E1a", out)
        E1_out = self.activation(self.conv_1b(out))
        cprint("E1b", E1_out)
        print("\n")
        
        out = self.pool_2(E1_out)
        cprint("pool_2", out)
        out = self.activation(self.conv_2a(out))
        cprint("E2a", out)
        E2_out = self.activation(self.conv_2b(out))
        cprint("E2b", E2_out)
        print("\n")
        
        out = self.pool_3(E2_out)
        cprint("pool_3", out)
        out = self.activation(self.conv_3a(out))
        cprint("E3a", out)
        E3_out = self.activation(self.conv_3b(out))
        cprint("E3b", E3_out)
        print("\n")
        
        out = self.pool_4(E3_out)
        cprint("pool_4", out)
        out = self.activation(self.conv_4a(out))
        cprint("E4a", out)
        out = self.activation(self.conv_4b(out))
        cprint("E4b", out)
        print("\n")
        
        # Decoder
        print("Decoder")
        out = self.up_3(out)
        out = torch.cat([out, E3_out], dim=1)
        cprint("up_3", out)
        out = self.activation(self.D_conv_3a(out))
        cprint("D3a", out)
        out = self.activation(self.D_conv_3b(out))
        cprint("D3b", out)
        print("\n")
        
        out = self.up_2(out)
        out = torch.cat([out, E2_out], dim=1)
        cprint("up_2", out)
        out = self.activation(self.D_conv_2a(out))
        cprint("D2a", out)
        out = self.activation(self.D_conv_2b(out))
        cprint("D2b", out)
        print("\n")
        
        out = self.up_1(out)
        out = torch.cat([out, E1_out], dim=1)
        cprint("up_1", out)
        out = self.activation(self.D_conv_1a(out))
        cprint("D1a", out)
        out = self.activation(self.D_conv_1b(out))
        cprint("D1b", out)
        
        # Time delay
        # ----------------------------------------
        if self.prev_state is not None:
            prev_state = self.prev_state
        else:
            prev_state = torch.zeros_like(out)
        self.prev_state = out.detach()
        out = torch.cat([out, prev_state], dim=1)
        # ----------------------------------------
        
        out = self.final_layer(out)
        cprint("out", out)
        return out

In [82]:
unet = Unet(3)

Building encoder
Building decoder


In [83]:
out = unet(X)
out.shape

X               (2, 1, 256, 256)
Encoder
E1a             (2, 32, 256, 256)
E1b             (2, 32, 256, 256)


pool_2          (2, 32, 128, 128)
E2a             (2, 64, 128, 128)
E2b             (2, 64, 128, 128)


pool_3          (2, 64, 64, 64)
E3a             (2, 128, 64, 64)
E3b             (2, 128, 64, 64)


pool_4          (2, 128, 32, 32)
E4a             (2, 256, 32, 32)
E4b             (2, 256, 32, 32)


Decoder
up_3            (2, 256, 64, 64)
D3a             (2, 128, 64, 64)
D3b             (2, 128, 64, 64)


up_2            (2, 128, 128, 128)
D2a             (2, 64, 128, 128)
D2b             (2, 64, 128, 128)


up_1            (2, 64, 256, 256)
D1a             (2, 32, 256, 256)
D1b             (2, 32, 256, 256)
out             (2, 3, 256, 256)


torch.Size([2, 3, 256, 256])

In [None]:
X.shape

In [10]:
X = torch.zeros((2, 3, 32, 32)) # CIFAR10

In [11]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=1)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        return out

In [12]:
cnn = CNN()

In [13]:
out = cnn(X)

In [14]:
out.shape

torch.Size([2, 16, 32, 32])