In [11]:
import argparse
import copy
import sys

sys.path.append('../../')
import sopa.src.models.odenet_cifar10.layers as cifar10_models
from sopa.src.models.odenet_cifar10.utils import *

In [12]:
parser = argparse.ArgumentParser()
# Architecture params
parser.add_argument('--is_odenet', type=eval, default=True, choices=[True, False])
parser.add_argument('--network', type=str, choices=['metanode34', 'metanode18', 'metanode10', 'metanode6', 'metanode4',
                                                    'premetanode34', 'premetanode18', 'premetanode10', 'premetanode6',
                                                    'premetanode4'],
                    default='premetanode10')
parser.add_argument('--in_planes', type=int, default=64)

# Type of layer's output normalization
parser.add_argument('--normalization_resblock', type=str, default='NF',
                    choices=['BN', 'GN', 'LN', 'IN', 'NF'])
parser.add_argument('--normalization_odeblock', type=str, default='NF',
                    choices=['BN', 'GN', 'LN', 'IN', 'NF'])
parser.add_argument('--normalization_bn1', type=str, default='NF',
                    choices=['BN', 'GN', 'LN', 'IN', 'NF'])
parser.add_argument('--num_gn_groups', type=int, default=32, help='Number of groups for GN normalization')

# Type of layer's weights  normalization
parser.add_argument('--param_normalization_resblock', type=str, default='PNF',
                    choices=['WN', 'SN', 'PNF'])
parser.add_argument('--param_normalization_odeblock', type=str, default='PNF',
                    choices=['WN', 'SN', 'PNF'])
parser.add_argument('--param_normalization_bn1', type=str, default='PNF',
                    choices=['WN', 'SN', 'PNF'])
# Type of activation
parser.add_argument('--activation_resblock', type=str, default='ReLU',
                    choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])
parser.add_argument('--activation_odeblock', type=str, default='ReLU',
                    choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])
parser.add_argument('--activation_bn1', type=str, default='ReLU',
                    choices=['ReLU', 'GeLU', 'Softsign', 'Tanh', 'AF'])

args, unknown_args = parser.parse_known_args()

In [13]:
# Initialize Neural ODE model
config = copy.deepcopy(args)

norm_layers = (get_normalization(config.normalization_resblock),
               get_normalization(config.normalization_odeblock),
               get_normalization(config.normalization_bn1))
param_norm_layers = (get_param_normalization(config.param_normalization_resblock),
                     get_param_normalization(config.param_normalization_odeblock),
                     get_param_normalization(config.param_normalization_bn1))
act_layers = (get_activation(config.activation_resblock),
              get_activation(config.activation_odeblock),
              get_activation(config.activation_bn1))

model = getattr(cifar10_models, config.network)(norm_layers, param_norm_layers, act_layers,
                                                config.in_planes, is_odenet=config.is_odenet)

In [14]:
model

MetaNODE(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): Identity()
  (layer1): MetaLayer(
    (blocks_res): Sequential(
      (0): PreBasicBlock(
        (bn1): Identity()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (shortcut): Sequential()
      )
    )
    (blocks_ode): ModuleList(
      (0): MetaODEBlock(
        (rhs_func): PreBasicBlock2(
          (bn1): Identity()
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): Identity()
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (shortcut): Sequential()
        )
      )
    )
  )
  (layer2): MetaLayer(
    (blocks_res): Sequential(
      (0): PreBasicBlock(
        (bn1): Identity()
     