In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from scale_cnn.convolution import ScaleConvolution
from scale_cnn.pooling import ScalePool

In [2]:
class SiCNN(nn.Module): 
    def __init__(self, f_in=1, size=5, ratio=2**(2/3), nratio=3, srange=2, padding=0, nb_classes=10, factor=1): 
        super().__init__()
        '''
        Scale equivariant arch with 3 convolutional layers
        '''
        self.f_in = f_in
        self.size = size
        self.ratio = ratio 
        self.nratio = nratio
        self.srange = srange
        self.padding = padding
        self.nb_classes = nb_classes

        self.conv1 = ScaleConvolution(self.f_in, int(factor*12), self.size, self.ratio, self.nratio, srange = 0, boundary_condition = "dirichlet", padding=self.padding, stride = 2)
        self.conv2 = ScaleConvolution(int(factor*12), int(factor*21), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.conv3 = ScaleConvolution(int(factor*21), int(factor*36), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.conv4 = ScaleConvolution(int(factor*36), int(factor*36), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.conv5 = ScaleConvolution(int(factor*36), int(factor*64), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.conv6 = ScaleConvolution(int(factor*64), int(factor*64), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.pool = ScalePool(self.ratio)
        
        self.fc = nn.Linear(int(factor*64), self.nb_classes, bias=True)

    def forward(self, x): 
        x = x.unsqueeze(1)  # [batch, sigma, feature, y, x]
        x = x.repeat(1, self.nratio, 1, 1, 1)  # [batch, sigma, feature, y, x]
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool(x) # [batch,feature]
        x = self.fc(x)
        return x

In [3]:
sum(p.numel() for p in SiCNN(f_in=3, size=5, ratio=2**(1/3), nratio=6, srange=4, factor=1).parameters())

1960183

In [17]:
sum(p.numel() for p in SiCNN(f_in=3, size=5, ratio=2**(1/3), nratio=6, srange=0, factor=3).parameters())

1963729

In [18]:
SiCNN(f_in=3, size=5, ratio=2**(1/3), nratio=6, srange=0, factor=3)

SiCNN(
  (conv1): ScaleConvolution (size=5, 3 → 108, n=6±0, dirichlet)
  (conv2): ScaleConvolution (size=5, 108 → 192, n=6±0, dirichlet)
  (conv3): ScaleConvolution (size=5, 192 → 288, n=6±0, dirichlet)
  (conv4): ScaleConvolution (size=5, 288 → 288, n=6±0, dirichlet)
  (conv5): ScaleConvolution (size=5, 288 → 576, n=6±0, dirichlet)
  (conv6): ScaleConvolution (size=5, 576 → 576, n=6±0, dirichlet)
  (pool): ScalePool
  (fc): Linear(in_features=576, out_features=10, bias=True)
)

In [19]:
sum(p.numel() for p in SiCNN(f_in=3, size=5, ratio=2**(2/3), nratio=3, srange=2, factor=1).parameters())

9125306

In [34]:
sum(p.numel() for p in SiCNN(f_in=3, size=5, ratio=2**(2/3), nratio=3, srange=0, factor=2.23).parameters())

9069796

116130

907180

909849