In [1]:
import torch
import torch.nn.functional as F

### Input activations

In [2]:
# input
x = torch.randn(1, 64, 32, 32)
w = torch.randn(64, 64, 3, 3)

In [3]:
mask_base = torch.zeros_like(w.data)
out_channels = w.size(1)
in_channels = w.size(0)
cg_groups = 4

In [4]:
cg_out_chunk_size = out_channels // cg_groups
cg_in_chunk_size = in_channels // cg_groups
for i in range(cg_groups):
    mask_base[i*cg_out_chunk_size:(i+1)*cg_out_chunk_size,i*cg_in_chunk_size:(i+1)*cg_in_chunk_size,:,:] = 1

In [5]:
Yp = F.conv2d(x, w*mask_base, None, stride=1, padding=1)
print(list(Yp.size()))

[1, 64, 32, 32]


In [6]:
slice_size = 16
Yp_ = F.avg_pool3d(Yp, kernel_size=(slice_size,1,1), stride=(slice_size, 1, 1))
Yp_ = Yp_.repeat_interleave(slice_size, dim=1)

In [7]:
class Greater_Than(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.gt(input, 0).float()

    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input, None 

In [8]:
k = 0.25
cg_alpha=2.0
cg_threshold_init = -6.0

cg_bn = F.instance_norm
cg_gt = Greater_Than.apply
cg_threshold = cg_threshold_init * torch.ones(1, out_channels, 1, 1)

pre_d = cg_alpha*(cg_bn(Yp_)-cg_threshold)
d = cg_gt(torch.nn.functional.hardtanh(k*(pre_d+2), min_val=0., max_val=1.)-0.5)

In [9]:
print(d.unique())

tensor([1.])
