In [31]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [32]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [33]:
sys.path.insert(0, SRC_DIR)

In [34]:
from model.zoo.common import AutoEncoder, encoder_block, decoder_block
from model.layers import conv3D, Lambda
from model.ops import act, identity, select, pool_gru

In [35]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [36]:
h.shape

torch.Size([5, 6, 1, 1])

In [37]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [38]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [39]:
act(h, all_pos)

tensor([0.6667, 0.0000, 0.3333, 0.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [40]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [41]:
act(h, all_neg)

tensor([0.6667, 0.0000, 0.6667, 1.0000, 0.0000], grad_fn=<ClampMaxBackward>)

In [42]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [43]:
y

tensor([1, 0, 1, 0, 1])

In [44]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [45]:
act(h, y)

tensor([0.6667, 0.0000, 0.3333, 1.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [46]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [47]:
def zeros(n: int) -> Tensor:
    return torch.zeros(n, dtype=torch.int64)


def ones(n: int) -> Tensor:
    return torch.ones(n, dtype=torch.int64)


def act_loss(x: Tensor, y: Tensor) -> Tensor:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    x0, x1 = x[neg], x[pos]
    n0, n1 = x0.size(0), x1.size(0)
    
    a0_x0 = act(x0, zeros(n0))
    a1_x0 = act(x0, ones(n0))
    
    a1_x1 = act(x1, ones(n1))
    a0_x1 = act(x1, zeros(n1))
    
    neg_loss = (a0_x0 - 1).abs() + a1_x0
    pos_loss = (a1_x1 - 1).abs() + a0_x1

    return (neg_loss.sum() + pos_loss.sum()) / y.size(0)

In [48]:
act_loss(h, y)

tensor(0.6667, grad_fn=<DivBackward0>)

In [49]:
DetectorOut = Tuple[FloatTensor, FloatTensor, FloatTensor]


def middle_block(in_ch: int, out_ch: int, kernel=3, stride=2, bn=True) -> nn.Module:
    conv = conv3D(in_ch, out_ch, kernel=kernel, stride=stride, bias=not bn)
    relu = nn.ReLU(inplace=True)
    layers = [conv, relu]
    if bn:
        layers.append(nn.BatchNorm3d(out_ch))
    return nn.Sequential(*layers)

In [50]:
class FakeDetector(nn.Module):
    def __init__(self, img_size: int, enc_depth: int, enc_width: int,
                 mid_layers: List[int], out_ch: int,
                 pool_size: Tuple[int, int] = None):
        super(FakeDetector, self).__init__()
        if img_size % 32:
            raise AttributeError("img_size should be a multiple of 32")
        if out_ch % 2:
            raise AttributeError("out_ch should be an even number")

        size_factor = 2 ** (enc_depth - 1)
        if size_factor > img_size:
            raise AttributeError(
                'Encoder dims (%d, %d) are incompatible with image '
                'size (%d, %d)' % (enc_depth, enc_width, img_size, img_size))
        emb_size = img_size // size_factor
        emb_ch = enc_width * size_factor

        self.encoder = AutoEncoder(in_ch=3, depth=enc_depth, width=enc_width)

        if img_size // 2 ** (enc_depth - 1) == 1:
            self.middle = Lambda(identity)
            rnn_in = emb_ch

        elif len(mid_layers) > 0:
            n_mid = len(mid_layers)
            mid_layers = [emb_ch] + mid_layers
            out_size = emb_size // 2 ** n_mid
            if not out_size:
                raise AssertionError('Too many middle layers...')
            layers = [middle_block(mid_layers[i], mid_layers[i + 1], stride=2)
                      for i in range(n_mid)]
            self.middle = nn.Sequential(*layers)
            rnn_in = mid_layers[-1] * out_size ** 2

        elif pool_size is not None:
            D, H = pool_size
            self.middle = nn.AdaptiveAvgPool3d((D, H, H))
            rnn_in = emb_ch * H ** 2

        else:
            raise AttributeError(
                'Both mid_layers and pool_size are missing. '
                'Unable to build model with provided configuration')

        self.rnn = nn.GRU(rnn_in, out_ch // 2)
        self.out = nn.Linear(out_ch, 1, bias=False)

    def forward(self, x: FloatTensor, y: LongTensor) -> DetectorOut:
        N, C, D, H, W = x.shape
        hidden, xs_hat = [], []

        for f in range(D):
            h, x_hat = self.encoder(x[:, :, f], y)
            hidden.append(h.unsqueeze(2))
            xs_hat.append(x_hat.unsqueeze(2))

        hidden = torch.cat(hidden, dim=2)
        xs_hat = torch.cat(xs_hat, dim=2)

        seq = self.middle(hidden).reshape(N, D, -1)
        seq_out = self.rnn(seq)
        seq_out = pool_gru(seq_out)
        y_hat = self.out(seq_out)

        return hidden, xs_hat, y_hat, seq

In [51]:
device = torch.device('cuda:1')

In [52]:
# img_size = 256

# model = FakeDetector(
#     img_size=img_size, 
#     enc_depth=5, 
#     enc_width=8, 
#     mid_layers=[128, 128],
#     out_ch=128
# ).to(device)

In [53]:
# N, n_frames = 16, 10

# x = torch.randn((N, 3, n_frames, img_size, img_size), device=device)
# y = torch.randint(2, (N,), device=device)

# h, x_hat, y_hat, seq = model(x, y)

In [54]:
# for layer in model.modules():
#     for cls in [nn.Conv2d, nn.Conv3d, nn.Linear]:
#         if isinstance(layer, cls):
#             nn.init.kaiming_uniform_(layer.weight, a=0)
#             if layer.bias is not None:
#                 layer.bias.data.zero_()

In [55]:
# for var, name in zip([x, h, x_hat, y_hat, seq], ['x', 'h', 'x_hat', 'y_hat', 'seq']):
#     mean, std = var.mean().item(), var.std().item()
#     shape = ', '.join(map(str, var.shape))
#     print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

In [64]:
from model.layers import conv2D
from model.ops import select, pool_gru


class ShrinkFork(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, fork_depth: int):
        super(ShrinkFork, self).__init__()
        self.main = encoder_block(in_ch, out_ch)
        aux = [encoder_block(out_ch * 2**p, out_ch * 2**(p+1), stride=2)
               for p in range(fork_depth)]
        self.fork = nn.Sequential(*aux)
    
    def forward(self, x):
        x = self.main(x)
        c = self.fork(x)
        return x, c


class Bilbo(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], enc_depth: int, 
                 enc_width: int, fork_depth: int, rnn_width: int):
        super(Bilbo, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")
        max_fork_depth = math.log2(H) - 1 
        if fork_depth > max_fork_depth:
            raise AttributeError(
                "fork_depth should be <= {} given the image_size "
                " ({}, {})".format(int(max_fork_depth), H, H))
        
        self.stem = encoder_block(C, enc_width, stride=1, bn=False)
        encoder_layers = [ShrinkFork(enc_width * 2**i, enc_width * 2**(i+1), 
                                     fork_depth=(fork_depth-i))
                          for i in range(0, enc_depth - 1)]
        self.encoder = nn.ModuleList(encoder_layers)
        
        decoder_layers = [decoder_block(enc_width * 2**(i+1), enc_width * 2**i)
                          for i in sorted(range(0, enc_depth - 1), reverse=True)]
        last = conv2D(enc_width, C, kernel=3, stride=1)
        self.decoder = nn.Sequential(*decoder_layers, last, nn.Tanh())
        
        rnn_in = enc_width * (H * W / 4) / 2**(fork_depth - 1) * (enc_depth - 1)
        assert not math.modf(rnn_in)[0]
        rnn_in = int(rnn_in)
        self.gru = nn.GRU(rnn_in, rnn_width, bidirectional=True)
        self.out = nn.Linear(rnn_width * 4, 1, bias=False)

    def forward(self, x: FloatTensor, y: LongTensor):
        N, C, D, H, W = x.shape
        hidden, x_rec, features = [], [], []
        
        for f in range(D):
            cc = []
            h = self.stem(x[:, :, f])
            for i in range(len(self.encoder)):
                h, c = self.encoder[i](h)
                cc.append(c)
            hidden.append(h.unsqueeze(2))
            features.append(torch.cat(cc, dim=1).unsqueeze(2))
            
            hc = select(h, y)
            x1 = self.decoder(hc)
            x_rec.append(x1.unsqueeze(2))
            
        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2)
        features = torch.cat(features, dim=2)
        
        gru_out = self.gru(features.reshape(N, D, -1))
        gru_out = pool_gru(gru_out)
        
        y_hat = self.out(gru_out)
        return hidden, x_rec, y_hat, features

In [65]:
deivce = torch.device('cuda:1')

In [161]:
N, D, H, W = 2, 5, 128, 128

x = torch.randn((N, 3, D, H, W)).to(device)
y = torch.randint(2, (N,)).to(device)

In [67]:
bilbo = Bilbo(
    image_shape=(3, H, W), 
    enc_depth=5, 
    enc_width=8, 
    fork_depth=5,
    rnn_width=64
).to(device)

In [68]:
out = bilbo(x, y)

for e in out:
    print(e.shape)

torch.Size([2, 128, 5, 8, 8])
torch.Size([2, 3, 5, 128, 128])
torch.Size([2, 1])
torch.Size([2, 2048, 5, 2, 2])


In [69]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat', 'y_hat', 'feat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.001, std 1.000 | (2, 3, 5, 128, 128)
h     | mean -0.000, std 1.000 | (2, 128, 5, 8, 8)
x_hat | mean -0.006, std 0.668 | (2, 3, 5, 128, 128)
y_hat | mean  0.325, std 0.180 | (2, 1)
feat  | mean -0.000, std 0.999 | (2, 2048, 5, 2, 2)


In [110]:
class Frodo(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], enc_depth: int, enc_width: int):
        super(Frodo, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")
            
        max_depth = math.log2(H) + 1 
        if enc_depth > max_depth:
            raise AttributeError(
                "enc_depth should be <= {} given the image_size "
                " ({}, {})".format(int(max_depth), H, H))
        
        stem = encoder_block(C, enc_width, stride=1, bn=False)
        encoder_layers = [encoder_block(enc_width * 2**i, enc_width * 2**(i+1))
                          for i in range(0, enc_depth - 1)]
        self.encoder = nn.Sequential(stem, *encoder_layers)
        
        decoder_layers = [decoder_block(enc_width * 2**(i+1), enc_width * 2**i)
                          for i in sorted(range(0, enc_depth - 1), reverse=True)]
        last = conv2D(enc_width, C, kernel=3, stride=1)
        self.decoder = nn.Sequential(*decoder_layers, last, nn.Tanh())

    def forward(self, x: FloatTensor, y: LongTensor):
        N, C, D, H, W = x.shape
        hidden, x_rec = [], []
        
        for f in range(D):
            h = self.encoder(x[:, :, f])
            hc = select(h, y)
            x1 = self.decoder(hc)
            
            hidden.append(h.unsqueeze(2))
            x_rec.append(x1.unsqueeze(2))
            
        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2)
        return hidden, x_rec

In [154]:
frodo = Frodo(
    image_shape=(3, H, W), 
    enc_depth=5, 
    enc_width=8,
).to(device)

In [155]:
out = frodo(x, y)

In [156]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.001, std 1.000 | (2, 3, 5, 128, 128)
h     | mean  0.000, std 1.000 | (2, 128, 5, 8, 8)
x_hat | mean -0.020, std 0.622 | (2, 3, 5, 128, 128)


In [90]:
def zeros(n: int, device: torch.device) -> Tensor:
    return torch.zeros(n, dtype=torch.int64, device=device)


def ones(n: int, device: torch.device) -> Tensor:
    return torch.ones(n, dtype=torch.int64, device=device)

In [107]:
from model.ops import reshape_as

In [131]:
def act_5dims(h: Tensor, y: Tensor, average=False) -> Tensor:
    N, C, D, H, W = h.shape
    y = reshape_as(y, h)
    h0, h1 = h.chunk(2, dim=1)
    a = h0 * (1 - y) + h1 * y

    if average:
        n_el = a.numel() / max(N, 1)
        a = a.abs().sum(tuple(range(1, a.ndim))) / n_el
    else:
        n_el = a.numel() / max(N * D, 1)
        a = a.abs().sum((1, 3, 4)) / n_el

    # For simplicity, and without losing generality, 
    # we constrain a(x) to be equal to 1
    return a.clamp_max_(1)

In [142]:
out[1].shape

torch.Size([2, 3, 5, 128, 128])

In [163]:
def decide(h: Tensor, device: torch.device) -> Tensor:
    N = h.size(0)
    a0 = act_5dims(h, zeros(N, device), average=False).unsqueeze(2)
    a1 = act_5dims(h, ones(N, device), average=False).unsqueeze(2)
    a = torch.cat([a0, a1], dim=2)
    _, y_pred = torch.max(a, dim=2)
    return y_pred.float().mean(1)

In [164]:
h, x_rec = frodo(x, y)

decide(h, device)

tensor([0.2000, 0.4000], device='cuda:1')