In [0]:
import torch
from torch import nn
import torch.nn.functional as F
from functools import partial

In [0]:
# Utility Functions and Classes
def conv3(c_in, c_out, stride = 1, groups=1): 
  return nn.Conv2d(c_in, c_out, 3, stride, padding=1, bias=False, groups=groups)

def conv1(c_in, c_out, stride = 1, groups=1): 
  return nn.Conv2d(c_in, c_out, 1, stride, bias=False, groups=groups) 

class PrintBlock(nn.Module):
  def __init__(self, message):
    super().__init__()
    self.message = message
  def forward(self, x):
    print(self.message)
    print("input shape:", x.shape)
    return x

identity = lambda x: x

In [0]:
class ResXBlockA(nn.Module):
  def __init__(self, c_in, c_mid, stride, cardinality, res_path = None):
    super().__init__()

    c_out = c_mid * 2

    assert(c_mid % cardinality == 0)

    c_path = c_mid // cardinality
    self.body_paths = [nn.Sequential(
        
        conv1(c_in, c_path, stride), 
        nn.BatchNorm2d(c_path),
        nn.ReLU(),

        conv3(c_path, c_path),
        nn.BatchNorm2d(c_path),
        nn.ReLU(),
        
        conv1(c_mid, c_out))
        
        for _ in range(cardinality)]

    
    self.final_norm = nn.BatchNorm2d(c_out)
    
    self.res_path = identity if res_path is None else res_path

  def forward(self, inp): 
    
    body = sum([path(inp) for path in self.body_paths])
    return F.relu(self.final_norm(body) + self.res_path(inp))

In [0]:
class ResXBlockB(nn.Module):
  def __init__(self, c_in, c_mid, stride, cardinality, res_path = None):
    super().__init__()

    c_out = c_mid * 2

    assert(c_mid % cardinality == 0)

    c_path = c_mid // cardinality
    self.body_paths = [nn.Sequential(
        
        conv1(c_in, c_path, stride), 
        nn.BatchNorm2d(c_path),
        nn.ReLU(),

        conv3(c_path, c_path),
        nn.BatchNorm2d(c_path),
        nn.ReLU())
        
        for _ in range(cardinality)]

    self.tail = nn.Sequential(
        conv1(c_mid, c_out),
        nn.BatchNorm2d(c_out)
    )
    
    self.res_path = identity if res_path is None else res_path

  def forward(self, inp): 
    body = [path(inp) for path in self.body_paths]
    concat = torch.cat(body, dim=1)
    with_tail = self.tail(concat)
    
    return F.relu(with_tail + self.res_path(inp))

In [0]:
class ResXBlockC(nn.Module):
  def __init__(self, c_in, c_mid, stride, cardinality, res_path = None):
    super().__init__()

    c_out = c_mid * 2

    self.body = nn.Sequential(
        
        conv1(c_in, c_mid, stride), 
        nn.BatchNorm2d(c_mid),
        nn.ReLU(),

        conv3(c_mid, c_mid, groups=cardinality),
        nn.BatchNorm2d(c_mid),
        nn.ReLU(),
        
        conv1(c_mid, c_out),
        nn.BatchNorm2d(c_out))
    
    self.res_path = identity if res_path is None else res_path

  def forward(self, inp): return F.relu(self.body(inp) + self.res_path(inp))

In [0]:
def resnext50_32_4(n_classes):

  start = nn.Sequential(
      nn.Conv2d(3, 64, 7, stride=2, bias = False, padding=3),
      nn.ReLU(),
      nn.BatchNorm2d(64),
      nn.MaxPool2d(kernel_size = 3, stride = 2)
  )

  c_in = 64

  model = []
  layer_sizes = [3, 4, 6, 3]


  # Corresponds to layers conv{3,4,5}_1 in paper.
  downsample_layers = {1,2,3}
  for i, layer in enumerate(layer_sizes):
    for j in range(layer):
      print(i, j, "!")
      c_mid = c_in * 2 if i == 0 and j == 0 else (c_in if j == 0 else c_in // 2)
      stride = 2 if i in downsample_layers and j == 0 else 1
      
      res_path = nn.Sequential(
          conv1(c_in, c_mid * 2, stride),
          nn.BatchNorm2d(c_mid * 2)
      )
      
      model.append(PrintBlock(f"starting layer {i}-{j}"))
      model.append(ResXBlockC(c_in, c_mid, stride, res_path = res_path, cardinality=32))
      c_in = c_mid * 2

  head = nn.Sequential(
      nn.AdaptiveAvgPool2d((1,1)),
      nn.Flatten(1),
      nn.Linear(2048, n_classes)
    )

  return nn.Sequential(start, *model, head)
  

In [14]:
m = resnext50_32_4(10)

0 0 !
0 1 !
0 2 !
1 0 !
1 1 !
1 2 !
1 3 !
2 0 !
2 1 !
2 2 !
2 3 !
2 4 !
2 5 !
3 0 !
3 1 !
3 2 !


In [8]:
m

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): ReLU()
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): PrintBlock()
  (2): ResXBlockC(
    (body): Sequential(
      (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (res_path): Sequential(
      (0): Conv2d(64, 256, ker

In [15]:
m(torch.randn(3, 3 ,400 ,400))

starting layer 0-0
input shape: torch.Size([3, 64, 99, 99])
starting layer 0-1
input shape: torch.Size([3, 256, 99, 99])
starting layer 0-2
input shape: torch.Size([3, 256, 99, 99])
starting layer 1-0
input shape: torch.Size([3, 256, 99, 99])
starting layer 1-1
input shape: torch.Size([3, 512, 50, 50])
starting layer 1-2
input shape: torch.Size([3, 512, 50, 50])
starting layer 1-3
input shape: torch.Size([3, 512, 50, 50])
starting layer 2-0
input shape: torch.Size([3, 512, 50, 50])
starting layer 2-1
input shape: torch.Size([3, 1024, 25, 25])
starting layer 2-2
input shape: torch.Size([3, 1024, 25, 25])
starting layer 2-3
input shape: torch.Size([3, 1024, 25, 25])
starting layer 2-4
input shape: torch.Size([3, 1024, 25, 25])
starting layer 2-5
input shape: torch.Size([3, 1024, 25, 25])
starting layer 3-0
input shape: torch.Size([3, 1024, 25, 25])
starting layer 3-1
input shape: torch.Size([3, 2048, 13, 13])
starting layer 3-2
input shape: torch.Size([3, 2048, 13, 13])


tensor([[ 0.4474,  0.5927,  0.2363, -0.2834,  0.1953, -0.7881, -0.5076, -0.1221,
         -0.5136, -0.1898],
        [ 0.4442,  0.4538,  0.2975, -0.1393,  0.2027, -0.7488, -0.5546, -0.1239,
         -0.5387, -0.2435],
        [ 0.3667,  0.4718,  0.2794, -0.1907,  0.2499, -0.7338, -0.5354, -0.1475,
         -0.5283, -0.2091]], grad_fn=<AddmmBackward>)