In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from scale_cnn.convolution import ScaleConvolution
from scale_cnn.pooling import ScalePool

In [5]:
class SiCNN_3(nn.Module): 
    def __init__(self, f_in=1, size=5, ratio=2**(2/3), nratio=3, srange=1, padding=0, nb_classes=10, features_factor=1): 
        super().__init__()
        '''
        Scale equivariant arch with 3 convolutional layers
        '''
        self.f_in = f_in
        self.size = size
        self.ratio = ratio 
        self.nratio = nratio
        self.srange = srange
        self.padding = padding
        self.nb_classes = nb_classes

        self.conv1 = ScaleConvolution(self.f_in, int(96 * features_factor), self.size, self.ratio, self.nratio, srange = 0, boundary_condition = "dirichlet", padding=self.padding, stride = 2)
        self.conv2 = ScaleConvolution(int(96 * features_factor), int(96 * features_factor), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.conv3 = ScaleConvolution(int(96 * features_factor), int(192 * features_factor), self.size, self.ratio, self.nratio, srange = self.srange, boundary_condition = "dirichlet", padding=self.padding)
        self.pool = ScalePool(self.ratio)
        
        self.fc1 = nn.Linear(int(192 * features_factor), int(150 * features_factor), bias=True)
        self.fc2 = nn.Linear(int(150 * features_factor), self.nb_classes, bias=True)

    def forward(self, x): 
        x = x.unsqueeze(1)  # [batch, sigma, feature, y, x]
        x = x.repeat(1, self.nratio, 1, 1, 1)  # [batch, sigma, feature, y, x]
        
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x) # [batch,feature]
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
sum(p.numel() for p in SiCNN_3(3, srange=2, features_factor=1).parameters())

3494044

In [7]:
sum(p.numel() for p in SiCNN_3(3, srange=0, features_factor=2).parameters())

2898478

In [8]:
SiCNN_3(3, srange=0, features_factor=2.19795)

SiCNN_3(
  (conv1): ScaleConvolution (size=5, 3 → 211, n=3±0, dirichlet)
  (conv2): ScaleConvolution (size=5, 211 → 211, n=3±0, dirichlet)
  (conv3): ScaleConvolution (size=5, 211 → 422, n=3±0, dirichlet)
  (pool): ScalePool
  (fc1): Linear(in_features=422, out_features=329, bias=True)
  (fc2): Linear(in_features=329, out_features=10, bias=True)
)

In [20]:
class kanazawa(nn.Module): 
    def __init__(self, f_in, ratio=2**(2/3), nratio=3, srange=1, nb_classes=10, features_factor=1): 
        super().__init__()
        '''
        Scale equivariant arch, based on architecture in Kanazawa's paper https://arxiv.org/abs/1412.5104
        selecting srange = 1 is equivalent to the paper
        '''
        self.f_in = f_in
        self.ratio = ratio 
        self.nratio = nratio
        self.srange = srange
        self.nb_classes = nb_classes

        self.conv1 = ScaleConvolution(self.f_in, int(36*features_factor), 3, self.ratio, self.nratio, srange = 0, boundary_condition = "dirichlet", stride = 2)
        self.conv2 = ScaleConvolution(int(36*features_factor), int(64*features_factor), 3, self.ratio, self.nratio, srange = srange, boundary_condition = "dirichlet")
        self.pool = ScalePool(self.ratio)
        
        self.fc1 = nn.Linear(int(64*features_factor), int(150*features_factor), bias = True)
        self.fc2 = nn.Linear(int(150*features_factor), self.nb_classes, bias = True)

In [21]:
sum(p.numel() for p in kanazawa(3, srange=2).parameters())

116012

In [25]:
sum(p.numel() for p in kanazawa(3, srange=0, features_factor=1.92).parameters())

116130

In [27]:
int(36*1.92)

69

In [28]:
int(64*1.92)

122

In [29]:
int(150*1.92)

288