In [1]:
from functools import partial

import torch
import torch.nn as nn
from torch import Tensor

import torch.nn.functional as F

In [2]:
def enc_block(in_ch: int, out_ch: int, kernel=3, 
              stride=2, pad=0, bn=True):
    conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel, 
                     stride=stride, padding=pad)
    relu = nn.ReLU(inplace=True)
    if bn:
        layers = [conv, nn.BatchNorm2d(out_ch), relu]
    else:
        layers = [conv, relu]
    return nn.Sequential(*layers)

In [3]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel=3, 
                 scale=2, pad=0, bn=True):
        super(DecoderBlock, self).__init__()
        self.upsample = partial(F.interpolate, 
                                scale_factor=scale, 
                                mode='nearest')
        conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel, 
                         stride=1, padding=pad)
        relu = nn.ReLU(inplace=True)
        if bn:
            layers = [conv, nn.BatchNorm2d(out_ch), relu]
        else:
            layers = [conv, relu]
        self.layers = nn.Sequential(*layers)
  
    def forward(self, x):
        x = self.upsample(x)
        out = self.layers(x)
        return out

In [4]:
def select_old(x: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = x.shape
    half_C = C // 2

    low  = half_C * y
    high = half_C * (y + 1)

    x = x.clone()
    for i in range(N):
        x[i, low[i]:high[i]] = 0
    return x

In [5]:
def select(h: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = h.shape
    y = y.reshape(N, 1, 1, 1)
    
    h0, h1 = h.chunk(2, dim=1)
    h0 = h0 * (1 - y)
    h1 = h1 * y
    h = torch.cat([h0, h1], dim=1)
    return h

In [6]:
def act(h: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = h.shape

    y = y.reshape(N, 1, 1, 1)
    h0, h1 = h.chunk(2, dim=1)
    a = h0 * (1 - y) + h1 * y
    
    n_el = C * H * W / 2
    a = a.abs().sum((1, 2, 3)) / n_el
    
    # For simplicity, and without losing generality, 
    # we constrain a(x) to be equal to 1
    return a.clamp_max_(1).ceil()

In [7]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
])[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [8]:
h.reshape(5, 6)

tensor([[1, 1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 1, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [9]:
select(h, all_pos).reshape(5, 6)

tensor([[0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [10]:
act(h, all_pos)

tensor([1., 0., 1., 0., 1.])

In [11]:
select(h, all_neg).reshape(5, 6)

tensor([[1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0]])

In [12]:
act(h, all_neg)

tensor([1., 0., 1., 1., 0.])

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]])

In [15]:
act(h, y)

tensor([1., 0., 1., 1., 1.])

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [16]:
def zeros(n: int) -> Tensor:
    return torch.zeros(n, dtype=torch.int64)


def ones(n: int) -> Tensor:
    return torch.ones(n, dtype=torch.int64)


def act_loss(x: Tensor, y: Tensor) -> int:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    x0, x1 = x[neg], x[pos]
    n0, n1 = x0.size(0), x1.size(0)
    
    a0_x0 = act(x0, zeros(n0))
    a1_x0 = act(x0, ones(n0))
    
    a1_x1 = act(x1, ones(n1))
    a0_x1 = act(x1, zeros(n1))
    
    neg_loss = (a0_x0 - 1).abs() + a1_x0
    pos_loss = (a1_x1 - 1).abs() + a0_x1

    return (neg_loss.sum() + pos_loss.sum()) / y.size(0)

In [17]:
act_loss(h, y)

tensor(0.6000)

In [18]:
class Autoencoder(nn.Module):
    def __init__(self, in_ch: int, depth: int, size=8, pad=1):
        super(Autoencoder, self).__init__()
        self.encoder = Autoencoder._build_encoder(in_ch, depth, size, pad)
        self.decoder = Autoencoder._build_decoder(in_ch, depth, size, pad)
        
    @staticmethod
    def _build_encoder(in_ch: int, depth: int, size: int, pad: int) -> nn.Module:        
        stem = enc_block(in_ch, size, stride=1, pad=pad, bn=False)
        main = [enc_block(size * 2**i, size * 2**(i+1), pad=pad) 
                for i in range(0, depth - 1)]
        return nn.Sequential(stem, *main)
    
    @staticmethod
    def _build_decoder(out_ch: int, depth: int, size: int, pad: int) -> nn.Module:
        main = [DecoderBlock(size * 2**(i+1), size * 2**i, pad=pad) 
                for i in sorted(range(0, depth - 1), reverse=True)]
        last = nn.Conv2d(size, out_ch, 3, stride=1, padding=pad)
        return nn.Sequential(*main, last, nn.Tanh())
        
    def forward(self, x, y):
        h = self.encoder(x)
        hc = select(h, y)
        x_hat = self.decoder(hc)
        return h, x_hat

In [19]:
model = Autoencoder(3, 5, pad=1)
model

Autoencoder(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affi

In [20]:
N = 5
x = torch.rand((N, 3, 256, 256))
y = torch.randint(2, (N,))

h, x_hat = model(x, y)
h.shape, x_hat.shape

(torch.Size([5, 128, 16, 16]), torch.Size([5, 3, 256, 256]))