In [1]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [3]:
sys.path.insert(0, SRC_DIR)

In [4]:
from model.zoo.common import encoder_block, decoder_block
from model.layers import conv2D, conv3D, Lambda, stack_enc_blocks, stack_dec_blocks, MaxMean2D, EncoderBlock, DecoderBlock
from model.loss import activation_loss
from model.ops import act, identity, select, pool_gru
from model.efficient_attention.efficient_attention import EfficientAttention

In [5]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [6]:
h.shape

torch.Size([5, 6, 1, 1])

In [7]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [8]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [9]:
act(h, all_pos)

tensor([0.6667, 0.0000, 0.3333, 0.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [10]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [11]:
act(h, all_neg)

tensor([0.6667, 0.0000, 0.6667, 1.0000, 0.0000], grad_fn=<ClampMaxBackward>)

In [12]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [15]:
act(h, y)

tensor([0.6667, 0.0000, 0.3333, 1.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [16]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [17]:
activation_loss(h, y)

tensor(0.6667, grad_fn=<DivBackward0>)

In [18]:
device = torch.device('cuda:0')

In [19]:
# N, D, H, W = 2, 5, 128, 128

# x = torch.randn((N, 3, D, H, W)).to(device)
# y = torch.randint(2, (N,)).to(device)

In [20]:
N, D, H, W = 16, 10, 256, 256

x = torch.randn((N, 3, D, H, W)).to(device)
y = torch.randint(2, (N,)).to(device)

In [21]:
class RNNBlock(nn.Module):
    def __init__(self, in_ch: int, rnn_ch: int, cls=nn.GRU, bidirectional=False):
        super().__init__()
        self.rnn = cls(in_ch, rnn_ch, bidirectional=bidirectional)
        self.out_ch = rnn_ch * 2 * (2 if bidirectional else 1)

    def forward(self, x):
        x = x.permute(2, 0, 1)
        x, h = self.rnn(x)
        x_max, _ = x.max(0)
        x_mean = x.mean(0)
        x = torch.cat([x_max, x_mean], dim=1)
        return x

In [22]:
def stack_enc_blocks(width: int, start: int, end: int, wide=False, 
                     attention: Optional[List[int]] = None):
    if attention is None:
        attention = []
    layers = []
    for i in range(start, end):
        in_ch = width * 2**i
        out_ch = width * 2**(i+1)
        if i in attention:
            assert not in_ch % 4
            att = EfficientAttention(in_ch, in_ch, in_ch//4, in_ch)
            layers.append(att) 
        h_ch = out_ch if wide else in_ch
        block = EncoderBlock(in_ch, out_ch, h_ch)
        layers.append(block)
    return layers


def stack_dec_blocks(width: int, start: int, end: int, wide=False, 
                     attention: Optional[List[int]] = None):
    layers = []
    for i in sorted(range(start, end), reverse=True):
        in_ch = width * 2**(i+1)
        out_ch = width * 2**i
        h_ch = in_ch if wide else out_ch
        block = DecoderBlock(in_ch, out_ch, h_ch)
        layers.append(block)
        if i in attention:
            assert not out_ch % 4 
            att = EfficientAttention(out_ch, out_ch, out_ch//4, out_ch)
            layers.append(att) 
    return layers

In [23]:
class SenyaGanjubas(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], width: int,
                 enc_depth: int, aux_depth: int, rnn_dim: int, wide=False,
                 p_emb_drop=0.1, p_out_drop=0.1, train=True,
                 enc_att: Optional[List[int]] = None, 
                 dec_att: Optional[List[int]] = None,
                 aux_att: Optional[List[int]] = None):
        super(SenyaGanjubas, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")

        max_depth = math.log2(H)
        if enc_depth + aux_depth > max_depth:
            raise AttributeError(
                f"enc_depth + aux_depth should be <= {int(max_depth)} given the "
                f"image_size ({H}, {H})")

        if width % 2:
            raise AttributeError("width must be even number")

        stem = [conv2D(C, width, bias=False), nn.ReLU(inplace=True)]
        encoder = stack_enc_blocks(width, 0, enc_depth - 1, wide=wide, attention=enc_att)
        self.encoder = nn.Sequential(*stem, *encoder)

        decoder = stack_dec_blocks(width, 0, enc_depth - 1, wide=wide, attention=dec_att)
        dec_out = conv2D(width, C, kernel=3, stride=1, bias=False)
        self.decoder = nn.Sequential(*decoder, dec_out, nn.Tanh())

        for i in range(2):
            if aux_att is not None:
                att = [a + enc_depth - 1 for a in aux_att]
            aux_branch = stack_enc_blocks(
                width // 2, enc_depth - 1, enc_depth - 1 + aux_depth, 
                wide=wide, attention=att)
            if p_emb_drop > 0:
                aux_branch = [nn.Dropout2d(p=p_emb_drop)] + aux_branch
            setattr(self, f'aux_{i}', nn.Sequential(*aux_branch))

        self.pool = MaxMean2D()
        aux_dim = width * 2 ** (enc_depth + aux_depth)
        self.rnn = RNNBlock(aux_dim, rnn_dim, bidirectional=True)

        out_layers = [nn.Linear(self.rnn.out_ch, 1, bias=False)]
        if p_out_drop > 0:
            out_layers = [nn.Dropout(p=p_out_drop)] + out_layers
        self.out = nn.Sequential(*out_layers)
        self.is_train = train

    def forward(self, x: FloatTensor, y: Optional[LongTensor] = None):
        N, C, D, H, W = x.shape
        hidden, x_rec, aux_0, aux_1 = [], [], [], []

        for f in range(D):
            h = self.encoder(x[:, :, f])

            if self.is_train:
                hc = select(h, y)
                x1 = self.decoder(hc).unsqueeze(2)
                x_rec.append(x1)

            h0, h1 = torch.chunk(h, 2, dim=1)
            a0 = self.pool(self.aux_0(h0))
            a1 = self.pool(self.aux_1(h1))

            for val, arr in zip([h, a0, a1], [hidden, aux_0, aux_1]):
                val = val.unsqueeze(2)
                arr.append(val)

        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2) if self.is_train else None
        aux_0 = torch.cat(aux_0, dim=2)
        aux_1 = torch.cat(aux_1, dim=2)
            
        aux_out = torch.cat([aux_0, aux_1], dim=1)
        # N, C, D -> D, N, C
        seq = self.rnn(aux_out)
        y_hat = self.out(seq)
        return hidden, x_rec, y_hat

    @staticmethod
    def to_y(enc: Tensor, x_rec: Tensor, y_hat: Tensor):
        y_pred = y_hat.detach()
        y_pred = torch.sigmoid(y_pred).squeeze_(1)
        return y_pred

In [24]:
sam = SenyaGanjubas(
    image_shape=(3, H, W), 
    width=8,
    enc_depth=5,
    aux_depth=3,
    p_emb_drop=0.15,
    p_out_drop=0.15,
    rnn_dim=256,
    wide=False,
    enc_att=[2],
    dec_att=[2],
    aux_att=[]
).to(device)

In [25]:
out = sam(x, y)

In [26]:
sam

SenyaGanjubas(
  (encoder): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): ReLU(inplace=True)
    (2): EncoderBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (2): Sequential(
          (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (id_conv): Sequential(
        (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
    

In [27]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat', 'y_hat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.000, std 1.000 | (16, 3, 10, 256, 256)
h     | mean  0.397, std 0.584 | (16, 128, 10, 16, 16)
x_hat | mean  0.123, std 0.547 | (16, 3, 10, 256, 256)
y_hat | mean -0.030, std 0.134 | (16, 1)
