In [1]:
from functools import partial

import torch
import torch.nn as nn
from torch import Tensor

import torch.nn.functional as F

In [2]:
def enc_block(in_ch: int, out_ch: int, kernel_size=3, stride=2, bn=True):
    conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, 
                     stride=stride, padding=1)
    relu = nn.ReLU(inplace=True)
    if bn:
        layers = [conv, nn.BatchNorm2d(out_ch), relu]
    else:
        layers = [conv, relu]
    return nn.Sequential(*layers)

In [3]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel_size=3, 
                 scale_factor=2, bn=True):
        super().__init__()
        self.upsample = partial(F.interpolate, 
                                scale_factor=scale_factor, 
                                mode='nearest')
        conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, 
                         stride=1, padding=1)
        relu = nn.ReLU(inplace=True)
        if bn:
            layers = [conv, nn.BatchNorm2d(out_ch), relu]
        else:
            layers = [conv, relu]
        self.layers = nn.Sequential(*layers)
  
    def forward(self, x):
        x = self.upsample(x)
        out = self.layers(x)
        return out

In [4]:
def select_block(x, y):
    N, C, H, W = x.shape
    half_C = C // 2

    low  = half_C * y
    high = half_C * (y + 1)

    x = x.clone()
    for i in range(N):
        x[i, low[i]:high[i]] = 0
    return x

In [5]:
model = nn.Sequential(
    enc_block(3, 8, stride=1, bn=False),
    enc_block(8, 16),
    enc_block(16, 32),
    enc_block(32, 64),
    enc_block(64, 128)
#     DecoderBlock(32, 16),
#     DecoderBlock(16, 8)
)

In [6]:
x = torch.rand((1, 3, 256, 256))
x1 = model(x)
x1.shape

torch.Size([1, 128, 16, 16])

In [257]:
def act(h: Tensor, y: Tensor, cls: int, eps=1e-7) -> Tensor:
    h = h[y == cls].clone()
    N, C, H, W = h.shape
    
    half_C = C // 2
    low  = half_C * cls
    high = half_C * (cls + 1)
    
    for i in range(N):
        h[i, low:high] = 0
        
    n_el = h.numel() / max(1, N * 2)
    act = h.abs().sum((1, 2, 3)) / n_el
    return act

In [258]:
h = torch.ones(3, 8, 1, 1)
y = torch.tensor([1, 0, 0])

act(h, y, 0), act(h, y, 1)

(tensor([1., 1.]), tensor([1.]))

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [259]:
def act_loss(h: Tensor, y: Tensor) -> int:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    h_neg, h_pos = h[neg], h[pos]
    
    neg_loss = (act(h_neg, y[neg], 0) - 1).abs().sum() + act(h_neg, y[pos], 1).sum()
    pos_loss = (act(h_pos, y[pos], 1) - 1).abs().sum() + act(h_pos, y[neg], 0).sum()
    print('0: {}'.format(neg_loss))
    print('1: {}'.format(pos_loss))
    return (neg_loss + pos_loss) / y.size(0)

In [284]:
emb_size = 6

h = torch.zeros(4, emb_size, 1, 1)
y = torch.tensor([1, 0, 1, 0])

pos = y.nonzero().reshape(-1)
neg = (y - 1).nonzero().reshape(-1)

# h[pos] = 0.2

h[neg, emb_size//2:emb_size] = 0 # neg
h[neg, 0:emb_size//2] = 12      # pos

h[pos, emb_size//2:emb_size] = 12 # neg
h[pos, 0:emb_size//2] = 0        # pos

act_loss(h, y)

0: 26.0
1: 26.0


tensor(13.)

In [36]:
h.nonzero()

tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 2, 0, 0],
        [3, 0, 0, 0],
        [3, 1, 0, 0],
        [3, 2, 0, 0]])

In [58]:
act(h, 1)

tensor([0.2500, 0.0000, 0.2500, 0.0000])