In [49]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models, transforms, utils
import copy
%matplotlib inline

In [38]:
model = models.vgg19(pretrained=True)

In [40]:
model.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace)
  (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (17): ReLU(inplace)
  (18): MaxPool2d(kernel_size=2, stride=2, padding=0, 

2がblock1_conv2に対応
16がblock3_conv4に対応

In [17]:
image_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )
    ])

In [31]:
sample = cv2.imread('sample1.jpg')

In [57]:
image = image_transform(sample).unsqueeze(0)

In [75]:
class Featex():
    def __init__(self, model, use_cuda):
        self.use_cuda = use_cuda
        self.feature1 = None
        self.feature2 = None
        self.model= copy.deepcopy(model.eval())
        self.model = self.model[:17]
        if self.use_cuda:
            self.model = self.model.cuda()
        self.model[2].register_forward_hook(self.save_feature1)
        self.model[16].register_forward_hook(self.save_feature2)
        
    def save_feature1(self, module, input, output):
        self.feature1 = output.detach()
    
    def save_feature2(self, module, input, output):
        self.feature2 = output.detach()
        
    def __call__(self, input):
        if self.use_cuda:
            input = input.cuda()
        _ = self.model(input)
        # resize feature2 to the same size of feature1
        self.feature2 = F.interpolate(self.feature2, size=(self.feature1.size()[2], self.feature1.size()[3]), mode='bilinear', align_corners=True)
        return torch.cat((self.feature1, self.feature2), dim=1)

In [76]:
extract_features = Featex(model.features, use_cuda=False)

In [77]:
feat = extract_features(image)

In [78]:
feat.shape

torch.Size([1, 320, 225, 384])

In [89]:
class MyNormLayer():
    def __call__(self, x1, x2):
        bs, _ , H, W = x1.size()
        _, h, w, _ = x2.size()
        x1 = x1.view(bs, -1, H*W)
        x2 = x2.view(bs, -1, h*w)
        concat = torch.cat((x1, x2), dim=2)
        x_mean = torch.mean(concat, dim=2)
        x_std = torch.std(concat, dim=2)
        x1 = (x1 - x_mean) / x_std
        x2 = (x2 - x_std) / x_std
        x1 = x1.view(bs, H, W, -1)
        x2 = x2.view(bs, h, w, -1)
        return [x1, x2]

In [97]:
class CreateModel():
    def __init__(self, alpha, model, use_cuda):
        self.alpha = alpha
        self.featex = Featex(model, use_cuda)
    def __call__(self, template, image):
        T_feat = self.featex(template)
        I_feat = self.featex(image)
        I_feat, T_feat = MyNormLayer()(I_feat, T_feat)
        dist = torch.einsum("xcab,xcde->xabde", torch.norm(I_feat, dim=1), torch.norm(T_feat, dim=1))
        conf_map = QATM(self.alpha)(dist)
        return conf_map

In [None]:
def call( self, x ):
        batch_size, ref_row, ref_col, qry_row, qry_col = [ tf.shape(x)[k] for k in range(5) ]
        x = tf.reshape( x, [batch_size, ref_row * ref_col, qry_row * qry_col ] )
        xm_ref = x - K.max(x,axis=1,keepdims=True)
        conf_ref = tf.nn.softmax( self.coef_ref*xm_ref, axis=1 )
        xm_qry = x - K.max(x,axis=2,keepdims=True)
        conf_qry = tf.nn.softmax( self.coef_qry*xm_qry, axis=2 )
        confidence = K.sqrt(conf_ref * conf_qry )
        conf_values, ind3 = tf.nn.top_k( confidence, k=1 ) # batch_size, ref_size, 1
        ind1, ind2 = tf.meshgrid( tf.range( batch_size ), 
                                  tf.range( ref_row * ref_col ), indexing='ij' )
        ind1 = K.flatten( ind1 )
        ind2 = K.flatten( ind2 )
        ind3 = K.flatten( ind3 )
        indices = K.stack([ind1,ind2,ind3],axis=1)
        values = tf.gather_nd( confidence, indices )
        values = tf.reshape( values, [batch_size, ref_row, ref_col, 1])
        return values
    
    def compute_output_shape( self, input_shape ):
        bs, H, W, _, _ = input_shape
        return (bs, H, W, 1)

In [None]:
class QATM():
    def __init__(self, alpha):
        self.alpha = alpha
        
    def __call__(self, x):
        batch_size, ref_row, ref_col, qry_row, qry_col = x.size()
        x = x.view(batch_size, ref_row*ref_col, qry_row*qry_col)
        xm_ref = x - torch.max(x, dim=1, keepdim=True)[0]
        conf_ref = F.softmax(self.alpha*xm_ref, dim=1)
        xm_qry = x - torch.max(x, dim=2, keepdim=True)[0]
        conf_qry = F.softmax(self.alpha * xm_qry, dim=2)
        confidence = torch.sqrt(conf_ref * conf_qry)

In [129]:
a = torch.rand(2,3,4)

In [135]:
F.softmax(a, dim=0)

tensor([[[0.5270, 0.6975, 0.6501, 0.5351],
         [0.4933, 0.6219, 0.3581, 0.4878],
         [0.5656, 0.4960, 0.3836, 0.4446]],

        [[0.4730, 0.3025, 0.3499, 0.4649],
         [0.5067, 0.3781, 0.6419, 0.5122],
         [0.4344, 0.5040, 0.6164, 0.5554]]])

In [131]:
torch.max(a, dim=1,keepdim=True)[0].shape

torch.Size([2, 1, 4])

In [122]:
torch.cat((torch.rand(10, 10), torch.rand(20, 30)), dim=(0,1))

TypeError: cat(): argument 'dim' must be int, not tuple

In [88]:
torch.mean(torch.rand(10, 20, 15), dim=0).shape

torch.Size([20, 15])