In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torchsummary import summary

from pyvrl.models.backbones.r3d import R2Plus1D

### Default backbone

In [3]:
backbone_args=dict(
    depth=18,
    num_stages=4,
    stem=dict(
        temporal_kernel_size=3,
        temporal_stride=1,
        in_channels=3,
        with_pool=False,
    ),
    down_sampling=[False, True, True, True],
    down_sampling_temporal=[False, True, True, True],
    channel_multiplier=1.0,
    bottleneck_multiplier=1.0,
    with_bn=True,
    zero_init_residual=False
)

In [4]:
backbone = R2Plus1D(**backbone_args)

In [5]:
x = torch.randn(1, 3, 16, 112, 112)

In [6]:
y = backbone(x)

In [7]:
y.shape

torch.Size([1, 512, 2, 7, 7])

In [8]:
summary(backbone, (3, 16, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 45, 16, 56, 56]           6,615
       BatchNorm3d-2       [-1, 45, 16, 56, 56]              90
              ReLU-3       [-1, 45, 16, 56, 56]               0
            Conv3d-4       [-1, 64, 16, 56, 56]           8,640
       BatchNorm3d-5       [-1, 64, 16, 56, 56]             128
              ReLU-6       [-1, 64, 16, 56, 56]               0
            Conv3d-7      [-1, 144, 16, 56, 56]          82,944
       BatchNorm3d-8      [-1, 144, 16, 56, 56]             288
              ReLU-9      [-1, 144, 16, 56, 56]               0
           Conv3d-10       [-1, 64, 16, 56, 56]          27,648
      BatchNorm3d-11       [-1, 64, 16, 56, 56]             128
             ReLU-12       [-1, 64, 16, 56, 56]               0
           Conv3d-13      [-1, 144, 16, 56, 56]          82,944
      BatchNorm3d-14      [-1, 144, 16,

### Customized backbone

In [65]:
backbone_args=dict(
    depth=18,
    num_stages=4,
    stem=dict(
        temporal_kernel_size=3,
        temporal_stride=1,
        in_channels=3,
        with_pool=False,
    ),
    down_sampling=[False, True, True, True],
    down_sampling_temporal=[False, True, True, True],
    channel_multiplier=1.0,
    bottleneck_multiplier=1.0,
    with_bn=True,
    zero_init_residual=False
)

[autoreload of pyvrl.models.backbones.r3d failed: Traceback (most recent call last):
  File "/home/fmthoker/anaconda3/envs/ctp4/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/home/fmthoker/anaconda3/envs/ctp4/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/home/fmthoker/anaconda3/envs/ctp4/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/home/fmthoker/anaconda3/envs/ctp4/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/gpfs/scratch1/shared/fmthoker/pbagad/CtP-ssl/pyvrl/models/backbones/r3d.py", line 450, in <module>
    clas

In [59]:
backbone = R2Plus1D(**backbone_args)

In [60]:
x = torch.randn(1, 3, 16, 112, 112)

In [61]:
y = backbone(x)

In [62]:
backbone

R2Plus1D(
  (stem): Sequential(
    (conv_s): Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (bn_s): BatchNorm3d(45, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu_s): ReLU(inplace=True)
    (conv_t): Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
    (bn_t): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (relu_t): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (conv_s): Conv3d(64, 144, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn_s): BatchNorm3d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (relu_s): ReLU(inplace=True)
        (conv_t): Conv3d(144, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
      )
      (bn1): BatchNorm3d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=

In [63]:
y.shape

torch.Size([1, 512, 2, 56, 56])

In [64]:
summary(backbone, (3, 16, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 45, 16, 56, 56]           6,615
       BatchNorm3d-2       [-1, 45, 16, 56, 56]              90
              ReLU-3       [-1, 45, 16, 56, 56]               0
            Conv3d-4       [-1, 64, 16, 56, 56]           8,640
       BatchNorm3d-5       [-1, 64, 16, 56, 56]             128
              ReLU-6       [-1, 64, 16, 56, 56]               0
            Conv3d-7      [-1, 144, 16, 56, 56]          82,944
       BatchNorm3d-8      [-1, 144, 16, 56, 56]             288
              ReLU-9      [-1, 144, 16, 56, 56]               0
           Conv3d-10       [-1, 64, 16, 56, 56]          27,648
      BatchNorm3d-11       [-1, 64, 16, 56, 56]             128
             ReLU-12       [-1, 64, 16, 56, 56]               0
           Conv3d-13      [-1, 144, 16, 56, 56]          82,944
      BatchNorm3d-14      [-1, 144, 16,

### Pytorch R2Plus1D

In [9]:
from pyvrl.models.backbones.r3d_pytorch import r2plus1d_18

In [10]:
backbone = r2plus1d_18()

144
144
230
288
460
576
921
1152


In [11]:
y = backbone(x)

In [12]:
y.shape

torch.Size([1, 512, 2, 7, 7])

In [13]:
summary(backbone, (3, 16, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 45, 16, 56, 56]           6,615
       BatchNorm3d-2       [-1, 45, 16, 56, 56]              90
              ReLU-3       [-1, 45, 16, 56, 56]               0
            Conv3d-4       [-1, 64, 16, 56, 56]           8,640
       BatchNorm3d-5       [-1, 64, 16, 56, 56]             128
              ReLU-6       [-1, 64, 16, 56, 56]               0
            Conv3d-7      [-1, 144, 16, 56, 56]          82,944
       BatchNorm3d-8      [-1, 144, 16, 56, 56]             288
              ReLU-9      [-1, 144, 16, 56, 56]               0
           Conv3d-10       [-1, 64, 16, 56, 56]          27,648
      BatchNorm3d-11       [-1, 64, 16, 56, 56]             128
             ReLU-12       [-1, 64, 16, 56, 56]               0
           Conv3d-13      [-1, 144, 16, 56, 56]          82,944
      BatchNorm3d-14      [-1, 144, 16,