In [1]:
from functools import partial

import torch
import torch.nn as nn
from torch import Tensor

import torch.nn.functional as F

In [2]:
def enc_block(in_ch: int, out_ch: int, kernel_size=3, stride=2, bn=True):
    conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, 
                     stride=stride, padding=1)
    relu = nn.ReLU(inplace=True)
    if bn:
        layers = [conv, nn.BatchNorm2d(out_ch), relu]
    else:
        layers = [conv, relu]
    return nn.Sequential(*layers)

In [3]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size=3, 
                 scale_factor=2, bn=True):
        super().__init__()
        self.upsample = partial(F.interpolate, 
                                scale_factor=scale_factor, 
                                mode='nearest')
        conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, 
                         stride=1, padding=1)
        relu = nn.ReLU(inplace=True)
        if bn:
            layers = [conv, nn.BatchNorm2d(out_ch), relu]
        else:
            layers = [conv, relu]
        self.layers = nn.Sequential(*layers)
  
    def forward(self, x):
        x = self.upsample(x)
        out = self.layers(x)
        return out

In [4]:
model = nn.Sequential(
    enc_block(3, 8, stride=1, bn=False),
    enc_block(8, 16),
    enc_block(16, 32),
    enc_block(32, 64),
    enc_block(64, 128)
#     DecoderBlock(32, 16),
#     DecoderBlock(16, 8)
)

In [5]:
x = torch.rand((1, 3, 256, 256))
x1 = model(x)
x1.shape

torch.Size([1, 128, 16, 16])

In [6]:
def select_old(x, y):
    N, C, H, W = x.shape
    half_C = C // 2

    low  = half_C * y
    high = half_C * (y + 1)

    x = x.clone()
    for i in range(N):
        x[i, low[i]:high[i]] = 0
    return x

In [7]:
def select(h, y):
    N, C, H, W = h.shape
    y = y.reshape(N, 1, 1, 1)
    
    h0, h1 = h.chunk(2, dim=1)
    h0 = h0 * (1 - y)
    h1 = h1 * y
    h = torch.cat([h0, h1], dim=1)
    return h

In [8]:
def act(h: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = h.shape

    y = y.reshape(N, 1, 1, 1)
    h0, h1 = h.chunk(2, dim=1)
    a = h0 * (1 - y) + h1 * y
    
    n_el = C * H * W / 2
    a = a.abs().sum((1, 2, 3)) / n_el
    
    # For simplicity, and without losing generality, 
    # we constrain a(x) to be equal to 1
    return a.clamp_max_(1).ceil()

In [9]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
])[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [10]:
h.reshape(5, 6)

tensor([[1, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 1, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [11]:
select(h, all_pos).reshape(5, 6)

tensor([[0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [12]:
act(h, all_pos)

tensor([1., 0., 1., 0., 1.])

In [13]:
select(h, all_neg).reshape(5, 6)

tensor([[1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]])

In [14]:
act(h, all_neg)

tensor([1., 0., 1., 1., 0.])

In [15]:
y

tensor([1, 0, 1, 0, 1])

In [16]:
select(h, y).reshape(5, 6)

tensor([[0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [17]:
act(h, y)

tensor([1., 0., 1., 1., 1.])

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [18]:
def zeros(n: int) -> Tensor:
    return torch.zeros(n, dtype=torch.int64)


def ones(n: int) -> Tensor:
    return torch.ones(n, dtype=torch.int64)


def act_loss(x: Tensor, y: Tensor) -> int:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    x0, x1 = x[neg], x[pos]
    n0, n1 = x0.size(0), x1.size(0)
    
    a0_x0 = act(x0, zeros(n0))
    a1_x0 = act(x0, ones(n0))
    
    a1_x1 = act(x1, ones(n1))
    a0_x1 = act(x1, zeros(n1))
    
    neg_loss = (a0_x0 - 1).abs() + a1_x0
    pos_loss = (a1_x1 - 1).abs() + a0_x1

    return (neg_loss.sum() + pos_loss.sum()) / y.size(0)

In [19]:
# h = torch.tensor([
#     [1,1,0,  1,1,0],
#     [0,0,0,  0,0,0],
#     [1,0,1,  0,1,0],
#     [1,1,1,  0,0,0],
#     [0,0,0,  1,1,1]
# ])[:,:,None,None]

h = torch.tensor([
    [1,1,1,  0,0,0],
    [1,1,0,  0,0,0],
    [1,1,0,  0,0,0],
    [1,1,1,  1,0,0],
    [1,1,1,  1,0,0]
])[:,:,None,None]


y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

act_loss(h, all_neg)

tensor(0.4000)