In [30]:
import torch
import torch.nn as nn

In [35]:
def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)

In [103]:
def attention_filter(attention_feature_map, kernel_size=3, mean=6, stddev=5):
    attention_map = torch.abs(attention_feature_map)
    
    attention_mask = attention_map > 2 * torch.mean(attention_map)
    attention_mask = attention_mask.float()
    
    w = torch.randn(kernel_size, kernel_size)
    truncated_normal_(w, mean, stddev)
    w = w / torch.sum(w)
    
    # [filter_height, filter_width, in_channels, out_channels]
    w = torch.unsqueeze(w, 2)
    w.repeat(1, 1, attention_mask.shape[3])
    w = torch.unsqueeze(w, 3)
    w.repeat(1, 1, 1, attention_mask.shape[3])
    # attention_map = tf.nn.conv2d(attention_mask, w, strides=[1, 1, 1, 1], padding='SAME')
    gaussian_filter = nn.Conv2d(attention_mask.shape[1], attention_mask.shape[1], (kernel_size, kernel_size))
    gaussian_filter.weight.data = w
    gaussian_filter.weight.requires_grad = False
    pad_filter = nn.Sequential(
        nn.ReflectionPad2d((1, 1, 1, 1)),
        gaussian_filter
    )
    attention_map = pad_filter(attention_mask)
    attention_map = attention_map - torch.min(attention_map)
    attention_map = attention_map / torch.max(attention_map)
    return attention_map

In [98]:
a = torch.arange(27).view(1, 3, 3, 3).float()
a.mean(dim=(0, 1))

tensor([[ 9., 10., 11.],
        [12., 13., 14.],
        [15., 16., 17.]])

In [104]:
x = attention_filter(a)
print(x)

tensor([[[[0.0124, 0.0124, 0.0124, 0.0124, 0.0124],
          [0.0124, 0.0124, 0.0124, 0.0124, 0.0124],
          [0.0124, 0.0124, 0.0124, 0.0124, 0.0124],
          [0.0124, 0.0124, 0.0124, 0.0124, 0.0124],
          [0.0124, 0.0124, 0.0124, 0.0124, 0.0124]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

         [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]]], grad_fn=<DivBackward0>)


In [82]:
a.max()

tensor(26.)