In [1]:
import math

from functools import partial
from typing import Callable, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch import Tensor

import torch.nn.functional as F

In [2]:
def enc_block(in_ch: int, out_ch: int, kernel=3, 
              stride=2, pad=0, bn=True):
    conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel, 
                     stride=stride, padding=pad)
    relu = nn.ReLU(inplace=True)
    if bn:
        layers = [conv, nn.BatchNorm2d(out_ch), relu]
    else:
        layers = [conv, relu]
    return nn.Sequential(*layers)

In [3]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, kernel=3, 
                 scale=2, pad=0, bn=True):
        super(DecoderBlock, self).__init__()
        self.upsample = partial(F.interpolate, 
                                scale_factor=scale, 
                                mode='nearest')
        conv = nn.Conv2d(in_ch, out_ch, kernel_size=kernel, 
                         stride=1, padding=pad)
        relu = nn.ReLU(inplace=True)
        if bn:
            layers = [conv, nn.BatchNorm2d(out_ch), relu]
        else:
            layers = [conv, relu]
        self.layers = nn.Sequential(*layers)
  
    def forward(self, x):
        x = self.upsample(x)
        out = self.layers(x)
        return out

In [4]:
def select_old(x: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = x.shape
    half_C = C // 2

    low  = half_C * y
    high = half_C * (y + 1)

    x = x.clone()
    for i in range(N):
        x[i, low[i]:high[i]] = 0
    return x

In [5]:
def select(h: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = h.shape
    y = y.reshape(N, 1, 1, 1)
    
    h0, h1 = h.chunk(2, dim=1)
    h0 = h0 * (1 - y)
    h1 = h1 * y
    h = torch.cat([h0, h1], dim=1)
    return h

In [6]:
def act(h: Tensor, y: Tensor) -> Tensor:
    N, C, H, W = h.shape

    y = y.reshape(N, 1, 1, 1)
    h0, h1 = h.chunk(2, dim=1)
    a = h0 * (1 - y) + h1 * y
    
    n_el = C * H * W / 2
    a = a.abs().sum((1, 2, 3)) / n_el
    
    # For simplicity, and without losing generality, 
    # we constrain a(x) to be equal to 1
    return a.clamp_max_(1).ceil()

In [7]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [8]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [9]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [10]:
act(h, all_pos)

tensor([1., 0., 1., 0., 1.], grad_fn=<CeilBackward>)

In [11]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [12]:
act(h, all_neg)

tensor([1., 0., 1., 1., 0.], grad_fn=<CeilBackward>)

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [15]:
act(h, y)

tensor([1., 0., 1., 1., 1.], grad_fn=<CeilBackward>)

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [16]:
def zeros(n: int) -> Tensor:
    return torch.zeros(n, dtype=torch.int64)


def ones(n: int) -> Tensor:
    return torch.ones(n, dtype=torch.int64)


def act_loss(x: Tensor, y: Tensor) -> Tensor:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    x0, x1 = x[neg], x[pos]
    n0, n1 = x0.size(0), x1.size(0)
    
    a0_x0 = act(x0, zeros(n0))
    a1_x0 = act(x0, ones(n0))
    
    a1_x1 = act(x1, ones(n1))
    a0_x1 = act(x1, zeros(n1))
    
    neg_loss = (a0_x0 - 1).abs() + a1_x0
    pos_loss = (a1_x1 - 1).abs() + a0_x1

    return (neg_loss.sum() + pos_loss.sum()) / y.size(0)

In [17]:
act_loss(h, y)

tensor(0.6000, grad_fn=<DivBackward0>)

In [18]:
# def reconstruction_loss(x: Tensor, x_hat: Tensor) -> Tensor:
#     return (x - x_hat).abs().sum() / x.numel()


def rec_loss(x: Tensor, x_hat: Tensor) -> Tensor:
    return F.l1_loss(x_hat, x, reduction='mean')

In [19]:
class Autoencoder(nn.Module):
    def __init__(self, in_ch: int, depth: int, size=8, pad=1):
        super(Autoencoder, self).__init__()
        self.encoder = Autoencoder._build_encoder(in_ch, depth, size, pad)
        self.decoder = Autoencoder._build_decoder(in_ch, depth, size, pad)
        
    @staticmethod
    def _build_encoder(in_ch: int, depth: int, size: int, pad: int) -> nn.Module:        
        stem = enc_block(in_ch, size, stride=1, pad=pad, bn=False)
        main = [enc_block(size * 2**i, size * 2**(i+1), pad=pad) 
                for i in range(0, depth - 1)]
        return nn.Sequential(stem, *main)
    
    @staticmethod
    def _build_decoder(out_ch: int, depth: int, size: int, pad: int) -> nn.Module:
        main = [DecoderBlock(size * 2**(i+1), size * 2**i, pad=pad) 
                for i in sorted(range(0, depth - 1), reverse=True)]
        last = nn.Conv2d(size, out_ch, 3, stride=1, padding=pad)
        return nn.Sequential(*main, last, nn.Tanh())
        
    def forward(self, x, y) -> Tuple[Tensor, Tensor]:
        h = self.encoder(x)
        hc = select(h, y)
        x_hat = self.decoder(hc)
        return h, x_hat

In [20]:
model = Autoencoder(3, 8)

In [21]:
N = 5
S = 128
x = torch.rand((N, 3, S, S))
y = torch.randint(2, (N,))

h, x_hat = model(x, y)

_, C, S, _ = h.shape

h.shape, x_hat.shape, C * S**2

(torch.Size([5, 1024, 1, 1]), torch.Size([5, 3, 128, 128]), 1024)

In [22]:
a_l = act_loss(h, y)
r_l = rec_loss(x, x_hat)
loss = a_l + 0.1 * r_l

a_l, r_l, loss

(tensor(1., grad_fn=<DivBackward0>),
 tensor(0.4331, grad_fn=<L1LossBackward>),
 tensor(1.0433, grad_fn=<AddBackward0>))

In [47]:
def identity(x: Tensor) -> Tensor:
    return x


def pool_gru(out_gru: Tensor) -> Tensor:
    out, _ = out_gru
    out_avg = torch.mean(out, dim=1)
    out_max, _ = torch.max(out, dim=1)
    return torch.cat([out_avg, out_max], dim=1)


class FakeDetector(nn.Module):
    def __init__(self, img_size: int, enc_dim: Tuple[int, int], 
                 seq_size: Tuple[int, int]):
        super(FakeDetector, self).__init__()
        self.autoenc = FakeDetector._build_encoder(img_size, enc_dim)
        seq_in, seq_out = seq_size
        self.pool = FakeDetector._build_pooling(img_size, enc_dim, seq_in)
        self.seq_model = nn.GRU(seq_in, seq_out)
        self.out = nn.Linear(seq_out*2, 1, bias=False)
        
    
    @staticmethod
    def _build_encoder(img_size: int, enc_dim: Tuple[int, int]) -> Autoencoder:
        depth, size = enc_dim
        if img_size % 32:
            raise AttributeError('Image size should be a multiple of 32')  
        return Autoencoder(in_ch=3, depth=depth, size=size, pad=1)
    
    
    @staticmethod
    def _build_pooling(img_size: int, enc_dim: Tuple[int, int], 
                       seq_in: int) -> Callable:
        enc_depth, enc_size = enc_dim
        size_factor = 2**(enc_depth-1)
        if size_factor > img_size:
            raise AttributeError('Encoder is too deep (%d) for spatial '
                                 'dim (%d, %d)' % (enc_depth, img_size, img_size))
        emb_S = img_size // size_factor
        emb_C = enc_size * size_factor
        
        if emb_C * emb_S**2 == seq_in:
            return identity
        else:
            out_S = math.sqrt(seq_in / emb_S / enc_size)
            if math.modf(out_S)[0] > 0:
                raise AttributeError('Sequence input size is incompatible '
                                     'with encoder out dims')
            else:
                out_S = int(out_S)
                in_dim = (emb_C, emb_S, emb_S)
                out_dim = (emb_C, out_S, out_S)
                print('Using avg pooling: {} -> {}'.format(in_dim, out_dim))
                return nn.AdaptiveAvgPool3d(out_dim)
    
    
    def forward(self, x: Tensor, y: Tensor):
        N, N_fr, C, H, W = x.shape
        hidden, xs_hat = [], []
        
        for f in range(N_fr):
            h, x_hat = self.autoenc(x[:,f], y)
            hidden.append(h[:,None])
            xs_hat.append(x_hat[:,None])
            
        hidden = torch.cat(hidden, dim=1)
        xs_hat = torch.cat(xs_hat, dim=1)
        
        seq = self.pool(hidden).reshape(N, N_fr, -1)
        seq_out = self.seq_model(seq)
        seq_out = pool_gru(seq_out)
        y_hat = self.out(seq_out)
        
        return hidden, xs_hat, y_hat

In [51]:
img_size = 256
depth = 9
model = FakeDetector(img_size=256, enc_dim=(5, 8), seq_size=(2048, 64))

Using avg pooling: (128, 16, 16) -> (128, 4, 4)


In [44]:
N, n_frames = 2, 5

x = torch.rand((N, n_frames, 3, img_size, img_size))
y = torch.randint(2, (N,))

h, x_hat, y_hat = model(x, y)

h.shape, x_hat.shape, y_hat.shape

(torch.Size([2, 5, 256, 16, 16]),
 torch.Size([2, 5, 3, 256, 256]),
 torch.Size([2, 1]))