In [1]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [3]:
sys.path.insert(0, SRC_DIR)

In [4]:
from model.zoo.common import encoder_block, decoder_block
from model.layers import conv2D, conv3D, Lambda
from model.loss import activation_loss
from model.ops import act, identity, select, pool_gru

In [5]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [6]:
h.shape

torch.Size([5, 6, 1, 1])

In [7]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [8]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [9]:
act(h, all_pos)

tensor([0.6667, 0.0000, 0.3333, 0.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [10]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<AsStridedBackward>)

In [11]:
act(h, all_neg)

tensor([0.6667, 0.0000, 0.6667, 1.0000, 0.0000], grad_fn=<ClampMaxBackward>)

In [12]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<AsStridedBackward>)

In [15]:
act(h, y)

tensor([0.6667, 0.0000, 0.3333, 1.0000, 1.0000], grad_fn=<ClampMaxBackward>)

In [16]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [17]:
activation_loss(h, y)

tensor(0.6667, grad_fn=<DivBackward0>)

In [18]:
device = torch.device('cuda:0')

In [19]:
N, D, H, W = 2, 5, 128, 128

x = torch.randn((N, 3, D, H, W)).to(device)
y = torch.randint(2, (N,)).to(device)

In [20]:
def affine(ch_in: int, ch_out: int):
    return nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=False)


class SpatialAttention(nn.Module):
    def __init__(self, ch):
        super(SpatialAttention, self).__init__()
        # Channel multiplier
        self.ch = ch
        self.theta = affine(ch, ch//8)
        self.phi   = affine(ch, ch//8)
        self.g     = affine(ch, ch//2)
        self.o     = affine(ch//2, ch)
        # Learnable gain parameter
        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, x, y=None):
        N, C, H, W = x.shape
        ch = self.ch
        
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), [2,2])
        g = F.max_pool2d(self.g(x), [2,2])    
        # Perform reshapes
        theta = theta.view(-1, ch // 8, H * W)
        phi = phi.view(-1, ch // 8, H * W // 4)
        g = g.view(-1, ch // 2, H * W // 4)
        # Matmul and softmax to get attention maps
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
        # Attention map times g path
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, ch // 2, H, W))
        return self.gamma * o + x

In [23]:
def enc_layer(in_ch: int, out_ch: int, kernel=3, stride=1, 
              act_fn=nn.ReLU(inplace=True), zero_bn=False) -> nn.Module:
    conv = conv2D(in_ch, out_ch, kernel=kernel, stride=stride, bias=False)
    bn = nn.BatchNorm2d(out_ch)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers = [conv, bn]
    if act_fn is not None: 
        layers.append(act_fn)
    return nn.Sequential(*layers)

In [24]:
class EncoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, h_ch: int):
        super().__init__()
        self.conv = nn.Sequential(
            enc_layer(in_ch, h_ch, kernel=1),
            enc_layer(h_ch, h_ch, kernel=3, stride=2),
            enc_layer(h_ch, out_ch, kernel=1, zero_bn=True, act_fn=None))
        self.idconv = nn.Sequential(
            nn.AvgPool2d(2, stride=2),
            enc_layer(in_ch, out_ch, kernel=1, act_fn=None))
        
    def forward(self, x):
        x = self.conv(x) + self.idconv(x)
        return torch.relu_(x)

In [25]:
def upsample(scale: int):
    return Lambda(partial(F.interpolate, scale_factor=scale, mode='nearest'))


def dec_layer(in_ch: int, out_ch: int, kernel=3, scale=1, 
              act_fn=nn.ReLU(inplace=True), zero_bn=False) -> nn.Module:
    layers = [upsample(scale)] if scale > 1 else []
    conv = conv2D(in_ch, out_ch, kernel=kernel, stride=1, bias=False)
    bn = nn.BatchNorm2d(out_ch)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers += [conv, bn]
    if act_fn is not None: 
        layers.append(act_fn)
    return nn.Sequential(*layers)

In [26]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, h_ch: int):
        super().__init__()
        self.conv = nn.Sequential(
            dec_layer(in_ch, h_ch, kernel=1),
            dec_layer(h_ch, h_ch, kernel=3, scale=2),
            dec_layer(h_ch, out_ch, kernel=1, zero_bn=True, act_fn=None))
        self.idconv = dec_layer(in_ch, out_ch, kernel=1, scale=2, act_fn=None)
        
    def forward(self, x):
        x = self.conv(x) + self.idconv(x)
        return torch.relu_(x)

In [27]:
def stack_enc_blocks(width: int, start: int, end: int, wide=False):
    layers = []
    for i in range(start, end):
        in_ch = width * 2**i
        out_ch = width * 2**(i+1)
        h_ch = out_ch if wide else in_ch
        block = EncoderBlock(in_ch, out_ch, h_ch)
        layers.append(block)
    return layers


def stack_dec_blocks(width: int, start: int, end: int, wide=False):
    layers = []
    for i in sorted(range(start, end), reverse=True):
        in_ch = width * 2**(i+1)
        out_ch = width * 2**i
        h_ch = in_ch if wide else out_ch
        block = DecoderBlock(in_ch, out_ch, h_ch)
        layers.append(block)
    return layers

In [28]:
class Samwise(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], width: int, 
                 enc_depth: int, aux_depth: int, p_drop=0.0, wide=False):
        super(Samwise, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")
            
        max_depth = math.log2(H) + 1         
        if enc_depth + aux_depth  > max_depth:
            raise AttributeError(
                f"enc_depth + aux_depth should be <= {int(max_depth)} given the "
                f"image_size ({H}, {H})")
            
        if width % 2:
            raise AttributeError("width must be even number")
        
        stem = [conv2D(C, width), nn.ReLU(inplace=True)]
        encoder = stack_enc_blocks(width, 0, enc_depth-1, wide=wide)
        self.encoder = nn.Sequential(*stem, *encoder)
        
        decoder = stack_dec_blocks(width, 0, enc_depth-1, wide=wide)
        dec_out = conv2D(width, C, kernel=3, stride=1)
        self.decoder = nn.Sequential(*decoder, dec_out, nn.Tanh())
        
        for i in range(2):
            aux_branch = stack_enc_blocks(
                width//2, enc_depth-1, enc_depth-1 + aux_depth, wide=wide)
            setattr(self, 'aux_{}'.format(i), nn.Sequential(*aux_branch))
        
        out_dim = width * 2**(enc_depth + aux_depth - 1)
        self.aux_out = nn.Sequential(
            nn.Dropout(p=p_drop),
            nn.Linear(out_dim, 1, bias=False))

    def forward(self, x: FloatTensor, y: LongTensor):
        N, C, D, H, W = x.shape
        hidden, x_rec, aux_0, aux_1 = [], [], [], []
        
        for f in range(D):
            h = self.encoder(x[:, :, f])
            hc = select(h, y)
            x1 = self.decoder(hc)
            
            h0, h1 = torch.chunk(h, 2, dim=1)
            a0 = self.aux_0(h0)
            a1 = self.aux_1(h1)
            
            for val, arr in zip([h, x1, a0, a1], 
                                [hidden, x_rec, aux_0, aux_1]):
                val = val.unsqueeze(2)
                arr.append(val)
            
        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2)
        aux_0 = reduce_frames(aux_0)
        aux_1 = reduce_frames(aux_1)
        aux = torch.cat([aux_0, aux_1], dim=1)
        y_hat = self.aux_out(aux)
        return hidden, x_rec, y_hat

In [29]:
def reduce_frames(v: List[Tensor]) -> Tensor:
    v = torch.cat(v, dim=2).flatten(2)
    return v.mean(dim=2)

In [30]:
N, D, H, W = 16, 10, 256, 256

x = torch.randn((N, 3, D, H, W)).to(device)
y = torch.randint(2, (N,)).to(device)

In [31]:
sam = Samwise(
    image_shape=(3, H, W), 
    width=8,
    enc_depth=5,
    aux_depth=3,
    p_drop=0.1,
    wide=False
).to(device)

In [32]:
out = sam(x, y)

In [33]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat', 'y_hat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))

x     | mean  0.000, std 1.000 | (16, 3, 10, 256, 256)
h     | mean  0.397, std 0.585 | (16, 128, 10, 16, 16)
x_hat | mean -0.070, std 0.535 | (16, 3, 10, 256, 256)
y_hat | mean  0.311, std 0.098 | (16, 1)


In [34]:
import crash_nb

ModuleNotFoundError: No module named 'crash_nb'

In [None]:
def build_enc_blocks(start_width: int, start: int, end: int, 
                     attention: Optional[List[int]] = None):
    if attention is None:
        attention = []
    layers = []
    for i in range(start, end):
        in_ch = start_width * 2**i
        out_ch = start_width * 2**(i+1)
        if i in attention:
            layers.append(SpatialAttention(in_ch))
        layers.append(encoder_block(in_ch, out_ch))
    return layers


def build_dec_blocks(start_width: int, start: int, end: int, 
                     attention: Optional[List[int]] = None):
    if attention is None:
        attention = []
    layers = []
    for i in sorted(range(start, end), reverse=True):
        in_ch = start_width * 2**(i+1)
        out_ch = start_width * 2**i
        layers.append(decoder_block(in_ch, out_ch))
        if i in attention:
            layers.append(SpatialAttention(out_ch))
    return layers


def reduce_frames(v: Tensor) -> Tensor:
    v = torch.cat(v, dim=2).flatten(2)
    return v.mean(dim=2)

In [None]:
class Frodo(nn.Module):
    def __init__(self, image_shape: Tuple[int, int, int], width: int, 
                 enc_depth: int, aux_depth: int, p_drop: float,
                 enc_attention: Optional[List[int]] = None, 
                 dec_attention: Optional[List[int]] = None):
        super(Frodo, self).__init__()
        C, H, W = image_shape
        if H != W:
            raise AttributeError("Only square images are supported!")
            
        max_depth = math.log2(H) + 1         
        if enc_depth + aux_depth  > max_depth:
            raise AttributeError(
                f"enc_depth + aux_depth should be <= {int(max_depth)} given the "
                f"image_size ({H}, {H})")
            
        if width % 2:
            raise AttributeError("width must be even number")
            
        max_att = enc_depth - 2
        if any(filter(lambda x: x > max_att, enc_attention + dec_attention)):
            raise AttributeError(f"Can't place attention higher than {max_att}")
        
        stem = encoder_block(C, width, stride=1, bn=False)        
        encoder = build_enc_blocks(width, 0, enc_depth - 1, enc_attention)
        self.encoder = nn.Sequential(stem, *encoder)
        
        decoder = build_dec_blocks(width, 0, enc_depth - 1, dec_attention)
        dec_out = conv2D(width, C, kernel=3, stride=1)
        self.decoder = nn.Sequential(*decoder, dec_out, nn.Tanh())
        
        for i in range(2):
            aux = build_enc_blocks(
                width//2, enc_depth - 1, enc_depth - 1 + aux_depth, [])
            aux_branch = nn.Sequential(*aux)
            setattr(self, 'aux_{}'.format(i), aux_branch)
        
        out_dim = width * 2**(enc_depth + aux_depth - 1)
        self.aux_out = nn.Sequential(
            nn.Dropout(p=p_drop),
            nn.Linear(out_dim, 1, bias=False))

    def forward(self, x: FloatTensor, y: LongTensor):
        N, C, D, H, W = x.shape
        hidden, x_rec, aux_0, aux_1 = [], [], [], []
        
        for f in range(D):
            h = self.encoder(x[:, :, f])
            hc = select(h, y)
            x1 = self.decoder(hc)
            
            h0, h1 = torch.chunk(h, 2, dim=1)
            a0 = self.aux_0(h0)
            a1 = self.aux_1(h1)
            
            for val, arr in zip([h, x1, a0, a1], 
                                [hidden, x_rec, aux_0, aux_1]):
                val = val.unsqueeze(2)
                arr.append(val)
            
        hidden = torch.cat(hidden, dim=2)
        x_rec = torch.cat(x_rec, dim=2)
        aux_0 = reduce_frames(aux_0)
        aux_1 = reduce_frames(aux_1)
        aux = torch.cat([aux_0, aux_1], dim=1)
        y_hat = self.aux_out(aux)
        return hidden, x_rec, y_hat

In [None]:
frodo = Frodo(
    image_shape=(3, H, W), 
    width=8,
    enc_depth=5,
    aux_depth=3,
    p_drop=0.1,
    enc_attention=[0],
    dec_attention=[0]
).to(device)

In [None]:
out = frodo(x, y)

In [None]:
for var, name in zip([x] + list(out), ['x', 'h', 'x_hat', 'y_hat']):
    mean, std = var.mean().item(), var.std().item()
    shape = ', '.join(map(str, var.shape))
    print('{:5s} | mean {: .03f}, std {:.03f} | ({})'.format(name, mean, std, shape))