In [1]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [3]:
sys.path.insert(0, SRC_DIR)

In [4]:
from model.autoencoder import AutoEncoder
from model.layers import conv3D, Lambda
from model.ops import act, identity, select, pool_gru

In [5]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [6]:
h.shape

torch.Size([5, 6, 1, 1])

In [7]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [8]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [9]:
act(h, all_pos)

tensor([1., 0., 1., 0., 1.], grad_fn=<CeilBackward>)

In [10]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [11]:
act(h, all_neg)

tensor([1., 0., 1., 1., 0.], grad_fn=<CeilBackward>)

In [12]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [15]:
act(h, y)

tensor([1., 0., 1., 1., 1.], grad_fn=<CeilBackward>)

In [16]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [17]:
def zeros(n: int) -> Tensor:
    return torch.zeros(n, dtype=torch.int64)


def ones(n: int) -> Tensor:
    return torch.ones(n, dtype=torch.int64)


def act_loss(x: Tensor, y: Tensor) -> Tensor:
    pos = y.nonzero().reshape(-1)
    neg = (y - 1).nonzero().reshape(-1)
    x0, x1 = x[neg], x[pos]
    n0, n1 = x0.size(0), x1.size(0)
    
    a0_x0 = act(x0, zeros(n0))
    a1_x0 = act(x0, ones(n0))
    
    a1_x1 = act(x1, ones(n1))
    a0_x1 = act(x1, zeros(n1))
    
    neg_loss = (a0_x0 - 1).abs() + a1_x0
    pos_loss = (a1_x1 - 1).abs() + a0_x1

    return (neg_loss.sum() + pos_loss.sum()) / y.size(0)

In [18]:
act_loss(h, y)

tensor(0.6000, grad_fn=<DivBackward0>)

In [19]:
DetectorOut = Tuple[FloatTensor, FloatTensor, FloatTensor]


def middle_block(in_ch: int, out_ch: int, kernel=3, stride=2, bn=True) -> nn.Module:
    conv = conv3D(in_ch, out_ch, kernel=kernel, stride=stride, bias=not bn)
    relu = nn.ReLU(inplace=True)
    layers = [conv, relu]
    if bn:
        layers.append(nn.BatchNorm3d(out_ch))
    return nn.Sequential(*layers)

In [20]:
class FakeDetector(nn.Module):
    def __init__(self, img_size: int, enc_depth: int, enc_width: int,
                 mid_layers: List[int], out_ch: int,
                 pool_size: Tuple[int, int] = None):
        super(FakeDetector, self).__init__()
        if img_size % 32:
            raise AttributeError("img_size should be a multiple of 32")
        if out_ch % 2:
            raise AttributeError("out_ch should be an even number")

        size_factor = 2 ** (enc_depth - 1)
        if size_factor > img_size:
            raise AttributeError(
                'Encoder dims (%d, %d) are incompatible with image '
                'size (%d, %d)' % (enc_depth, enc_width, img_size, img_size))
        emb_size = img_size // size_factor
        emb_ch = enc_width * size_factor

        self.encoder = AutoEncoder(in_ch=3, depth=enc_depth, width=enc_width)

        if img_size // 2 ** (enc_depth - 1) == 1:
            self.middle = Lambda(identity)
            rnn_in = emb_ch

        elif len(mid_layers) > 0:
            n_mid = len(mid_layers)
            mid_layers = [emb_ch] + mid_layers
            out_size = emb_size // 2 ** n_mid
            if not out_size:
                raise AssertionError('Too many middle layers...')
            layers = [middle_block(mid_layers[i], mid_layers[i + 1], stride=2)
                      for i in range(n_mid)]
            self.middle = nn.Sequential(*layers)
            rnn_in = mid_layers[-1] * out_size ** 2

        elif pool_size is not None:
            D, H = pool_size
            self.middle = nn.AdaptiveAvgPool3d((D, H, H))
            rnn_in = emb_ch * H ** 2

        else:
            raise AttributeError(
                'Both mid_layers and pool_size are missing. '
                'Unable to build model with provided configuration')

        self.rnn = nn.GRU(rnn_in, out_ch // 2)
        self.out = nn.Linear(out_ch, 1, bias=False)

    def forward(self, x: FloatTensor, y: LongTensor) -> DetectorOut:
        N, C, D, H, W = x.shape
        hidden, xs_hat = [], []

        for f in range(D):
            h, x_hat = self.encoder(x[:, :, f], y)
            hidden.append(h.unsqueeze(2))
            xs_hat.append(x_hat.unsqueeze(2))

        hidden = torch.cat(hidden, dim=2)
        xs_hat = torch.cat(xs_hat, dim=2)

        seq = self.middle(hidden).reshape(N, D, -1)
        seq_out = self.rnn(seq)
        seq_out = pool_gru(seq_out)
        y_hat = self.out(seq_out)

        return hidden, xs_hat, y_hat, seq

In [21]:
device = torch.device('cuda:1')

In [22]:
img_size = 256

model = FakeDetector(
    img_size=img_size, 
    enc_depth=5, 
    enc_width=8, 
    mid_layers=[128, 128],
    out_ch=128
).to(device)

In [23]:
N, n_frames = 16, 10

x = torch.randn((N, 3, n_frames, img_size, img_size), device=device)
y = torch.randint(2, (N,), device=device)

h, x_hat, y_hat, seq = model(x, y)

In [24]:
# for layer in model.modules():
#     for cls in [nn.Conv2d, nn.Conv3d, nn.Linear]:
#         if isinstance(layer, cls):
#             nn.init.kaiming_uniform_(layer.weight, a=0)
#             if layer.bias is not None:
#                 layer.bias.data.zero_()

In [25]:
for var, name in zip([x, h, x_hat, y_hat, seq], ['x', 'h', 'x_hat', 'y_hat', 'seq']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.000, std 1.000 | (16, 3, 10, 256, 256)
h     | mean -0.000, std 1.000 | (16, 128, 10, 16, 16)
x_hat | mean -0.031, std 0.654 | (16, 3, 10, 256, 256)
y_hat | mean  0.708, std 0.052 | (16, 1)
seq   | mean -0.000, std 1.000 | (16, 10, 2048)


In [26]:
# conv_2 = nn.Sequential(
#     nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=True),
#     # nn.BatchNorm2d(8),
#     nn.ReLU(inplace=True)
# ).to(device)

# conv_3 = nn.Sequential(
#     conv3D(3, 8, kernel=3, stride=1, pad=1),
#     nn.ReLU(inplace=True)
# ).to(device)

In [27]:
# for model in [conv_2, conv_3]:
#     for layer in model:
#         for cls in [nn.Conv2d, nn.Conv3d, nn.Linear]:
#             if isinstance(layer, cls):
#                 nn.init.kaiming_normal_(layer.weight, a=0)
#                 if layer.bias is not None:
#                     layer.bias.data.zero_()

In [28]:
# x2 = conv_2(x[:, :, 0])
# x3 = conv_3(x)

# for var, name in zip([x2, x3], ['x2', 'x3']):
#     mean, std = var.mean().item(), var.std().item()
#     shape = ', '.join(map(str, var.shape))
#     print('{:5s}: mean {:.03f}, std {:.03f}, ({})'.format(name, mean, std, shape))