In [1]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [3]:
sys.path.insert(0, SRC_DIR)

In [4]:
from model.zoo.common import encoder_block, decoder_block
from model.layers import conv2D, conv3D, relu, get_a_from_act_fn, ActivationFn, Lambda, MaxMean2D, MaxMean3D, EncoderBlock, DecoderBlock
from model.loss import activation_loss
from model.ops import act, identity, select, pool_gru
from model.efficient_attention.efficient_attention import EfficientAttention

In [5]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [6]:
h.shape

torch.Size([5, 6, 1, 1])

In [7]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [8]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [9]:
act(h, all_pos)

tensor([0.6667, 0.0000, 0.3333, 0.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [10]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [11]:
act(h, all_neg)

tensor([0.6667, 0.0000, 0.6667, 1.0000, 0.0000], grad_fn=<ClampMaxBackward>)

In [12]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [15]:
act(h, y)

tensor([0.6667, 0.0000, 0.3333, 1.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [16]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [17]:
activation_loss(h, y)

tensor(0.6667, grad_fn=<DivBackward0>)

In [18]:
device = torch.device('cuda:0')

In [19]:
# N, D, H, W = 2, 5, 128, 128

# x = torch.randn((N, 3, D, H, W)).to(device)
# y = torch.randint(2, (N,)).to(device)

In [20]:
N, D, H, W = 3, 10, 256, 256

x = torch.randn((N, 3, D, H, W)).to(device)
y = torch.randint(2, (N,)).to(device)

In [21]:
# https://github.com/fastai/course-v3/blob/master/nbs/dl2/11_train_imagenette.ipynb
def init_cnn(m, a=0.0):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight, a=a)
    for l in m.children(): init_cnn(l)

In [22]:
def stack_enc_blocks(width: int, start: int, end: int, wide=False,
                     act_fn: Optional[ActivationFn] = relu):
    layers = []
    for i in range(start, end):
        in_ch = width * 2**i
        out_ch = width * 2**(i+1)
        h_ch = out_ch if wide else in_ch
        block = EncoderBlock(in_ch, out_ch, h_ch, act_fn=act_fn)
        layers.append(block)
    return layers


def stack_dec_blocks(width: int, start: int, end: int, wide=False,
                     act_fn: Optional[ActivationFn] = relu):
    layers = []
    for i in sorted(range(start, end), reverse=True):
        in_ch = width * 2**(i+1)
        out_ch = width * 2**i
        h_ch = in_ch if wide else out_ch
        block = DecoderBlock(in_ch, out_ch, h_ch, act_fn=act_fn)
        layers.append(block)
    return layers


class RNNBlock(nn.Module):
    def __init__(self, in_ch: int, rnn_ch: int, bidirectional=False):
        super().__init__()
        self.gru = nn.GRU(in_ch, rnn_ch, bidirectional=bidirectional)
        self.out_ch = rnn_ch * (2 if bidirectional else 1)

    def forward(self, x):
        # N, C, D -> D, N, C
        x = x.permute(2, 0, 1)
        x, _ = self.gru(x)
#         x_mean = x.mean(0)
#         x_max, _ = x.max(0)
#         x_last = x[-1]
#         x = torch.cat([x_mean, x_max, x_last], dim=1)
        return x[-1]


class Samwise(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], width: int,
                 enc_depth: int, aux_depth: int, wide=False,
                 act_fn: Optional[ActivationFn] = relu,
                 reduce: str = 'mean', rnn_dim: Optional[int] = None,
                 p_emb_drop=0.1, p_out_drop=0.1, train=True):
        super(Samwise, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")

        max_depth = math.log2(H)
        if enc_depth + aux_depth > max_depth:
            raise AttributeError(
                f"enc_depth + aux_depth should be <= {int(max_depth)} given the "
                f"image_size ({H}, {H})")

        if width % 2:
            raise AttributeError("width must be even number")

        stem = [conv2D(C, width, bias=False), act_fn]
        encoder = stack_enc_blocks(width, 0, enc_depth - 1, wide=wide, act_fn=act_fn)
        self.encoder = nn.Sequential(*stem, *encoder)

        decoder = stack_dec_blocks(width, 0, enc_depth - 1, wide=wide, act_fn=act_fn)
        dec_out = conv2D(width, C, kernel=3, stride=1, bias=False)
        self.decoder = nn.Sequential(*decoder, dec_out, nn.Tanh())

        for i in range(2):
            aux_branch = stack_enc_blocks(
                width // 2, enc_depth - 1, enc_depth - 1 + aux_depth,
                wide=wide, act_fn=act_fn)
            if p_emb_drop > 0:
                aux_branch = [nn.Dropout2d(p=p_emb_drop)] + aux_branch
            setattr(self, f'aux_{i}', nn.Sequential(*aux_branch))

        aux_dim = width * 2 ** (enc_depth + aux_depth)
        out_dim = 0
        if reduce == 'rnn':
            if not rnn_dim:
                raise AttributeError("GRU dim is missing")
            for i in range(2):
                pool = MaxMean3D(reduce_frames=False)
                rnn = RNNBlock(aux_dim//2, rnn_dim, bidirectional=True)
                reducer = nn.Sequential(pool, rnn)
                out_dim = rnn.out_ch * 2
                setattr(self, f'reduce_{i}', reducer)
        elif reduce == 'mean':
            reducer = Lambda(lambda x: x.mean(dim=(2, 3, 4)))
            out_dim = aux_dim // 2
            for i in range(2):
                setattr(self, f'reduce_{i}', reducer)
        elif reduce == 'max-mean' or reduce == 'mean-max':
            reducer = MaxMean3D()
            out_dim = aux_dim
            for i in range(2):
                setattr(self, f'reduce_{i}', reducer)
        else:
            raise AttributeError(
                f"reduce={reduce} - invalid value, available options: "
                "[rnn, mean, max-mean, mean-max]")

        out = [nn.Linear(out_dim, out_dim//4), 
                   act_fn,
                   nn.Linear(out_dim//4, 1, bias=False)]
        if p_out_drop > 0:
            out = [nn.Dropout(p=p_out_drop)] + out
        self.out = nn.Sequential(*out)
        init_cnn(self.out, a=get_a_from_act_fn(act_fn))
        self.is_train = train

    def forward(self, x: FloatTensor, y: Optional[LongTensor] = None):
        N, C, D, H, W = x.shape
        hidden, x_rec, aux_0, aux_1 = [], [], [], []

        for f in range(D):
            h = self.encoder(x[:, :, f])

            if self.is_train:
                hc = select(h, y)
                x1 = self.decoder(hc).unsqueeze(2)
                x_rec.append(x1)

            h0, h1 = torch.chunk(h, 2, dim=1)
            a0 = self.aux_0(h0)
            a1 = self.aux_1(h1)

            for val, arr in zip([h, a0, a1], [hidden, aux_0, aux_1]):
                val = val.unsqueeze(2)
                arr.append(val)

        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2) if self.is_train else None

        aux_out = []
        for i, aux_i in enumerate([aux_0, aux_1]):
            aux_i = torch.cat(aux_i, dim=2)
            aux_i = getattr(self, f'reduce_{i}')(aux_i)
            aux_out.append(aux_i)
        aux_out = torch.cat(aux_out, dim=1)

        y_hat = self.out(aux_out)
        return hidden, x_rec, y_hat

    @staticmethod
    def to_y(enc: Tensor, x_rec: Tensor, y_hat: Tensor):
        y_pred = y_hat.detach()
        y_pred = torch.sigmoid(y_pred).squeeze_(1)
        return y_pred.clamp_(0.05, 0.95)

In [23]:
sam = Samwise(
    image_shape=(3, H, W), 
    width=8,
    enc_depth=5,
    aux_depth=3,
    p_emb_drop=0.15,
    p_out_drop=0.15,
    reduce='rnn',
    rnn_dim=256,
    wide=False,
    act_fn=nn.PReLU(),
).to(device)

In [24]:
out = sam(x, y)

In [25]:
sam

Samwise(
  (encoder): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): PReLU(num_parameters=1)
    (2): EncoderBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
        (1): Sequential(
          (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
        (2): Sequential(
          (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (id_conv): Sequential(
        (0): AvgPool2d(kernel_size=2, stride=2, paddin

In [26]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat', 'y_hat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.000, std 1.000 | (3, 3, 10, 256, 256)
h     | mean  0.299, std 0.662 | (3, 128, 10, 16, 16)
x_hat | mean  0.099, std 0.539 | (3, 3, 10, 256, 256)
y_hat | mean -0.846, std 0.305 | (3, 1)
