# imports

In [8]:
from collections import OrderedDict

import numpy as np
import nni
import torch

import nni.retiarii.nn.pytorch as nn
import pytorch_lightning as pl

from nni import trace
from nni.retiarii import model_wrapper, fixed_arch
from nni.retiarii.nn.pytorch import Cell
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategy import DARTS as DartsStrategy
from nni.retiarii.evaluator.pytorch import Lightning, LightningModule, Trainer
from nni.retiarii.evaluator.pytorch.lightning import DataLoader
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.utilities.types import STEP_OUTPUT

from torch import optim, tensor, zeros_like
from typing import Any

from torch.utils.data import Dataset

Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(parent)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  __import__("pkg_resources").declare_namespace(__name__)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_na

# space

In [29]:
def conv_2d(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_out, kernel_size=kernel_size, dilation=dilation, padding=padding, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
    )

def depthwise_separable_conv(C_in, C_out, kernel_size=3, dilation=1, padding=1, activation=None):
    return nn.Sequential(
        nn.Conv2d(C_in, C_in, kernel_size=kernel_size, dilation=dilation, padding=padding, groups=C_in, bias=False),
        nn.Conv2d(C_in, C_out, 1, bias=False),
        nn.BatchNorm2d(C_out),
        nn.ReLU() if activation is None else activation,
    )

def pools():
    pool_dict = OrderedDict([
        ("MaxPool2d", nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
        # ("AvgPool2d", nn.AvgPool2d(kernel_size=2, stride=2, padding=0)),
        # ("DepthToSpace", nn.PixelShuffle(2)),
    ])
    return pool_dict

def upsamples():
    upsample_dict = OrderedDict([
        ("Upsample_nearest", nn.Upsample(scale_factor=2, mode='nearest')),
        ("Upsample_bilinear", nn.Upsample(scale_factor=2, mode='bilinear')),

    ])
    return upsample_dict

def convs(C_in, C_out):
    # all padding should follow this formula:
    # pd = (ks - 1) * dl // 2
    conv_dict = OrderedDict([
        
        ("conv2d_1x1_Relu", conv_2d(C_in, C_out)),
        # ("conv2d_1x1_SiLU", conv_2d(C_in, C_out, activation=nn.SiLU())),

        # ("conv2d_3x3_Relu", conv_2d(C_in, C_out, kernel_size=3, padding=1)),
        # ("conv2d_3x3_SiLU", conv_2d(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),

        # ("conv2d_5x5_Relu", conv_2d(C_in, C_out, kernel_size=5, padding=2)),
        # ("conv2d_5x5_SiLU", conv_2d(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),


        ("convDS_1x1_Relu", depthwise_separable_conv(C_in, C_out)),
        # ("convDS_1x1_SiLU", depthwise_separable_conv(C_in, C_out, activation=nn.SiLU())),

        # ("convDS_3x3_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1)),
        # # ("convDS_3x3_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=3, padding=1, activation=nn.SiLU())),

        # # ("convDS_5x5_Relu", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2)),
        # ("convDS_5x5_SiLU", depthwise_separable_conv(C_in, C_out, kernel_size=5, padding=2, activation=nn.SiLU())),
    ])
    return conv_dict

@model_wrapper
class Autoencoder(nn.Module):
    def __init__(self,in_channels, out_channels, depth=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.depth = depth
            
        self.max_mid_channels = 128

        # Encoder
        self.encoder = nn.ModuleList()
        self.encoder.append(Cell(convs(self.in_channels, self.max_mid_channels), num_nodes=1, num_ops_per_node=1, num_predecessors=1, label=f'encoder conv {1}'))

        for i in range(self.depth-1):

            next_in_channels = int(self.max_mid_channels*(2**-i))
            next_out_channels = int(self.max_mid_channels*(2**-(i+1)))
            
            self.encoder.append(Cell(convs(next_in_channels, next_out_channels), num_nodes=1, num_ops_per_node=1, num_predecessors=1, label=f'encoder conv {i+2}'))

        # Decoder
        self.decoder = nn.ModuleList()
        for i in range(self.depth-2,-1,-1):
            
            next_in_channels = int(self.max_mid_channels*(2**-(i+1)))
            next_out_channels = int(self.max_mid_channels*(2**-i))

            self.decoder.append(Cell(convs(next_in_channels, next_out_channels), num_nodes=1, num_ops_per_node=1, num_predecessors=1, label=f'decoder conv {i+1}'))

        self.decoder.append(Cell(convs(self.max_mid_channels, self.out_channels), num_nodes=1, num_ops_per_node=1, num_predecessors=1, label=f'decoder conv {depth}'))

    def forward(self,x):
        

        for i in range(self.depth):
            x = self.encoder[i](x)
        
        for i in range(self.depth):
            x = self.decoder[i](x)

        return x
    
    def test(self):

        x = torch.randn(1,self.in_channels,128,128)
        y = self.forward(x)
        print(f'output: {y.shape}')

model = Autoencoder(in_channels=784, out_channels=4, depth=4)

model.test()





input shape: torch.Size([1, 784, 128, 128])
encoder layer 1 shape: torch.Size([1, 128, 128, 128])
encoder layer 2 shape: torch.Size([1, 64, 128, 128])
encoder layer 3 shape: torch.Size([1, 32, 128, 128])
encoder layer 4 shape: torch.Size([1, 16, 128, 128])
decoder layer 5 shape: torch.Size([1, 32, 128, 128])
decoder layer 6 shape: torch.Size([1, 64, 128, 128])
decoder layer 7 shape: torch.Size([1, 128, 128, 128])
decoder layer 8 shape: torch.Size([1, 4, 128, 128])
output: torch.Size([1, 4, 128, 128])


# eval