In [56]:
import itertools

import tensorflow as tf
from tensorflow.keras import layers

specification = {
    'LeViT_128S': {
        'embed_dim': [128, 256, 384],
        'key_dim': 16,
        'num_heads': [4, 6, 8],
        'depth': [2, 3, 4],
        'drop_path': 0,
        'weights': 'https://huggingface.co/facebook/levit-128S/resolve/main/pytorch_model.bin'
    },
    'LeViT_256': {
        'embed_dim': [256, 384, 512],
        'key_dim': 32,
        'num_heads': [4, 6, 8],
        'depth': [4, 4, 4],
        'drop_path': 0,
        'weights': 'https://huggingface.co/facebook/levit-256/resolve/main/pytorch_model.bin'
    },
    'LeViT_384': {
        'embed_dim': [384, 512, 768],
        'key_dim': 32,
        'num_heads': [6, 9, 12],
        'depth': [4, 4, 4],
        'drop_path': 0.1,
        'weights': 'https://huggingface.co/facebook/levit-384/resolve/main/pytorch_model.bin'
    },
}


@tf.function
def hard_swish(features):
    """Computes a hard version of the swish function.

    This operation can be used to reduce computational cost and improve
    quantization for edge devices.

    Args:
        features: A `Tensor` representing preactivation values.

    Returns:
        The activation value.
    """
    return features * tf.nn.relu6(features + tf.cast(3., features.dtype)) * (1. / 6.)


class Backbone(layers.Layer):
    def __init__(self, out_channels):
        super(Backbone, self).__init__(name='backbone')
        self.convolution_layer1 = layers.Conv2D(filters=out_channels // 8,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer1')
        self.batch_norm1 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm1')
        self.convolution_layer2 = layers.Conv2D(filters=out_channels // 4,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer2')
        self.batch_norm2 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm2')
        self.convolution_layer3 = layers.Conv2D(filters=out_channels // 2,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer3')
        self.batch_norm3 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm3')
        self.convolution_layer4 = layers.Conv2D(filters=out_channels,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer4')
        self.batch_norm4 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm4')

    def call(self, x):
        x = hard_swish(self.batch_norm1(self.convolution_layer1(x)))
        x = hard_swish(self.batch_norm2(self.convolution_layer2(x)))
        x = hard_swish(self.batch_norm3(self.convolution_layer3(x)))
        x = hard_swish(self.batch_norm4(self.convolution_layer4(x)))
        return x


class Residual(layers.Layer):
    def __init__(self, module, drop_rate=0., name='residual'):
        super(Residual, self).__init__(name=name)
        self.module = module
        self.dropout = layers.Dropout(drop_rate)

    def call(self, x, training):
        return x + self.dropout(self.module(x), training=training)


class LinearNorm(layers.Layer):
    def __init__(self, out_channels, bn_weight_init=1, name='linearnorm'):
        super(LinearNorm, self).__init__(name=name)
        self.batch_norm = layers.BatchNormalization(gamma_initializer=tf.constant_initializer(bn_weight_init))
        self.linear = layers.Dense(out_channels, activation=None)

    def call(self, x):
        x = self.linear(x)
        shape = x.get_shape().as_list()
        x = tf.reshape(self.batch_norm(tf.reshape(x, (-1, shape[2]))), shape)
        return x


class Downsample(layers.Layer):
    def __init__(self, stride, resolution, name='downsample'):
        super(Downsample, self).__init__(name=name)
        self.stride = stride
        self.resolution = resolution

    def call(self, x):
        batch_size, _, channels = x.get_shape().as_list()
        x = tf.reshape(x, (batch_size, self.resolution, self.resolution, channels))
        x = x[:, ::self.stride, ::self.stride]
        return tf.reshape(x, (batch_size, -1, channels))


class NormLinear(layers.Layer):
    def __init__(self, out_channels, bias=True, std=0.02, drop=0.0):
        super(NormLinear, self).__init__(name='stem')
        self.batch_norm = layers.BatchNormalization()
        self.dropout = layers.Dropout(drop)
        self.linear = layers.Dense(out_channels,
                                   activation=None,
                                   use_bias=bias,
                                   kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0., stddev=std))

    def call(self, x, training):
        x = self.batch_norm(x)
        x = self.dropout(x, training=training)
        x = self.linear(x)
        return x


class MLP(layers.Layer):
    """
    MLP Layer with `2X` expansion in contrast to ViT with `4X`.
    """
    def __init__(self, input_dim, hidden_dim, name='mlp'):
        super(MLP, self).__init__(name=name)
        self.linear_up = LinearNorm(hidden_dim)
        self.linear_down = LinearNorm(input_dim)

    def call(self, x):
        return self.linear_down(hard_swish(self.linear_up(x)))


class Attention(layers.Layer):
    def __init__(self, input_dim, key_dim, num_attention_heads=8, attention_ratio=4, resolution=14, name='attention'):
        super(Attention, self).__init__(name=name)
        self.num_attention_heads = num_attention_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.attention_ratio = attention_ratio

        self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
        self.out_dim_projection = attention_ratio * key_dim * num_attention_heads

        self.queries_keys_values = LinearNorm(self.out_dim_keys_values)
        self.projection = LinearNorm(input_dim)

    def call(self, x):
        batch_size, seq_length, _ = tf.shape(x)

        queries_keys_values = self.queries_keys_values(x)

        query, key, value = tf.split(tf.reshape(queries_keys_values, (batch_size, seq_length, self.num_attention_heads, -1)), [
                self.key_dim, self.key_dim, self.attention_ratio * self.key_dim],axis=3)

        query = tf.transpose(query, (0, 2, 1, 3))
        key = tf.transpose(key, (0, 2, 1, 3))
        value = tf.transpose(value, (0, 2, 1, 3))
        attention = tf.matmul(query, key, transpose_b=True) * self.scale
        attention = tf.nn.softmax(attention, axis=-1)
        hidden_state = tf.reshape(tf.transpose(tf.matmul(attention, value), (0, 1, 3, 2)),
                                  (batch_size, seq_length, self.out_dim_projection))
        hidden_state = self.projection(hard_swish(hidden_state))
        return hidden_state


class AttentionDownsample(layers.Layer):
    def __init__(self,
                 input_dim,
                 output_dim,
                 key_dim,
                 num_attention_heads,
                 attention_ratio,
                 stride,
                 resolution_in,
                 resolution_out,
                 name='attention_downsample'):
        super(AttentionDownsample, self).__init__(name=name)

        self.num_attention_heads = num_attention_heads
        self.scale = key_dim**-0.5
        self.key_dim = key_dim
        self.attention_ratio = attention_ratio
        self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
        self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
        self.resolution_out = resolution_out
        self.resolution_in = resolution_in
        # resolution_in is the intial resolution, resoloution_out is final resolution after downsampling
        self.keys_values = LinearNorm(self.out_dim_keys_values)
        self.queries_subsample = Downsample(stride, resolution_in)
        self.queries = LinearNorm(key_dim * num_attention_heads)
        self.projection = LinearNorm(output_dim)

    def call(self, x):
        batch_size, seq_length, _ = tf.shape(x)

        key, value = tf.split(tf.reshape(self.keys_values(x), (
            batch_size, seq_length, self.num_attention_heads,
            -1)), [self.key_dim, self.attention_ratio * self.key_dim],
                      axis=3)

        key = tf.transpose(key, (0, 2, 1, 3))
        value = tf.transpose(value, (0, 2, 1, 3))

        query = self.queries(self.queries_subsample(x))
        query = tf.transpose(
            tf.reshape(query, (batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim)),
            (0, 2, 1, 3))

        attention = tf.matmul(query, key, transpose_b=True) * self.scale
        attention = tf.nn.softmax(attention, axis=-1)
        x = tf.reshape(tf.transpose(tf.matmul(attention, value), (0, 1, 3, 2)),
                       (batch_size, -1, self.out_dim_projection))
        x = self.projection(hard_swish(x))
        return x


def levit_stage(embed_dim,
                key_dim,
                num_attention_heads,
                resolution,
                depth,
                attention_ratio,
                mlp_ratio,
                drop_path,
                name='stage'):
    stages = []
    for i in range(depth):
        stages.append(
            Residual(
                Attention(input_dim=embed_dim,
                          key_dim=key_dim,
                          num_attention_heads=num_attention_heads,
                          attention_ratio=attention_ratio,
                          resolution=resolution,
                          name=name + '/attention' + str(i)),
                drop_path,
                name=name + '/attention' + str(i) + '/residual'))

        if mlp_ratio > 0:
            h = int(embed_dim * mlp_ratio)
            stages.append(
                Residual(MLP(input_dim=embed_dim, hidden_dim=h, name=name + '/mlp' + str(i)),
                         drop_path,
                         name=name + '/mlp' + str(i) + '/residual'))
    return tf.keras.Sequential(stages, name=name)


def levit_downsample(input_dim,
                     output_dim,
                     resolution,
                     resolution_out,
                     down_ops,
                     drop_path,
                     name='stage_downsample'):
    stages = []
    stages.append(
        AttentionDownsample(input_dim=input_dim,
                            output_dim=output_dim,
                            key_dim=down_ops['key_dim'],
                            num_attention_heads=down_ops['num_heads'],
                            attention_ratio=down_ops['attn_ratio'],
                            stride=down_ops['stride'],
                            resolution_in=resolution,
                            resolution_out=resolution_out,
                            name=name + '/attention'))
    if down_ops['mlp_ratio'] > 0:  # mlp_ratio
        h = int(output_dim * down_ops['mlp_ratio'])
        stages.append(
            Residual(MLP(input_dim=output_dim, hidden_dim=h, name=name + '/mlp'), drop_path, name=name + '/residual'))
    return tf.keras.Sequential(stages, name=name)

class LeVIT(tf.keras.Model):
    def __init__(self,
                 image_dim,
                 patch_size,
                 num_classes,
                 embed_dim=[192],
                 key_dim=[64],
                 depth=[12],
                 num_heads=[3],
                 attention_ratio=[2],
                 mlp_ratio=[2],
                 down_ops={},
                 distillation=True,
                 drop_path=0.,
                 name='Levit'):
        super(LeVIT, self).__init__(name=name)
        input_resolution_stage1 = image_dim // patch_size
        input_resolution_stage2 = (input_resolution_stage1 - 1) // down_ops[1]['stride'] + 1
        input_resolution_stage3 = (input_resolution_stage2 - 1) // down_ops[2]['stride'] + 1

        self.backbone = Backbone(embed_dim[0])

        self.stage1 = levit_stage(embed_dim=embed_dim[0],
                        key_dim=key_dim[0],
                        num_attention_heads=num_heads[0],
                        resolution=input_resolution_stage1,
                        depth=depth[0],
                        attention_ratio=attention_ratio[0],
                        mlp_ratio=mlp_ratio[0],
                        drop_path=drop_path,
                        name='stage1')

        self.stage1_downsample = levit_downsample(input_dim=embed_dim[0],
                                                  output_dim=embed_dim[1],
                                                  resolution=input_resolution_stage1,
                                                  resolution_out=input_resolution_stage2,
                                                  down_ops=down_ops[1],
                                                  drop_path=drop_path,
                                                  name='stage1_downsample')

        self.stage2 = levit_stage(embed_dim=embed_dim[1],
                        key_dim=key_dim[1],
                        num_attention_heads=num_heads[1],
                        resolution=input_resolution_stage2,
                        depth=depth[1],
                        attention_ratio=attention_ratio[1],
                        mlp_ratio=mlp_ratio[1],
                        drop_path=drop_path,
                        name='stage2')

        self.stage2_downsample = levit_downsample(input_dim=embed_dim[1],
                                                  output_dim=embed_dim[2],
                                                  resolution=input_resolution_stage2,
                                                  resolution_out=input_resolution_stage3,
                                                  down_ops=down_ops[2],
                                                  drop_path=drop_path,
                                                  name='stage2_downsample')

        self.stage3 = levit_stage(embed_dim=embed_dim[2],
                                  key_dim=key_dim[2],
                                  num_attention_heads=num_heads[2],
                                  resolution=input_resolution_stage3,
                                  depth=depth[2],
                                  attention_ratio=attention_ratio[2],
                                  mlp_ratio=mlp_ratio[2],
                                  drop_path=drop_path,
                                  name='stage3')

        self.class_head = NormLinear(num_classes)

    def call(self, x):
        x = self.backbone(x)
        batch_size, _, _, channels = tf.shape(x)
        x = tf.reshape(x, (batch_size, -1, channels))
        x = self.stage1(x)
        x = self.stage1_downsample(x)
        x = self.stage2(x)
        x = self.stage2_downsample(x)
        x = self.stage3(x)
        x = tf.math.reduce_mean(x, -1)
        x = self.class_head(x)
        return x



def levit(input_shape,
          patch_size,
          num_classes,
          embed_dim=[192],
          key_dim=[64],
          depth=[12],
          num_heads=[3],
          attention_ratio=[2],
          mlp_ratio=[2],
          down_ops={},
          distillation=True,
          drop_path=0.):
    inputs = layers.Input(shape=input_shape)
    input_resolution_stage1 = input_shape[0] // patch_size
    input_resolution_stage2 = (input_resolution_stage1 - 1) // down_ops[1]['stride'] + 1
    input_resolution_stage3 = (input_resolution_stage2 - 1) // down_ops[2]['stride'] + 1

    x = Backbone(embed_dim[0])(inputs)
    batch_size, input_resolution_stage1, input_resolution_stage1, channels = tf.shape(x)
    x = tf.reshape(x, (batch_size, input_resolution_stage1 * input_resolution_stage1, channels))

    x = levit_stage(embed_dim=embed_dim[0],
                    key_dim=key_dim[0],
                    num_attention_heads=num_heads[0],
                    resolution=input_resolution_stage1,
                    depth=depth[0],
                    attention_ratio=attention_ratio[0],
                    mlp_ratio=mlp_ratio[0],
                    drop_path=drop_path,
                    name='stage1')(x)

    x = levit_downsample(input_dim=embed_dim[0],
                         output_dim=embed_dim[1],
                         resolution=input_resolution_stage1,
                         resolution_out=input_resolution_stage2,
                         down_ops=down_ops[1],
                         drop_path=drop_path,
                         name='stage1_downsample')(x)

    x = levit_stage(embed_dim=embed_dim[1],
                    key_dim=key_dim[1],
                    num_attention_heads=num_heads[1],
                    resolution=input_resolution_stage2,
                    depth=depth[1],
                    attention_ratio=attention_ratio[1],
                    mlp_ratio=mlp_ratio[1],
                    drop_path=drop_path,
                    name='stage2')(x)

    x = levit_downsample(input_dim=embed_dim[1],
                         output_dim=embed_dim[2],
                         resolution=input_resolution_stage2,
                         resolution_out=input_resolution_stage3,
                         down_ops=down_ops[2],
                         drop_path=drop_path,
                         name='stage2_downsample')(x)

    x = levit_stage(embed_dim=embed_dim[2],
                    key_dim=key_dim[2],
                    num_attention_heads=num_heads[2],
                    resolution=input_resolution_stage3,
                    depth=depth[2],
                    attention_ratio=attention_ratio[2],
                    mlp_ratio=mlp_ratio[2],
                    drop_path=drop_path,
                    name='stage3')(x)

    x = tf.math.reduce_mean(x, -1)
    x = NormLinear(num_classes)(x)

    return tf.keras.Model(inputs=inputs, outputs=x)


def LeViT_128S(image_dim, num_classes=1000, distillation=False, pretrained=False):
    return model_factory(image_dim=image_dim,
                         **specification['LeViT_128S'],
                         num_classes=num_classes,
                         distillation=distillation,
                         pretrained=pretrained)


def LeViT_256(image_dim, num_classes=1000, distillation=False, pretrained=False):
    return model_factory(image_dim=image_dim,
                         **specification['LeViT_256'],
                         num_classes=num_classes,
                         distillation=distillation,
                         pretrained=pretrained)


def LeViT_384(image_dim, num_classes=1000, distillation=False, pretrained=False):
    return model_factory(image_dim=image_dim,
                         **specification['LeViT_384'],
                         num_classes=num_classes,
                         distillation=distillation,
                         pretrained=pretrained)


def model_factory(image_dim,
                  embed_dim,
                  key_dim,
                  depth,
                  num_heads,
                  drop_path,
                  weights,
                  num_classes,
                  distillation,
                  pretrained):
    model = LeVIT(
        image_dim,
        patch_size=16,
        embed_dim=embed_dim,
        num_heads=num_heads,
        key_dim=[key_dim] * 3,
        depth=depth,
        attention_ratio=[2, 2, 2],
        mlp_ratio=[2, 2, 2],
        down_ops={
            1: {
                'key_dim': key_dim, 'num_heads': embed_dim[0] // key_dim, 'attn_ratio': 4, 'mlp_ratio': 2, 'stride': 2
            },
            2: {
                'key_dim': key_dim, 'num_heads': embed_dim[1] // key_dim, 'attn_ratio': 4, 'mlp_ratio': 2, 'stride': 2
            },
        },
        num_classes=num_classes,
        drop_path=drop_path,
        distillation=distillation)

    return model

In [59]:
from levit_torch import LeViT_128S, LeViT_256, LeViT_384

In [62]:
from levit_torch import LeViT_128S, LeViT_256, LeViT_384
model_128 = LeViT_128S(num_classes=10)
print(sum(p.numel() for p in model_128.parameters() if p.requires_grad))
model_256 = LeViT_256(num_classes=10)
print(sum(p.numel() for p in model_256.parameters() if p.requires_grad))
model_384 = LeViT_384(num_classes=10)
print(sum(p.numel() for p in model_384.parameters() if p.requires_grad))

7010140
17871982
37596990


In [68]:
attention_biases = tf.Variable(tf.zeros((10, len([10, 20]))))


In [65]:
from levit_tf import LeViT_128S, LeViT_256, LeViT_384
import numpy as np

model_128 = LeViT_128S(image_dim=224, num_classes=10)
model_128(tf.ones((4, 224, 224, 3)))
print(np.sum([np.prod(v.get_shape().as_list()) for v in model_128.trainable_variables]))
model_256 = LeViT_256(image_dim=224, num_classes=10)
model_256(tf.ones((4, 224, 224, 3)))
print(np.sum([np.prod(v.get_shape().as_list()) for v in model_256.trainable_variables]))
model_384 = LeViT_384(image_dim=224, num_classes=10)
model_384(tf.ones((4, 224, 224, 3)))
print(np.sum([np.prod(v.get_shape().as_list()) for v in model_384.trainable_variables]))

7019226
17894122
37628538


In [None]:
        model = load_bit_m_r50_1(framework="tf")
        self.assertEqual(, 23496256)

    def test_bit_torch(self):
        model = load_bit_m_r50_1(framework="torch")
        self.assertEqual(, 23496256)

In [57]:
model = LeViT_128S(image_dim=224, num_classes=10)

In [58]:
output = model(tf.ones((4, 224, 224, 3)))
print(output.shape)

(4, 10)


In [8]:
embed_dim = [256, 384, 512]
key_dim = [32, 32, 32]
num_heads = [4, 6, 8]
depth =  [4, 4, 4]
drop_path =  0
attention_ratio=[2, 2, 2]
mlp_ratio=[2, 2, 2]
down_ops=[
    #('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
    [32, embed_dim[0] // 32, 4, 2, 2],
    [32, embed_dim[1] // 32, 4, 2, 2],
]

In [19]:
x = Backbone(embed_dim[0])(tf.ones((4, 224, 224, 3)))

In [20]:
batch_size, _, _ , channels = tf.shape(x)

In [21]:
input_shape = [224, 224, 3]
patch_size = 16
input_resolution_stage1 = input_shape[0] // patch_size
input_resolution_stage2 = (input_resolution_stage1 - 1) // down_ops[0][4] + 1
input_resolution_stage3 = (input_resolution_stage2 - 1) // down_ops[1][4] + 1

In [22]:
x = Backbone(embed_dim[0])(tf.ones((None, 224, 224, 3)))
x = tf.reshape(x, (None, -1, 256))
x = levit_stage(embed_dim=embed_dim[0],
                key_dim=key_dim[0],
                num_attention_heads=num_heads[0],
                resolution=input_resolution_stage1,
                depth=depth[0],
                attention_ratio=attention_ratio[0],
                mlp_ratio=mlp_ratio[0],
                drop_path=drop_path,
                name='stage1')(x)

x = levit_downsample(input_dim=embed_dim[0],
                         output_dim=embed_dim[1],
                         resolution=input_resolution_stage1,
                         resolution_out=input_resolution_stage2,
                         key_dim=down_ops[0][0],
                         num_heads=down_ops[0][1],
                         attn_ratio=down_ops[0][2],
                         mlp_ratio=down_ops[0][3],
                         stride=down_ops[0][4],
                         drop_path=drop_path,
                         name='stage1_downsample')(x)

x = levit_stage(embed_dim=embed_dim[1], key_dim=key_dim[1], num_attention_heads=num_heads[1], resolution=input_resolution_stage2, depth=depth[1], attention_ratio=attention_ratio[1], mlp_ratio=mlp_ratio[1], drop_path=drop_path, name='stage2')(x)

x = levit_downsample(input_dim=embed_dim[1],
                        output_dim=embed_dim[2],
                        resolution=input_resolution_stage2,
                        resolution_out=input_resolution_stage3,
                        key_dim=down_ops[1][0],
                        num_heads=down_ops[1][1],
                        attn_ratio=down_ops[1][2],
                        mlp_ratio=down_ops[1][3],
                        stride=down_ops[1][4],
                        drop_path=drop_path,
                        name='stage2_downsample')(x)

x = levit_stage(embed_dim=embed_dim[2], key_dim=key_dim[2], num_attention_heads=num_heads[2], resolution=input_resolution_stage3, depth=depth[2], attention_ratio=attention_ratio[2], mlp_ratio=mlp_ratio[2], drop_path=drop_path, name='stage3')(x)

ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

In [306]:
x = tf.math.reduce_mean(x, -1)
print(x.shape)
x = NormLinear(10)(x)

(4, 16)


In [304]:
x.shape

TensorShape([4, 10])

In [290]:
downsample = AttentionDownsample(input_dim=embed_dim[0],
                                 output_dim=embed_dim[1],
                                 key_dim=down_ops[0][0],
                                 num_attention_heads=down_ops[0][1],
                                 attention_ratio=down_ops[0][2],
                                 stride=down_ops[0][4],
                                 resolution_in=input_resolution_stage1,
                                 resolution_out=input_resolution_stage2)


attention_ratio 4
key_dim 32
num_attention_heads 8
out_dim_keys_values 1280
out_dim_projection 1024


In [200]:
hidden_state.shape

TensorShape([4, 196, 256])

In [186]:
query.shape, key.shape, value.shape

(TensorShape([4, 4, 196, 32]),
 TensorShape([4, 4, 196, 32]),
 TensorShape([4, 4, 196, 64]))

In [198]:
hidden_state.shape

TensorShape([4, 196, 256])

InvalidArgumentError: Exception encountered when calling layer 'attention' (type Attention).

{{function_node __wrapped__Transpose_device_/job:localhost/replica:0/task:0/device:CPU:0}} transpose expects a vector of size 4. But input(1) is a vector of size 2 [Op:Transpose]

Call arguments received by layer 'attention' (type Attention):
  • x=tf.Tensor(shape=(4, 196, 256), dtype=float32)

In [159]:
x.shape

TensorShape([4, 14, 14, 256])

<tf.Tensor: shape=(4, 196, 256), dtype=float32, numpy=
array([[[-9.5882276e-03, -2.1390039e-03,  2.2422159e-03, ...,
          6.2611694e-03,  8.9580305e-03,  2.3321765e-03],
        [-9.5882276e-03, -2.1390039e-03,  2.2422159e-03, ...,
          6.2611694e-03,  8.9580305e-03,  2.3321765e-03],
        [-9.5882276e-03, -2.1390039e-03,  2.2422159e-03, ...,
          6.2611694e-03,  8.9580305e-03,  2.3321765e-03],
        ...,
        [ 4.9301329e-05,  2.9083407e-03,  4.7255424e-03, ...,
          3.0111913e-03,  5.1256054e-04, -4.8995493e-03],
        [ 4.9301329e-05,  2.9083407e-03,  4.7255424e-03, ...,
          3.0111913e-03,  5.1256054e-04, -4.8995493e-03],
        [-2.9130569e-03, -2.9273028e-03,  1.0818355e-03, ...,
         -2.9805668e-03,  4.4568088e-03, -9.3067074e-03]],

       [[-9.5882276e-03, -2.1390039e-03,  2.2422159e-03, ...,
          6.2611694e-03,  8.9580305e-03,  2.3321765e-03],
        [-9.5882276e-03, -2.1390039e-03,  2.2422159e-03, ...,
          6.2611694e-03,  8.

In [3]:



x = levit_downsample(input_dim=embed_dim[0], output_dim=embed_dim[1], resolution=input_resolution_stage1, resolution_out=input_resolution_stage2, key_dim=down_ops[0][0], num_heads=down_ops[0][1], attn_ratio=down_ops[0][2], mlp_ratio=down_ops[0][3], stride=down_ops[0][4], drop_path=drop_path)(x)

x = levit_stage(embed_dim=embed_dim[1], key_dim=key_dim[1], num_attention_heads=num_heads[1], resolution=input_resolution_stage2, depth=depth[1], attention_ratio=attention_ratio[1], mlp_ratio=mlp_ratio[1], drop_path=drop_path, name='stage2')(x)

x = levit_downsample(input_dim=embed_dim[1], output_dim=embed_dim[2], resolution=input_resolution_stage2, resolution_out=input_resolution_stage3, key_dim=down_ops[1][0], num_heads=down_ops[1][1], attn_ratio=down_ops[1][2], mlp_ratio=down_ops[1][3], stride=down_ops[1][4], drop_path=drop_path)(x)

x = levit_stage(embed_dim=embed_dim[2], key_dim=key_dim[2], num_attention_heads=num_heads[2], resolution=input_resolution_stage3, depth=depth[2], attention_ratio=attention_ratio[2], mlp_ratio=mlp_ratio[2], drop_path=drop_path, name='stage3')(x)

x = tf.math.reduce_mean(x, -1)
x = NormLinear(num_classes)(x)

In [72]:
import tensorflow as tf 
tf.ones((10, 224, 224, 10)).softmax(axis=1)

AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'softmax'

In [9]:
import torch 
input_data = torch.ones(4, 3, 224, 224)

In [12]:
patch_embded_output = model.patch_embed(input_data)

In [73]:
patch_embded_output.shape

torch.Size([4, 256, 14, 14])

In [81]:
patch_embded_output1 = patch_embded_output.flatten(2).transpose(1, 2)

In [28]:
patch_embded_output1.shape

torch.Size([4, 196, 256])

In [25]:
class LinearNorm(torch.nn.Module):
    def __init__(self, in_features, out_features, bn_weight_init=1):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features, bias=False)
        self.batch_norm = torch.nn.BatchNorm1d(out_features)
        torch.nn.init.constant_(self.batch_norm.weight, bn_weight_init)

    def forward(self, x):
        x = self.linear(x)
        return self.batch_norm(x.flatten(0, 1)).reshape_as(x)

queries_keys_values = LinearNorm(256, 512)

In [27]:
queries_keys_values(patch_embded_output1).shape

torch.Size([4, 196, 512])

In [124]:
data = tf.reshape(tf.ones((4, 196, 512)), (4, 196, 4, -1))

In [31]:
reshape_output = queries_keys_values(patch_embded_output1).view(4, 196, 4, -1)


In [91]:
reshape_output.shape

torch.Size([4, 196, 4, 128])

In [33]:
query, key, value = reshape_output.split([32, 32, 2 * 32], dim=3)


In [127]:
query, key, value = tf.split(data, [32, 32, 2 * 32], axis=3)


In [34]:
query.shape, key.shape, value.shape

(torch.Size([4, 196, 4, 32]),
 torch.Size([4, 196, 4, 32]),
 torch.Size([4, 196, 4, 64]))

In [128]:
query.shape, key.shape, value.shape

(TensorShape([4, 196, 4, 32]),
 TensorShape([4, 196, 4, 32]),
 TensorShape([4, 196, 4, 64]))

In [35]:
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)

In [36]:
query.shape, key.shape, value.shape

(torch.Size([4, 4, 196, 32]),
 torch.Size([4, 4, 196, 32]),
 torch.Size([4, 4, 196, 64]))

In [39]:
key.transpose(-2, -1).shape

torch.Size([4, 4, 32, 196])

In [43]:
attention = query @ key.transpose(-2, -1) * 32**-0.5

In [44]:
attention.shape

torch.Size([4, 4, 196, 196])

In [42]:
32**-0.5

0.1767766952966369

In [45]:
import itertools
resolution = 14
points = list(itertools.product(range(resolution), range(resolution)))

len_points = len(points)
attention_offsets, indices = {}, []
for p1 in points:
    for p2 in points:
        offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
        if offset not in attention_offsets:
            attention_offsets[offset] = len(attention_offsets)
        indices.append(attention_offsets[offset])

In [92]:
len(indices)

38416

In [98]:
from tensorflow.keras import layers

@tf.function
def hard_swish(features):
    """Computes a hard version of the swish function.

    This operation can be used to reduce computational cost and improve
    quantization for edge devices.

    Args:
        features: A `Tensor` representing preactivation values.

    Returns:
        The activation value.
    """
    return features * tf.nn.relu6(features + tf.cast(3., features.dtype)) * (1. / 6.)

class Backbone(layers.Layer):
    def __init__(self, out_channels):
        super(Backbone, self).__init__(name='backbone')
        self.convolution_layer1 = layers.Conv2D(filters=out_channels // 8,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer1')
        self.batch_norm1 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm1')
        self.convolution_layer2 = layers.Conv2D(filters=out_channels // 4,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer2')
        self.batch_norm2 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm2')
        self.convolution_layer3 = layers.Conv2D(filters=out_channels // 2,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer3')
        self.batch_norm3 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm3')
        self.convolution_layer4 = layers.Conv2D(filters=out_channels,
                                                kernel_size=3,
                                                strides=2,
                                                padding="same",
                                                use_bias=False,
                                                name='convolution_layer4')
        self.batch_norm4 = layers.BatchNormalization(gamma_initializer='ones', name='batch_norm4')

    def call(self, x):
        x = hard_swish(self.batch_norm1(self.convolution_layer1(x)))
        x = hard_swish(self.batch_norm2(self.convolution_layer2(x)))
        x = hard_swish(self.batch_norm3(self.convolution_layer3(x)))
        x = hard_swish(self.batch_norm4(self.convolution_layer4(x)))
        return x

In [99]:
model = Backbone(256)

In [105]:
for weight in model.weights:
    print(weight.name)

backbone/convolution_layer1/kernel:0
backbone/batch_norm1/gamma:0
backbone/batch_norm1/beta:0
backbone/convolution_layer2/kernel:0
backbone/batch_norm2/gamma:0
backbone/batch_norm2/beta:0
backbone/convolution_layer3/kernel:0
backbone/batch_norm3/gamma:0
backbone/batch_norm3/beta:0
backbone/convolution_layer4/kernel:0
backbone/batch_norm4/gamma:0
backbone/batch_norm4/beta:0
backbone/batch_norm1/moving_mean:0
backbone/batch_norm1/moving_variance:0
backbone/batch_norm2/moving_mean:0
backbone/batch_norm2/moving_variance:0
backbone/batch_norm3/moving_mean:0
backbone/batch_norm3/moving_variance:0
backbone/batch_norm4/moving_mean:0
backbone/batch_norm4/moving_variance:0


In [101]:
model(tf.ones((3, 224, 224, 3))).shape

TensorShape([3, 14, 14, 256])

In [52]:
torch.LongTensor(indices).view(len_points, len_points)

tensor([[  0,   1,   2,  ..., 193, 194, 195],
        [  1,   0,   1,  ..., 192, 193, 194],
        [  2,   1,   0,  ..., 191, 192, 193],
        ...,
        [193, 192, 191,  ...,   0,   1,   2],
        [194, 193, 192,  ...,   1,   0,   1],
        [195, 194, 193,  ...,   2,   1,   0]])

In [64]:
import tensorflow as tf
data = tf.reshape(indices, (len_points, len_points))

In [66]:
data.backing_device

'/job:localhost/replica:0/task:0/device:CPU:0'

In [57]:
attention_biases = torch.nn.Parameter(torch.zeros(4, len(attention_offsets)))


In [59]:
attention_biases.shape

torch.Size([4, 196])

In [60]:
attention_bias_cache = {}

In [62]:
attention_bias_cache['cpu'] = attention_biases[:, torch.LongTensor(indices).view(len_points, len_points)]

In [61]:
torch.LongTensor(indices).view(len_points, len_points)

tensor([[  0,   1,   2,  ..., 193, 194, 195],
        [  1,   0,   1,  ..., 192, 193, 194],
        [  2,   1,   0,  ..., 191, 192, 193],
        ...,
        [193, 192, 191,  ...,   0,   1,   2],
        [194, 193, 192,  ...,   1,   0,   1],
        [195, 194, 193,  ...,   2,   1,   0]])

In [63]:
attention_bias_cache['cpu']

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 