In [1]:
import torch
from torch import nn
import nnj

In [31]:
def downblock_gen(in_chan, im_height, im_width):
    downblock = [
        nnj.Reshape(in_chan, im_height, im_width),
        nnj.MaxPool2d(2),
        nnj.Conv2d(in_chan, in_chan * 2, 3, stride=1, padding=1),
        nnj.Tanh(),
        nnj.Conv2d(in_chan * 2, in_chan * 2, 3, stride=1, padding=1),
        nnj.Tanh(),
        nnj.Flatten(),
    ]
    return downblock


def upblock_gen(in_chan, im_height, im_width):
    upblock = [
        nnj.Reshape(in_chan, im_height, im_width),
        nnj.Conv2d(in_chan, int(in_chan / 2), 3, stride=1, padding=1),
        nnj.Tanh(),
        nnj.Conv2d(int(in_chan / 2), int(in_chan / 4), 3, stride=1, padding=1),
        nnj.Upsample(scale_factor=2),
        nnj.Tanh(),
        nnj.Flatten(),
    ]
    return upblock


def middleblock_gen(in_chan, im_height, im_width):
    upblock = [
        nnj.Reshape(in_chan, im_height, im_width),
        nnj.MaxPool2d(2),
        nnj.Conv2d(in_chan, in_chan * 2, 3, stride=1, padding=1),
        nnj.Tanh(),
        nnj.Conv2d(in_chan * 2, in_chan, 3, stride=1, padding=1),
        nnj.Upsample(scale_factor=2),
        nnj.Tanh(),
        nnj.Flatten(),
    ]
    return upblock

In [34]:
im_height = 64
im_width = 64
multiplication_factor = 8

In [38]:
downblocks = [
    downblock_gen(2**i * multiplication_factor, im_height / 2**i, im_width / 2**i)
    for i in [0, 1, 2, 3]
]
upblocks = [
    upblock_gen(2**i * multiplication_factor * 2, 2 * im_height / 2**i, 2 * im_width / 2**i)
    for i in [0, 1, 2, 3]
]
net = nnj.Sequential(
    *downblocks[0],
    nnj.SkipConnection(
        *downblocks[1],
        nnj.SkipConnection(
            *downblocks[2],
            nnj.SkipConnection(
                *downblocks[3],
                nnj.SkipConnection(
                    *middleblock_gen(
                        2**3 * multiplication_factor, im_height / 2**3, im_width / 2**3
                    ),
                    add_hooks=True,
                ),
                *upblocks[3],
                add_hooks=True,
            ),
            *upblocks[2],
            add_hooks=True,
        ),
        *upblocks[1],
        add_hooks=True,
    ),
    *upblocks[0],
)

In [40]:
print(net)
print(upblocks[0])
print(downblocks[0])

Sequential(
  (0): Reshape()
  (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): Tanh()
  (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): Tanh()
  (6): Flatten()
  (7): SkipConnection(
    (_F): Sequential(
      (0): Reshape()
      (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): Tanh()
      (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): Tanh()
      (6): Flatten()
      (7): SkipConnection(
        (_F): Sequential(
          (0): Reshape()
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): Tanh()
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), 