In [1]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.autograd import Variable

MAX_WIDTH_HEIGHT = 500

In [2]:
def dilat(tensor, dilations):
    """
    tensor  (tensor): Tensor to be dilated
    dilations (list): List of dilation factor for each dim,
                      None for no dilation on this dim

    For each dilated dim of size s, it splits this dim into to (s // dil, dil) dimensions
    Otherwise, does nothing to the dim
    """
    assert len(tensor.shape) == len(dilations)
    assert all(
        (dil is None) or (sh % dil == 0) for dil, sh in zip(dilations, tensor.shape)
    ), "dilation should divide dimension"

    new_dim = []
    content_indices = []
    group_indices = []
    for dim, dilat in zip(tensor.shape, dilations):
        if dilat is None:
            new_dim.append(dim)
        else:
            new_dim.extend([dim // dilat, dilat])

    tensor_dilat = tensor.view(*new_dim)
    return tensor_dilat


def dilated_attention(V, Q, K, dilation=1):
    try:
        x_dilation, y_dilation = dilation
    except:
        x_dilation = y_dilation = dilation

    # V shape: B x W x H x n_head x d_v
    # Q shape: B x W x H x n_head x d_k
    # K shape: B x W x H x n_head x d_k
    batch_size, width, height, n_head, d_v = V.shape
    d_k = Q.shape[-1]

    # B x W/dil x dil x H/dil x dil x n_head x d
    K_dilated = dilat(K, [None, x_dilation, y_dilation, None, None])
    Q_dilated = dilat(Q, [None, x_dilation, y_dilation, None, None])
    V_dilated = dilat(V, [None, x_dilation, y_dilation, None, None])

    # b = batch, h = head, d = dimension
    # each pixel can attend pixels in its own group,
    # i.e. all position at a distance multiple of (x_dil, y_dil)
    # blocks are (x_dil X y_dil) rectangle of the image
    # each pixel of block (x,y) in the group (i,j) attend pixel in same group in different blocks (v,w)
    # a dot product is done over the d dimension
    # the head and batch dimension are kept
    attention_coefficients = torch.einsum("bxiyjhd,bviwjhd->bxiyjhvw", [Q_dilated, K_dilated])
    #attention_coefficients = torch.einsum("bxiyjhd,bviwjhd->bxiyjvwh", [Q_dilated, K_dilated])
    attention_shape = attention_coefficients.size()
    attention_coefficients = attention_coefficients.view(attention_shape[:-2] + (-1,))
    attention_probs = nn.Softmax(dim=-1)(attention_coefficients)
    attention_coefficients = attention_probs.view(attention_shape)



    # the attention_coefficients are used to compute the weighted sum of the values
    # each pixel in block (x,y) and group (i,j) sums the values of
    # the pixes in group (i,j) at any other block position (v,w)
    new_V = torch.einsum("bxiyjhvw,bviwjhd->bxiyjhd", [attention_coefficients, V_dilated])
    new_V = new_V.contiguous().view(batch_size, width, height, n_head*d_v)
    #print(new_V.shape)
    return new_V, attention_coefficients

In [3]:
def get_unfolded(tensor,kernel_size):
    """
    tensor  (tensor): Tensor to be dilated, in shape (batch_size, W, H, nhead*d)
    kernel_size: size of the square to be attended

    output:
        tensor_unf: results in shape (batch, nhead*d, W, H, kernel_size, kernel_size)

    """
    B, W, H, D = tensor.shape
    tensor = tensor.permute(0,3,1,2)
    unf = nn.Unfold(kernel_size=kernel_size, dilation=1, padding=int((kernel_size-1)/2), stride=1)
    tensor_unf = unf(tensor)
    tensor_unf = tensor_unf.view((B, D, W, H , kernel_size, kernel_size))
    return tensor_unf


def local_attention(V, Q, K, kernel_size=5):

    # V shape: B x W x H x n_head x d_v
    # Q shape: B x W x H x n_head x d_k
    # K shape: B x W x H x n_head x d_k
    batch_size, width, height, n_head, d_v = V.shape
    V = V.view((batch_size, width, height, -1))
    #Q = Q.view((batch_size, width, height, -1))
    K = K.view((batch_size, width, height, -1))

    d_k = Q.shape[-1]

    K_field = get_unfolded(K,kernel_size).view((batch_size, n_head, d_v, width, height, kernel_size, kernel_size))
    V_field = get_unfolded(V,kernel_size).view((batch_size, n_head, d_v, width, height, kernel_size, kernel_size))
    #Q_field = get_unfolded(Q)

    # b = batch, h = head, d = dimension
    # each pixel can attend pixels in its own group,
    # i.e. all position at a distance multiple of (x_dil, y_dil)
    # blocks are (x_dil X y_dil) rectangle of the image
    # each pixel of block (x,y) in the group (i,j) attend pixel in same group in different blocks (v,w)
    # a dot product is done over the d dimension
    # the head and batch dimension are kept
    attention_coefficients = torch.einsum("bwhnd,bndwhxy->bwhnxy", [Q, K_field])
    #attention_coefficients = torch.einsum("bxiyjhd,bviwjhd->bxiyjvwh", [Q_dilated, K_dilated])
    attention_shape = attention_coefficients.size()
    attention_coefficients = attention_coefficients.view(attention_shape[:-2] + (-1,))
    attention_probs = nn.Softmax(dim=-1)(attention_coefficients)
    attention_coefficients = attention_probs.view(attention_shape)



    # the attention_coefficients are used to compute the weighted sum of the values
    # each pixel in block (x,y) and group (i,j) sums the values of
    # the pixes in group (i,j) at any other block position (v,w)
    new_V = torch.einsum("bwhnxy,bndwhxy->bwhnd", [attention_coefficients, V_field])
    #print(new_V.shape)
    new_V = new_V.contiguous().view(batch_size, width, height, n_head* d_v)
    return new_V, attention_coefficients

In [4]:
class BertSelfAttentionDilation(nn.Module):
    def __init__(self, config, output_attentions=False, keep_multihead_output=False, dilations=None, kernel_size=5):
        super(BertSelfAttentionDilation, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.output_attentions = output_attentions
        self.keep_multihead_output = keep_multihead_output
        self.multihead_output = None

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.dilations = dilations
        self.kernel_size = kernel_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.proj = nn.Linear(4*config.hidden_size, config.hidden_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask, head_mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        if self.dilations is not None:
            q_shape = mixed_query_layer.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
            query_layer = mixed_query_layer.view(*q_shape)
            key_layer = mixed_key_layer.view(*q_shape)
            value_layer = mixed_value_layer.view(*q_shape)
            context_layer_dil, attention_probs = dilated_attention(value_layer, query_layer, key_layer, dilation=self.dilations)
            context_layer_row, attention_probs_row = dilated_attention(value_layer, query_layer, key_layer,
                                                                       dilation=(1, value_layer.shape[2]))
            context_layer_col,attention_probs_col = dilated_attention(value_layer, query_layer, key_layer,
                                                                       dilation=(value_layer.shape[1],1))
            context_layer_local,attention_probs_local = local_attention(value_layer, query_layer, key_layer,
                                                                       self.kernel_size)
            #print(context_layer_dil.shape, context_layer_local.shape, context_layer_row.shape)
            context_layer_cat = torch.cat((context_layer_dil,context_layer_row,context_layer_col,context_layer_local),
                                          dim=-1)
            context_layer = self.proj(context_layer_cat)

            #TODO: Linear projection or weighted sum?
        else:
            query_layer = self.transpose_for_scores(mixed_query_layer)
            key_layer = self.transpose_for_scores(mixed_key_layer)
            value_layer = self.transpose_for_scores(mixed_value_layer)
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
            attention_scores = attention_scores / math.sqrt(self.attention_head_size)
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

            # Normalize the attention scores to probabilities.
            attention_probs = nn.Softmax(dim=-1)(attention_scores)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.dropout(attention_probs)

            # Mask heads if we want to
            if head_mask is not None:
                attention_probs = attention_probs * head_mask

            context_layer = torch.matmul(attention_probs, value_layer)
            if self.keep_multihead_output:
                self.multihead_output = context_layer
                self.multihead_output.retain_grad()

            context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
            context_layer = context_layer.view(*new_context_layer_shape)
        if self.output_attentions:
            return attention_probs, context_layer
        return context_layer

In [5]:
config = dict(
    dataset="Cifar10",
    model="bert",
    optimizer="SGD",
    optimizer_decay_at_epochs=[150, 250],
    optimizer_decay_with_factor=10.0,
    optimizer_learning_rate=0.1,
    optimizer_momentum=0.9,
    optimizer_weight_decay=0.0001,
    batch_size=16,
    num_epochs=300,
    seed=42,
    # added for BERT, some are useless
    vocab_size_or_config_json_file=-1,
    hidden_size=128,  # 768,
    num_hidden_layers=4,
    num_attention_heads=4,
    intermediate_size=512,
    hidden_act="gelu",
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
    max_position_embeddings=512,
    type_vocab_size=2,
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    # BERT Image specific
    mask_dimension=5,
)


In [6]:
device = torch.device('cuda:0')

In [7]:
!nvidia-smi

Mon Jul  8 17:28:47 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 375.51                 Driver Version: 375.51                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla K40c          Off  | 0000:04:00.0     Off |                    0 |
| 25%   49C    P0    66W / 235W |      0MiB / 11439MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K40c          Off  | 0000:84:00.0     Off |                    0 |
| 23%   38C    P0    68W / 235W |      0MiB / 11439MiB |     98%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------

In [8]:
torch.cuda.empty_cache()

In [9]:
class BertConfig(object):
    """Configuration class to store the configuration of a `BertModel`.
    """
    def __init__(self,
                 vocab_size_or_config_json_file,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2,
                 initializer_range=0.02,
                 layer_norm_eps=1e-12):
        """Constructs BertConfig.

        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
            layer_norm_eps: The epsilon used by LayerNorm.
        """
        if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
                        and isinstance(vocab_size_or_config_json_file, unicode)):
            with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_act = hidden_act
            self.intermediate_size = intermediate_size
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
            self.layer_norm_eps = layer_norm_eps
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BertConfig` from a Python dictionary of parameters."""
        config = BertConfig(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BertConfig` from a json file of parameters."""
        with open(json_file, "r", encoding='utf-8') as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

    def to_json_file(self, json_file_path):
        """ Save this instance to a json file."""
        with open(json_file_path, "w", encoding='utf-8') as writer:
            writer.write(self.to_json_string())

try:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
    #logger.info("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .")
    class BertLayerNorm(nn.Module):
        def __init__(self, hidden_size, eps=1e-12):
            """Construct a layernorm module in the TF style (epsilon inside the square root).
            """
            super(BertLayerNorm, self).__init__()
            self.weight = nn.Parameter(torch.ones(hidden_size))
            self.bias = nn.Parameter(torch.zeros(hidden_size))
            self.variance_epsilon = eps

        def forward(self, x):
            u = x.mean(-1, keepdim=True)
            s = (x - u).pow(2).mean(-1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.variance_epsilon)
            return self.weight * x + self.bias


In [10]:
import sys
bert_config = BertConfig.from_dict(config)

In [11]:
attndil = BertSelfAttentionDilation(config=bert_config,dilations=4)
attndil = attndil.to(device)

In [12]:
config

{'dataset': 'Cifar10',
 'model': 'bert',
 'optimizer': 'SGD',
 'optimizer_decay_at_epochs': [150, 250],
 'optimizer_decay_with_factor': 10.0,
 'optimizer_learning_rate': 0.1,
 'optimizer_momentum': 0.9,
 'optimizer_weight_decay': 0.0001,
 'batch_size': 16,
 'num_epochs': 300,
 'seed': 42,
 'vocab_size_or_config_json_file': -1,
 'hidden_size': 128,
 'num_hidden_layers': 4,
 'num_attention_heads': 4,
 'intermediate_size': 512,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'attention_probs_dropout_prob': 0.1,
 'max_position_embeddings': 512,
 'type_vocab_size': 2,
 'initializer_range': 0.02,
 'layer_norm_eps': 1e-12,
 'mask_dimension': 5}

In [13]:
class BertSelfAttention(nn.Module):
    def __init__(self, config, output_attentions=False, keep_multihead_output=False):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.output_attentions = output_attentions
        self.keep_multihead_output = keep_multihead_output
        self.multihead_output = None

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask, head_mask=None):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        
        #print(query_layer.shape)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        
        #print(attention_scores.shape)
        
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)
        if self.keep_multihead_output:
            self.multihead_output = context_layer
            self.multihead_output.retain_grad()

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        if self.output_attentions:
            return attention_probs, context_layer
        return context_layer


In [14]:
attnos = []
for i in range(12):
    attnos.append(BertSelfAttention(config=bert_config).to(device))


In [15]:
attnos = []
for i in range(12):
    attnos.append(BertSelfAttentionDilation(config=bert_config, dilations=4).to(device))

In [15]:
batch_size = 14
width = 32
height = 32
n_head = 16
d = 128
dil = 4

print("batch_size", batch_size)
print("width", width)
print("height", height)
print("n_head", n_head)
print("d", d)
print("dil", dil)

#batch = torch.rand(batch_size, width, height, d)
#batch_flat = torch.rand(batch_size, width*height, d).to(device)
#attention_mask = torch.ones(batch_size, width, height)


batch_size 14
width 32
height 32
n_head 16
d 128
dil 4


In [41]:
attention_mask_flat = torch.ones(width*height, width*height).to(device)

In [42]:
batch = batch.to(device)
attention_mask = attention_mask.to(device)

In [16]:
from tqdm import tqdm
import math

In [17]:
num_steps = 20

for i in tqdm(range(num_steps)):
    context_layers = []
    attention_mask_flat = torch.ones(width*height, width*height).to(device)
    batch_flat = torch.rand(batch_size, width*height, d).to(device)
    for attnorig in attnos:
        context_layers.append(attnorig(batch_flat, attention_mask_flat))
    

  0%|          | 0/20 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 224.00 MiB (GPU 0; 11.17 GiB total capacity; 10.79 GiB already allocated; 115.75 MiB free; 479.00 KiB cached)

In [18]:
num_steps = 20

for i in tqdm(range(num_steps)):
    context_layers = []
    batch = torch.rand(batch_size, width, height, d).to(device)
    attention_mask = torch.ones(batch_size, width, height).to(device)
    for attnorig in attnos:
        context_layers.append(attnorig(batch, attention_mask))

100%|██████████| 20/20 [00:30<00:00,  1.52s/it]
