# Receptive field size calculator

RF 계산하는게 귀찮아서 계산기를 만듦.

Method:

1. Forward with `x.requires_grad = True`
2. RF 를 구하고 싶은 featuremap (or output) 을 선택
3. Backward from the center element of chosen featuremap
4. Calc RF using nonzero grads

In [1]:
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def calc_rf(net, C_in, size):
    """
    Args:
        net
        C_in
        size: input size of x
    """
    # 1. Forward with x.requires_grad=True
    x = torch.rand(1, C_in, size, size, requires_grad=True)
    r = net(x)
    # Skip step 2 by assuming final output as a target featuremap
    # 3. Backward from center element
    r[0, 0, r.size(2)//2, r.size(3)//2].backward()
    # 4. Calc RF using nonzero grads
    spatial_max_points = x.grad[0, 0].nonzero().max(0).values
    spatial_min_points = x.grad[0, 0].nonzero().min(0).values
    rf_size = spatial_max_points - spatial_min_points + 1
    
    # sanity check
    rf = x[
        0, 0, 
        spatial_min_points[0]:spatial_min_points[0]+rf_size[0], 
        spatial_min_points[1]:spatial_min_points[1]+rf_size[1]
    ]
    n_zero = (rf == 0.).sum()
    assert n_zero == 0, "RF has {} zero element".format(n_zero)
    
    return rf_size.tolist()

## Code blocks

code blocks from MaHFG.

In [3]:
def dispatcher(dispatch_fn):
    def decorated(key, *args):
        if callable(key):
            return key

        if key is None:
            key = 'none'

        return dispatch_fn(key, *args)
    return decorated


@dispatcher
def norm_dispatch(norm):
    return {
        'none': nn.Identity,
        'in': partial(nn.InstanceNorm2d, affine=False),  # false as default
        'bn': nn.BatchNorm2d,
    }[norm.lower()]


@dispatcher
def w_norm_dispatch(w_norm):
    # NOTE Unlike other dispatcher, w_norm is function, not class.
    return {
        'none': lambda x: x
    }[w_norm.lower()]


@dispatcher
def activ_dispatch(activ, norm=None):
    return {
        "none": nn.Identity,
        "relu": nn.ReLU,
        "lrelu": partial(nn.LeakyReLU, negative_slope=0.2)
    }[activ.lower()]


@dispatcher
def pad_dispatch(pad_type):
    return {
        "zero": nn.ZeroPad2d,
        "replicate": nn.ReplicationPad2d,
        "reflect": nn.ReflectionPad2d
    }[pad_type.lower()]


class LinearBlock(nn.Module):
    """ pre-active linear block """
    def __init__(self, C_in, C_out, norm='none', activ='relu', bias=True, w_norm='none',
                 dropout=0.):
        super().__init__()
        activ = activ_dispatch(activ, norm)
        if norm.lower() == 'bn':
            norm = nn.BatchNorm1d
        elif norm.lower() == 'frn':
            norm = FilterResponseNorm1d
        elif norm.lower() == 'none':
            norm = nn.Identity
        else:
            raise ValueError(f"LinearBlock supports BN only (but {norm} is given)")
        w_norm = w_norm_dispatch(w_norm)
        self.norm = norm(C_in)
        self.activ = activ()
        if dropout > 0.:
            self.dropout = nn.Dropout(p=dropout)
        self.linear = w_norm(nn.Linear(C_in, C_out, bias))

    def forward(self, x):
        x = self.norm(x)
        x = self.activ(x)
        if hasattr(self, 'dropout'):
            x = self.dropout(x)
        return self.linear(x)


class ConvBlock(nn.Module):
    """ pre-active conv block """
    def __init__(self, C_in, C_out, kernel_size=3, stride=1, padding=1, norm='none',
                 activ='relu', bias=True, upsample=False, downsample=False, w_norm='none',
                 pad_type='zero', dropout=0., size=None):
        # 1x1 conv assertion
        if kernel_size == 1:
            assert padding == 0
        super().__init__()
        self.C_in = C_in
        self.C_out = C_out

        activ = activ_dispatch(activ, norm)
        norm = norm_dispatch(norm)
        w_norm = w_norm_dispatch(w_norm)
        pad = pad_dispatch(pad_type)
        self.upsample = upsample
        self.downsample = downsample

        self.norm = norm(C_in)
        self.activ = activ()
        if dropout > 0.:
            self.dropout = nn.Dropout2d(p=dropout)
        self.pad = pad(padding)
        self.conv = w_norm(nn.Conv2d(C_in, C_out, kernel_size, stride, bias=bias))

    def forward(self, x):
        x = self.norm(x)
        x = self.activ(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2)
        if hasattr(self, 'dropout'):
            x = self.dropout(x)
        x = self.conv(self.pad(x))
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        return x


class ResBlock(nn.Module):
    """ Pre-activate ResBlock with spectral normalization """
    def __init__(self, C_in, C_out, kernel_size=3, padding=1, upsample=False, downsample=False,
                 norm='none', w_norm='none', activ='relu', pad_type='zero', dropout=0.,
                 scale_var=False):
        assert not (upsample and downsample)
        super().__init__()
        w_norm = w_norm_dispatch(w_norm)
        self.C_in = C_in
        self.C_out = C_out
        self.upsample = upsample
        self.downsample = downsample
        self.scale_var = scale_var

        self.conv1 = ConvBlock(C_in, C_out, kernel_size, 1, padding, norm, activ,
                               upsample=upsample, w_norm=w_norm, pad_type=pad_type,
                               dropout=dropout)
        self.conv2 = ConvBlock(C_out, C_out, kernel_size, 1, padding, norm, activ,
                               w_norm=w_norm, pad_type=pad_type, dropout=dropout)

        # XXX upsample / downsample needs skip conv?
        if C_in != C_out or upsample or downsample:
            self.skip = w_norm(nn.Conv2d(C_in, C_out, 1))

    def forward(self, x):
        """
        normal: pre-activ + convs + skip-con
        upsample: pre-activ + upsample + convs + skip-con
        downsample: pre-activ + convs + downsample + skip-con
        => pre-activ + (upsample) + convs + (downsample) + skip-con
        """
        out = x

        out = self.conv1(out)
        out = self.conv2(out)

        if self.downsample:
            out = F.avg_pool2d(out, 2)

        # skip-con
        if hasattr(self, 'skip'):
            if self.upsample:
                x = F.interpolate(x, scale_factor=2)
            x = self.skip(x)
            if self.downsample:
                x = F.avg_pool2d(x, 2)

        out = out + x
        if self.scale_var:
            out = out / np.sqrt(2)
        return out

In [4]:
C = 32
w_norm = 'none'
activ = 'relu'
pad_type = 'zero'
ConvBlk = partial(ConvBlock, w_norm=w_norm, activ=activ, pad_type=pad_type)
ResBlk = partial(ResBlock, w_norm=w_norm, activ=activ, pad_type=pad_type)

feats = [
    ConvBlk(1, C, stride=2, activ='none'), # 64x64 (stirde==2) -> 3
    ResBlk(C*1, C*2, downsample=True),    # 32x32 -> 13
    ResBlk(C*2, C*4, downsample=True),    # 16x16 -> 33
    ResBlk(C*4, C*8, downsample=True),    # 8x8 -> 73
    ResBlk(C*8, C*8, downsample=False),   # 8x8 -> 125
    ResBlk(C*8, C*8, downsample=False),   # 8x8 -> 128
]

# feats = [
#     ConvBlk(1, 1, stride=1, activ='none'), # 3
#     ConvBlk(1, 1, stride=1, activ='none'), # 5
#     ConvBlk(1, 1, stride=1, activ='none'), # 7
#     ConvBlk(1, 1, stride=1, activ='none'), # 9
# ]

In [5]:
for i, _ in enumerate(feats):
    net = nn.Sequential(*feats[:i+1])
    print(i, calc_rf(net, 1, 128))

0 [3, 3]
1 [13, 13]
2 [33, 33]
3 [73, 73]
4 [125, 125]
5 [128, 128]
