In [1]:
import math
import numpy as np
import os
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import partial
from torch import FloatTensor, LongTensor, Tensor
from torchvision import transforms as T
from tqdm.notebook import tqdm
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')

In [3]:
sys.path.insert(0, SRC_DIR)

In [4]:
from model.zoo.common import encoder_block, decoder_block
from model.layers import conv2D, conv3D, relu, get_a_from_act_fn, ActivationFn, Lambda, MaxMean2D, MaxMean3D, EncoderBlock, DecoderBlock
from model.loss import activation_loss
from model.ops import act, identity, select, pool_gru
from model.efficient_attention.efficient_attention import EfficientAttention

In [21]:
from model.ops import reshape_as

In [33]:
def act(h: Tensor, y: Tensor) -> Tensor:
    N = y.size(0)
    y = reshape_as(y, h)
    h0, h1 = h.chunk(2, dim=1)
    a = h0 * (1 - y) + h1 * y
    n_el = a.numel() / max(N, 1)
    a = a.abs().sum(tuple(range(1, a.ndim))) / n_el
    
    # For simplicity, and without losing generality, 
    # we constrain a(x) to be equal to 1
    return a

In [5]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,0,  1,1,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])

all_neg = torch.zeros(y.size(0), dtype=torch.int64)
all_pos = torch.ones(y.size(0), dtype=torch.int64)

In [6]:
h.shape

torch.Size([5, 6, 1, 1])

In [7]:
h.reshape(5, 6)

tensor([[1., 1., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<ViewBackward>)

In [8]:
select(h, all_pos).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<ViewBackward>)

In [26]:
act(h, all_pos)

tensor([0.6667, 0.0000, 0.3333, 0.0000, 1.0000], grad_fn=<DivBackward0>)

In [10]:
select(h, all_neg).reshape(5, 6)

tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], grad_fn=<ViewBackward>)

In [28]:
act(h, all_neg)

tensor([0.6667, 0.0000, 0.6667, 1.0000, 0.0000], grad_fn=<DivBackward0>)

In [12]:
(select(h, all_neg) + select(h, all_pos) == h).all().item()

True

In [13]:
y

tensor([1, 0, 1, 0, 1])

In [14]:
select(h, y).reshape(5, 6)

tensor([[0., 0., 0., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 0., 0., 1., 1., 1.]], grad_fn=<ViewBackward>)

In [29]:
act(h, y)

tensor([0.6667, 0.0000, 0.3333, 1.0000, 1.0000], grad_fn=<DivBackward0>)

In [16]:
(select(h, y) + select(h, (1-y)) == h).all().item()

True

$$
L_{ACT} =
\sum_{x ∈ S_0}
|a_0(x) − 1| + |a_1(x)| +
\sum_{x ∈ S_1}
|a_1(x) − 1| + |a_0(x)|
$$

In [30]:
activation_loss(h, y)

tensor(0.6667, grad_fn=<DivBackward0>)

In [67]:
h = torch.tensor([
    [1,1,0,  1,1,0],
    [0,0,0,  0,0,0],
    [1,0,1,  0,1,0],
    [1,1,1,  0,0,0],
    [0,0,3.99,  -0.6,-1.5,1]
], dtype=torch.float32, requires_grad=True)[:,:,None,None]

y = torch.tensor([1, 0, 1, 0, 1])
activation_loss(h, y)

tensor(0.8667, grad_fn=<DivBackward0>)