In [90]:

from collections import OrderedDict
import torch
import nni.retiarii.nn.pytorch as nn
from nni import trace
from nni.retiarii.nn.pytorch import Cell
from components import attention, pools, upsamples, convs, transposed_conv_2d

nodes = 1
filters = 64
depth = 4

enConvList = nn.ModuleList()
decConvList = nn.ModuleList()
upList = nn.ModuleList()

enConvList.append(
    Cell(
         op_candidates=convs(1,filters),
         num_nodes=1,
         num_ops_per_node=1,
         num_predecessors=1,
         label=f"encoder 1"
        ))

for i in range(depth-1):
    enConvList.append(
        Cell(
             op_candidates=convs(filters,filters*2//nodes),
             num_nodes=nodes,
             num_ops_per_node=1,
             num_predecessors=1,
             label=f"encoder {i+2}"
            ))
    filters *= 2

bottleneck = Cell(
                 op_candidates=convs(filters,filters*2//nodes),
                 num_nodes=nodes,
                 num_ops_per_node=1,
                 num_predecessors=1,
                 label=f"bottleneck"
                )

for i in range(depth):
    upList.append(nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2))
    decConvList.append(
        Cell(
             op_candidates=convs(filters*2,filters//nodes),
             num_nodes=nodes,
             num_ops_per_node=1,
             num_predecessors=1,
             label=f"decoder {i+1}"
            ))
    filters //= 2

outconv = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)

# put random input here
input = torch.randn(1, 1, 64, 64)

skips = []

print(f'Input shape: {input.shape}\n')
for enconv in enConvList:
    input = enconv([input])
    skips.append(input)
    print(f'Conv {enconv.label} shape: {input.shape}')
    input = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)(input)
    print(f'Pool {enconv.label} shape: {input.shape}\n')

outout = bottleneck([input])
print(f'Conv {bottleneck.label} shape: {outout.shape}\n')

skips = skips[::-1]
for i, deconv, ups in zip(range(depth), decConvList, upList):
    outout = ups(outout)
    print(f'Upsample {deconv.label} shape: {outout.shape}')

    print(f'\t\tupsam shape: {outout.shape}')
    print(f'\t\tskips shape: {skips[i].shape}')
    outout = torch.cat((outout, skips[i]), dim=1)
    print(f'\t\tconcat shape: {outout.shape}')
    outout = deconv([outout])
    print(f'Conv {deconv.label} shape: {outout.shape}\n\n')

print(f'Enter the last conv layer\n')
outout = outconv(outout)
print(f'Out shape: {outout.shape}\n\n')




Input shape: torch.Size([1, 1, 64, 64])

Conv encoder 1 shape: torch.Size([1, 64, 64, 64])
Pool encoder 1 shape: torch.Size([1, 64, 32, 32])

Conv encoder 2 shape: torch.Size([1, 128, 32, 32])
Pool encoder 2 shape: torch.Size([1, 128, 16, 16])

Conv encoder 3 shape: torch.Size([1, 256, 16, 16])
Pool encoder 3 shape: torch.Size([1, 256, 8, 8])

Conv encoder 4 shape: torch.Size([1, 512, 8, 8])
Pool encoder 4 shape: torch.Size([1, 512, 4, 4])

Conv bottleneck shape: torch.Size([1, 1024, 4, 4])

Upsample decoder 1 shape: torch.Size([1, 512, 8, 8])
		upsam shape: torch.Size([1, 512, 8, 8])
		skips shape: torch.Size([1, 512, 8, 8])
		concat shape: torch.Size([1, 1024, 8, 8])
Conv decoder 1 shape: torch.Size([1, 512, 8, 8])


Upsample decoder 2 shape: torch.Size([1, 256, 16, 16])
		upsam shape: torch.Size([1, 256, 16, 16])
		skips shape: torch.Size([1, 256, 16, 16])
		concat shape: torch.Size([1, 512, 16, 16])
Conv decoder 2 shape: torch.Size([1, 256, 16, 16])


Upsample decoder 3 shape: torc



# add pools

In [97]:

nodes = 4
filters = 64
depth = 4

enConvList = nn.ModuleList()
decConvList = nn.ModuleList()
upList = nn.ModuleList()
poolList = nn.ModuleList()

enConvList.append(
    Cell(
         op_candidates=convs(1,filters),
         num_nodes=1,
         num_ops_per_node=1,
         num_predecessors=1,
         label=f"encoder 1"
        ))
poolList.append(
    Cell(
         op_candidates=pools(),
         num_nodes=1,
         num_ops_per_node=1,
         num_predecessors=1,
         label=f"pool 1"
        ))

for i in range(depth-1):
    poolList.append(
        Cell(
            op_candidates=pools(),
            num_nodes=1,
            num_ops_per_node=1,
            num_predecessors=1,
            label=f"pool {i+2}"
            ))
    enConvList.append(
        Cell(
             op_candidates=convs(filters,filters*2//nodes),
             num_nodes=nodes,
             num_ops_per_node=1,
             num_predecessors=1,
             label=f"encoder {i+2}"
            ))
    filters *= 2

bottleneck = Cell(
                 op_candidates=convs(filters,filters*2//nodes),
                 num_nodes=nodes,
                 num_ops_per_node=1,
                 num_predecessors=1,
                 label=f"bottleneck"
                )

for i in range(depth):
    upList.append(
        Cell(
            op_candidates=upsamples(filters*2,filters),
            num_nodes=1,
            num_ops_per_node=1,
            num_predecessors=1,
            label=f"upsample {i+1}"
            ))
    decConvList.append(
        Cell(
             op_candidates=convs(filters*2,filters//nodes),
             num_nodes=nodes,
             num_ops_per_node=1,
             num_predecessors=1,
             label=f"decoder {i+1}"
            ))
    filters //= 2

outconv = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1)

# put random input here
input = torch.randn(1, 1, 64, 64)

skips = []

print(f'Input shape: {input.shape}\n')
for enconv, pl in zip(enConvList, poolList):
    input = enconv([input])
    skips.append(input)
    print(f'Conv {enconv.label} shape: {input.shape}')
    input = pl(input)
    print(f'Pool {enconv.label} shape: {input.shape}\n')

outout = bottleneck([input])
print(f'Conv {bottleneck.label} shape: {outout.shape}\n')

skips = skips[::-1]
for i, deconv, ups in zip(range(depth), decConvList, upList):
    outout = ups(outout)
    print(f'Upsample {deconv.label} shape: {outout.shape}')

    print(f'\t\tupsam shape: {outout.shape}')
    print(f'\t\tskips shape: {skips[i].shape}')
    outout = torch.cat((outout, skips[i]), dim=1)
    print(f'\t\tconcat shape: {outout.shape}')
    outout = deconv([outout])
    print(f'Conv {deconv.label} shape: {outout.shape}\n\n')

print(f'Enter the last conv layer\n')
outout = outconv(outout)
print(f'Out shape: {outout.shape}\n\n')

Input shape: torch.Size([1, 1, 64, 64])

Conv encoder 1 shape: torch.Size([1, 64, 64, 64])
Pool encoder 1 shape: torch.Size([1, 64, 32, 32])

Conv encoder 2 shape: torch.Size([1, 128, 32, 32])
Pool encoder 2 shape: torch.Size([1, 128, 16, 16])

Conv encoder 3 shape: torch.Size([1, 256, 16, 16])
Pool encoder 3 shape: torch.Size([1, 256, 8, 8])

Conv encoder 4 shape: torch.Size([1, 512, 8, 8])
Pool encoder 4 shape: torch.Size([1, 512, 4, 4])

Conv bottleneck shape: torch.Size([1, 1024, 4, 4])

Upsample decoder 1 shape: torch.Size([1, 512, 8, 8])
		upsam shape: torch.Size([1, 512, 8, 8])
		skips shape: torch.Size([1, 512, 8, 8])
		concat shape: torch.Size([1, 1024, 8, 8])
Conv decoder 1 shape: torch.Size([1, 512, 8, 8])


Upsample decoder 2 shape: torch.Size([1, 256, 16, 16])
		upsam shape: torch.Size([1, 256, 16, 16])
		skips shape: torch.Size([1, 256, 16, 16])
		concat shape: torch.Size([1, 512, 16, 16])
Conv decoder 2 shape: torch.Size([1, 256, 16, 16])


Upsample decoder 3 shape: torc

