In [331]:
import tensorflow as tf
import torch as tc
import torch.nn as nn
import torch.nn.functional as F
import math

ACT2FN = {
    "relu": F.relu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
    "sigmoid": torch.sigmoid,
}

tf.random.set_seed(42)

In [32]:
inp = tf.random.uniform(shape=(8, 30), minval=1, maxval=1000, dtype=tf.int32)

In [122]:
embedding = tf.keras.layers.Embedding(1000, 768)

In [123]:
embed = embedding(inp)
embed.shape

TensorShape([8, 30, 768])

In [126]:
conv = tf.keras.layers.Conv1D(filters=3072, kernel_size=1, strides=(1, ), padding="same", activation='relu')

In [127]:
conv(embed).shape

TensorShape([8, 30, 3072])

In [210]:
tc.random.manual_seed(42)
inp2 = tc.randint(1, 1000, (8, 30), dtype=tc.long)

In [211]:
embedding2 = tc.nn.Embedding(1000, 768)

In [231]:
embed2 = embedding2(inp2)
embed2.shape

torch.Size([8, 30, 768])

In [232]:
conv2 = nn.Conv1d(768, 3072, 1, (1,), groups=1)

In [319]:
conv2(embed2.permute(0, 2, 1)).shape

torch.Size([8, 3072, 30])

In [154]:
class SqueezeBertConfig():

    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_act="gelu",
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=512,
        type_vocab_size=2,
        initializer_range=0.02,
        layer_norm_eps=1e-12,
        pad_token_id=0,
        embedding_size=768,
        q_groups=4,
        k_groups=4,
        v_groups=4,
        post_attention_groups=1,
        intermediate_groups=4,
        output_groups=4,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        self.pad_token_id = pad_token_id
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_act = hidden_act
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps
        self.embedding_size = embedding_size
        self.q_groups = q_groups
        self.k_groups = k_groups
        self.v_groups = v_groups
        self.post_attention_groups = post_attention_groups
        self.intermediate_groups = intermediate_groups
        self.output_groups = output_groups

In [155]:
config = SqueezeBertConfig()

In [157]:
class SqueezeBertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(
            config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))

    def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [214]:
sbe = SqueezeBertEmbeddings(config)

In [215]:
ebd = sbe(inp2)
ebd.shape

torch.Size([8, 30, 768])

In [386]:
class SqueezeBertSelfAttention(nn.Module):
    def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1):
        super().__init__()
        if cin % config.num_attention_heads != 0:
            raise ValueError(
                "cin (%d) is not a multiple of the number of attention "
                "heads (%d)" % (cin, config.num_attention_heads)
            )
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(cin / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups)
        self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups)
        self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.softmax = nn.Softmax(dim=-1)

        self.matmul_qk = MatMulWrapper()
        self.matmul_qkv = MatMulWrapper()

    def transpose_for_scores(self, x):
        """
        - input: [N, C, W]
        - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
        """
        new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1])  # [N, C1, C2, W]
        x = x.view(*new_x_shape)
        return x.permute(0, 1, 3, 2)  # [N, C1, C2, W] --> [N, C1, W, C2]

    def transpose_key_for_scores(self, x):
        """
        - input: [N, C, W]
        - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
        """
        new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1])  # [N, C1, C2, W]
        x = x.view(*new_x_shape)
        # no `permute` needed
        return x

    def transpose_output(self, x):
        """
        - input: [N, C1, W, C2]
        - output: [N, C, W]
        """
        x = x.permute(0, 1, 3, 2).contiguous()  # [N, C1, C2, W]
        new_x_shape = (x.size()[0], self.all_head_size, x.size()[3])  # [N, C, W]
        x = x.view(*new_x_shape)
        return x

    def forward(self, hidden_states, attention_mask, output_attentions=True):
        """
        expects hidden_states in [N, C, W] data layout.
        The attention_mask data layout is [N, W], and it does not need to be transposed.
        """
        print("hidden_states shape: ", hidden_states.shape)
        mixed_query_layer = self.query(hidden_states)
        print("mixed_query_layer shape: ", mixed_query_layer.shape)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_key_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_score = self.matmul_qk(query_layer, key_layer)
        attention_score = attention_score / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        
#         print(attention_score.shape)
        attention_score = attention_score + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = self.softmax(attention_score)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = self.matmul_qkv(attention_probs, value_layer)
        context_layer = self.transpose_output(context_layer)

        result = {"context_layer": context_layer}
        if output_attentions:
            result["attention_score"] = attention_score
        return result

In [395]:
class MatMulWrapper(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, mat1, mat2):
        return torch.matmul(mat1, mat2)

In [396]:
c0 = config.hidden_size
c1 = config.hidden_size
c2 = config.intermediate_size
c3 = config.hidden_size

In [397]:
attention = SqueezeBertSelfAttention(
            config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups)

In [390]:
input_shape = inp2.shape
device = ebd.device
am = torch.ones(input_shape, device=device)
am.shape

torch.Size([8, 30])

In [391]:
eam = am[:, None, None, :]
eam.shape

torch.Size([8, 1, 1, 30])

In [392]:
hidden_states = ebd.permute(0, 2, 1)
hidden_states.shape

torch.Size([8, 768, 30])

In [561]:
nc = nn.Conv1d(in_channels=768, out_channels=768, kernel_size=1, stride=(1,), groups=4)
nc(hidden_states).shape

torch.Size([8, 768, 30])

In [491]:
nc.weight.shape

torch.Size([768, 192, 1])

In [393]:
ot = attention(hidden_states, attention_mask=eam)

hidden_states shape:  torch.Size([8, 768, 30])
mixed_query_layer shape:  torch.Size([8, 768, 30])


In [384]:
ot["context_layer"].shape

torch.Size([8, 768, 30])

In [385]:
ot["attention_score"].shape

torch.Size([8, 12, 30, 30])

In [301]:
class SqueezeBertLayerNorm(nn.LayerNorm):
    """
    This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.
    N = batch C = channels W = sequence length
    """

    def __init__(self, hidden_size, eps=1e-12):
        nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps)  # instantiates self.{weight, bias, eps}

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = nn.LayerNorm.forward(self, x)
        return x.permute(0, 2, 1)

In [302]:
class ConvDropoutLayerNorm(nn.Module):
    """
    ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
    """

    def __init__(self, cin, cout, groups, dropout_prob):
        super().__init__()

        self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
        self.layernorm = SqueezeBertLayerNorm(cout)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, input_tensor):
        x = self.conv1d(hidden_states)
        x = self.dropout(x)
        x = x + input_tensor
        x = self.layernorm(x)
        return x

In [367]:
class ConvDropoutLayerNorm2(nn.Module):
    """
    ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
    """

    def __init__(self, cin, cout, groups, dropout_prob):
        super().__init__()

        self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
        self.layernorm = nn.LayerNorm(cout)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, hidden_states, input_tensor):
        x = self.conv1d(hidden_states)
        x = self.dropout(x)
        x = x + input_tensor
        x = x.permute(0, 2, 1)
        x = self.layernorm(x)
        x = x.permute(0, 2, 1)
        return x

In [368]:
ao = ConvDropoutLayerNorm(c0, c1, config.post_attention_groups, dropout_prob=config.hidden_dropout_prob)

In [369]:
ao2 = ConvDropoutLayerNorm2(c0, c1, config.post_attention_groups, dropout_prob=config.hidden_dropout_prob)

In [370]:
attn_output2 = ao2(ot["context_layer"], hidden_states)
attn_output2.shape

torch.Size([8, 30, 768])
torch.Size([8, 30, 768])


torch.Size([8, 768, 30])

In [335]:
attn_output = ao(ot["context_layer"], hidden_states)
attn_output.shape

torch.Size([8, 768, 30])

In [325]:
class ConvActivation(nn.Module):
    """
    ConvActivation: Conv, Activation
    """

    def __init__(self, cin, cout, groups, act):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
        self.act = ACT2FN[act]

    def forward(self, x):
        output = self.conv1d(x)
        return self.act(output)

In [345]:
config.intermediate_groups

4

In [333]:
imd = ConvActivation(c1, c2, config.intermediate_groups, config.hidden_act)

In [341]:
fc_out = imd(attn_output)
fc_out.size()

torch.Size([8, 3072, 30])

In [340]:
fo = ConvDropoutLayerNorm(c2, c3, config.output_groups, dropout_prob=config.hidden_dropout_prob)

In [344]:
fo(fc_out, attn_output).shape

torch.Size([8, 768, 30])

In [689]:
config.intermediate_groups, config.output_groups

(4, 4)

## Conv

In [683]:
m = nn.Conv1d(768, 64, 3, stride=1, padding=0, groups=2)
# batch_size, hidden_size, seq_len
input = torch.randn(4, 768, 512)
m(input).shape, m.weight.shape

(torch.Size([4, 64, 510]), torch.Size([64, 384, 3]))

In [686]:
class CnnTestModel(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.n = tf.keras.layers.Conv1D(64, 3, strides=1, padding='valid', groups=2, activation='relu')
    
    @tf.function(experimental_compile=True)
    def call(self, x):
        out = self.n(x)
        print(self.n.weights[0].shape, self.n.weights[1].shape)
        return out

ctm = CnnTestModel()
# batch_size, seq_len, hidden_size
x = tf.random.normal((4, 512, 768))
out = ctm(x)
out.shape

(3, 384, 64) (64,)
(3, 384, 64) (64,)


TensorShape([4, 510, 64])