In [1]:
import torch
from functools import partial

import flexconv
from flexconv import FlexConv, SeparableFlexConv, CCNN, CCNNBlock, TCNBlock, PreActResNetBlock, ResNetBlock

In [2]:
in_channels = 3
out_channels = 32
n_dims = 2 # 1 for sequential data, 2 for image etc.
conv = FlexConv(in_channels, out_channels, n_dims)
sep_conv = SeparableFlexConv(in_channels, out_channels, n_dims)
conv

FlexConv(
  (Kernel): MAGNet(
    (linears): ModuleList(
      (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
    )
    (output_linear): Conv2d(32, 96, kernel_size=(1, 1), stride=(1, 1))
    (filters): ModuleList(
      (0-3): 4 x MAGNetLayer(
        (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
      )
    )
  )
)

In [3]:
model = CCNN(in_channels=3, out_channels=10, no_hidden=380, no_blocks=6, data_dim=2)
model

CCNN(
  (dropout_in): Dropout(p=0, inplace=False)
  (conv1): SeparableFlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (channel_mixer): Conv2d(3, 380, kernel_size=(1, 1), stride=(1, 1))
  )
  (norm1): BatchNorm2d(380, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlinear): GELU(approximate='none')
  (blocks): Sequential(
    (0): S4Block(
      (conv1): SeparableFlexConv(
        (Kernel): MAGNet(
          (linears): ModuleList(
            (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (output_linear): Conv2d(32, 380, kernel_size=(1, 1), stride=(1, 1))
          (filters): ModuleList(
            (0-3): 4 x MAGNe

In [4]:
ccnn_block = CCNNBlock(in_channels, out_channels, n_dims)
ccnn_block

CCNNBlock(
  (conv1): SeparableFlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
    (channel_mixer): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  )
  (nonlinears): ModuleList(
    (0-1): 2 x GELU(approximate='none')
  )
  (norm1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
  (dp): Dropout2d(p=0.15, inplace=False)
  (shortcut): Sequential(
    (0): Conv2d(3, 32, kernel_size=(1, 1), stride=(1, 1))
  )
)

Other blocks are still available but do not yet have a ncie wrapper interface.
You can still use them as follows:

In [5]:
conv = partial(FlexConv, data_dim=2) # block expects data_dim to be already set
nonlinearity = torch.nn.GELU
norm = torch.nn.BatchNorm2d # can be other dimensions 
linearLayerType = flexconv.linear.Linear2d # can be other dimensions
dropoutType = torch.nn.Dropout2d # can be other dimensions
dropout = 0.1


tcn_block = TCNBlock(3, 16, conv, nonlinearity, norm, linearLayerType, dropoutType, dropout)
tcn_block

shortcut used


TCNBlock(
  (conv1): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (conv2): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (nonlinearities): ModuleList(
    (0-2): 3 x GELU(approximate='none')
  )
  (norm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=Tru

In [6]:
conv = partial(FlexConv, data_dim=2) # block expects data_dim to be already set
nonlinearity = torch.nn.GELU
norm = torch.nn.BatchNorm2d # can be other dimensions 
linearLayerType = flexconv.linear.Linear2d # can be other dimensions
dropoutType = torch.nn.Dropout2d # can be other dimensions
dropout = 0.1

# TCNBlock, ResNetBlock, PreActResNetBlock
resnet_block = ResNetBlock(3, 16, conv, nonlinearity, norm, linearLayerType, dropoutType, dropout)
resnet_block

shortcut used


ResNetBlock(
  (conv1): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (conv2): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (nonlinearities): ModuleList(
    (0-2): 3 x GELU(approximate='none')
  )
  (norm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=

In [7]:
conv = partial(FlexConv, data_dim=2) # block expects data_dim to be already set
nonlinearity = torch.nn.GELU
norm = torch.nn.BatchNorm2d # can be other dimensions 
linearLayerType = flexconv.linear.Linear2d # can be other dimensions
dropoutType = torch.nn.Dropout2d # can be other dimensions
dropout = 0.1

pre_act_resnet_block = PreActResNetBlock(3, 16, conv, nonlinearity, norm, linearLayerType, dropoutType, dropout)
pre_act_resnet_block

shortcut used


PreActResNetBlock(
  (conv1): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 48, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (conv2): FlexConv(
    (Kernel): MAGNet(
      (linears): ModuleList(
        (0-2): 3 x Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      )
      (output_linear): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1))
      (filters): ModuleList(
        (0-3): 4 x MAGNetLayer(
          (linear): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (nonlinearities): ModuleList(
    (0-2): 3 x GELU(approximate='none')
  )
  (norm1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, a