In [4]:
import sys, os
sys.path.append('/home/hyunjoon/github/tracking-pytorch/')
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18 as _resnet

In [20]:
class NormLayer(nn.Module):
    def __init__(self, kernel_size, stride, padding=(0, 0), eps=1e-06):
        '''
        '''
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if kernel_size[0] == 1:
            self.pool = nn.AvgPool2d(kernel_size, stride, padding)
        else:
            self.pool = None
        self.eps = eps
        
    def forward(self, x):
        u_x = torch.mean(x, dim=1, keepdim=True)
        u_x2 = torch.mean(x*x, dim=1, keepdim=True)
        if self.pool is not None:
            u_x = self.pool(u_x)
            u_x2 = self.pool(u_x2)
        v_x = F.relu(u_x2 - (u_x * u_x), inplace=True)
        out = (x - u_x) / torch.sqrt(v_x + self.eps)
        return out
    
    
class DepthToSpace(nn.Module):
    def __init__(self, block_size):
        super().__init__()
        self.bs = block_size

    def forward(self, x):
        N, C, H, W = x.size()
        x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W)  # (N, bs, bs, C//bs^2, H, W)
        x = x.permute(0, 3, 4, 1, 5, 2).contiguous()  # (N, C//bs^2, H, bs, W, bs)
        x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs)  # (N, C//bs^2, H * bs, W * bs)
        return x
    

class Resnet18(nn.Module):
    '''
    '''
    def __init__(self, pretrained=False):
        '''
        '''
        super().__init__()
        net = _resnet(pretrained=pretrained)
        self.base_layers = nn.ModuleList(list(net.children())[:-2])
        self.base_names = [l[0] for l in net.named_children()]
        
        self.color_layer = nn.Sequential(
            nn.AvgPool2d(4, stride=4),
            nn.Conv2d(3, 32, kernel_size=(1, 1), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            NormLayer(4, 2),
            nn.AvgPool2d(4, 2, padding=0),
        )
        
        self.norm_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            NormLayer(4, 2),
            nn.AvgPool2d(4, 2, padding=0),
        )
        
        self.downsample_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            NormLayer(1, 1),
            nn.AvgPool2d(4, 2, padding=0),
        )
        
        self.upsample_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            DepthToSpace(2),
            NormLayer(1, 1),
        )
        
        self.mid_layer = nn.Sequential(
            nn.ReLU(inplace=True),
            NormLayer(1, 1),
        )
        
    def forward(self, x):
        '''
        '''
        # color layer
        color_feat = self.color_layer(x)
        
        # base layers
        base_layers = {}
        for n, layer in self.base_layers:
            x = layer(x)
            base_layers[n] = x
            
        # grad layer
        grad_feat = self.norm_layer(base_layers['layer1'])
        
        # deep feature layer
        fd = self.downsample_layer(base_layers['layer2'])
        fm = self.mid_layer(base_layers['layer3'])
        fu = self.upsample_layer(base_layers['layer4'])
        deep_feat = torch.cat([fd, fm, fu], dim=1)
        
        # all the features
        feat = torch.cat([color_feat, grad_feat, deep_feat], dim=1)
        return feat

In [21]:
net = Resnet18()

In [22]:
list(net.children())

[ModuleList(
   (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (2): ReLU(inplace=True)
   (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (4): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
 

In [45]:
roi_align(x, roi, output_size=(3, 3))

tensor([[[[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

    

In [27]:
x[:, :, :3, :3]

tensor([[[[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

         [[ 0.,  1.,  2.],
          [29., 30., 31.],
          [58., 59., 60.]],

    