In [1]:
import os
import math
import torch
import torch.nn as nn
from models.modules.encoder import RnnUserEncoder
from models.modules.weighter import CnnWeighter

ModuleNotFoundError: No module named 'thop'

In [2]:
os.path.split("/data/v-pezhang/Code/GateFormer/src/data/ckpts/UserOneTowerGateFormer-weight_Cnn-4/large/185045.model")[-1]

'185045.model'

In [2]:
class CnnWeighter(nn.Module):
    def __init__(self, manager):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(
                in_channels=manager.gate_embedding_dim,
                out_channels=manager.gate_hidden_dim,
                kernel_size=3,
                padding=1
            ),
            nn.ReLU()
        )
        self.weightPooler = nn.Sequential(
            nn.Linear(manager.gate_hidden_dim, manager.gate_hidden_dim),
            nn.ReLU(),
            nn.Dropout(manager.dropout_p),
            nn.Linear(manager.gate_hidden_dim, 1)
        )

        nn.init.xavier_normal_(self.cnn[0].weight)


    def _compute_weight(self, embeddings):
        weights = self.weightPooler(embeddings).squeeze(-1)
        return weights


    def forward(self, embedding):
        """
        Args:
            token_id: [B, L]
            attn_mask: [B, L]

        Returns:
            weights: [B, L]
        """
        original_shape = embedding.shape[:-1]
        cnn_input = embedding.transpose(-1, -2)
        conv_embedding = self.cnn(cnn_input).transpose(-1, -2).view(*original_shape, -1)
        weight = self._compute_weight(conv_embedding)
        return weight

In [3]:
class m:
    hidden_dim = 768
    gate_embedding_dim = 768
    gate_hidden_dim = 768
    dropout_p = 0.1
    vocab_size = 30522
manager = m()
model = CnnWeighter(manager)
x = torch.rand(1, 32, manager.gate_embedding_dim)

# FLOPs of original BERT
macs, params = profile(model, inputs=(x,))
macs * 2

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[91m[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[00m
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[91m[WARN] Cannot find rule for <class '__main__.CnnWeighter'>. Treat it as zero Macs and zero Params.[00m


151093248.0

In [10]:
# FLOPs of BaseBert
L = 32

embedding = 2 * L * 768 * 3
bert_project = 12 * L * 64 * 64* 2 * 3
bert_attn = 12 * L * 64 * L * 2 + 12 * L * L * 64 * 2
bert_intm =  L * 768 * 768 * 2 +  L * 768 * 2 +  L * 768 * 3072 * 4

bert = (bert_project + bert_attn + bert_intm) * 12

all = embedding + bert
all * 12 / 1e9

50.743148544

In [3]:
# FLOPs of Synthesizer
L = 32

embedding = 2 * L * 768 * 3
bert_project = 12 * L * 64 * 64 * 2
synthesizer_attn = 12 * (L * 64 * L * 2 + L * L * L * 2) + 12 * L * L * 64 * 2
bert_intm =  L * 768 * 768 * 2 +  L * 768 * 2 +  L * 768 * 3072 * 4

bert = (bert_project + synthesizer_attn + bert_intm) * 12

all = embedding + bert
all

4162535424

In [9]:
# FLOPs of baseline pruners
K = 8

embedding = 2 * K * 768 * 3
bert_project = 12 * K * 64 * 2 * 3 * 64
bert_attn = 12 * K * 64 * K * 2 + 12 * K * K * 64 * 2
bert_intm =  K * 768 * 768 * 2 +  K * 768 * 2 +  K * 768 * 3072 * 4

bert = (bert_project + bert_attn + bert_intm) * 12

pruner = embedding + bert
pruner * 12 / 1e9

12.60085248

In [10]:
import math
# FLOPs of InfoGate
L = 32
K = 8
D = 300

gate_embedding = 2 * L * D
# gate_weighting = 151093248
gate_weighting = L * 300 * 300 * 3

gate_sort = L * math.log2(L)
gate_all = gate_embedding + gate_weighting + gate_sort

embedding = 2 * K * 768 * 3
bert_project = 12 * K * 64 * 2 * 3 * 64
bert_attn = 12 * K * 64 * K * 2 + 12 * K * K * 64 * 2
bert_intm =  K * 768 * 768 * 2 +  K * 768 * 2 +  K * 768 * 3072 * 4

bert = embedding + (bert_project + bert_attn + bert_intm) * 12

infogate = bert + gate_all
infogate * 12 / 1e9, gate_all

(12.7047648, 8659360.0)

In [12]:
# FLOPs of InfoGate(Trans)
L = 32
K = 8
D = 300

gate_embedding = 2 * L * D
gate_weighting = 12 * L * 64 * 64 * 2 * 3 + 12 * L * 64 * L * 2 + 12 * L * L * 64 * 2 + L * 768 * 768 * 2 +  L * 768 * 2 +  L * 768 * 3072 * 4
gate_sort = L * math.log2(L)
gate_all = gate_embedding + gate_weighting + gate_sort

embedding = 2 * K * 768 * 3
bert_project = 12 * K * 64 * 2 * 3 * 64
bert_attn = 12 * K * 64 * K * 2 + 12 * K * K * 64 * 2
bert_intm =  K * 768 * 768 * 2 +  K * 768 * 2 +  K * 768 * 3072 * 4

bert = (bert_project + bert_attn + bert_intm) * 12

infogate = embedding + bert + gate_all
infogate * 12 / 1e9, gate_all

(16.829533056, 352390048.0)

In [2]:
import math
# FLOPs of KeyBERT
L = 32
K = 8
D = 300

embedding = 2 * L * 768 * 3
bert_project = 12 * L * 64 * 64* 2 * 3
bert_attn = 12 * L * 64 * L * 2 + 12 * L * L * 64 * 2
bert_intm =  L * 768 * 768 * 2 +  L * 768 * 2 +  L * 768 * 3072 * 4
gate_sort = L * math.log2(L)
gate_all = embedding + (bert_project + bert_attn + bert_intm) * 6 + gate_sort

embedding = 2 * K * 768 * 3
bert_project = 12 * K * 64 * 2 * 3 * 64
bert_attn = 12 * K * 64 * K * 2 + 12 * K * K * 64 * 2
bert_intm =  K * 768 * 768 * 2 +  K * 768 * 2 +  K * 768 * 3072 * 4

bert = embedding + (bert_project + bert_attn + bert_intm) * 12

infogate = bert + gate_all
infogate * 12 / 1e9, gate_all

(37.973313408, 2114371744.0)

In [9]:
pruner / infogate, all / infogate, all / pruner

(0.7716033974308345, 6.222750033481267, 8.064700148030473)

In [14]:
class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(30522, 768)
    def forward(self, x):
        return self.embedding(x)


m = M()
x = torch.rand(1, 32, 768)
macs, params = profile(m, inputs=(x,))

[91m[WARN] Cannot find rule for <class 'torch.nn.modules.sparse.Embedding'>. Treat it as zero Macs and zero Params.[00m
[91m[WARN] Cannot find rule for <class '__main__.M'>. Treat it as zero Macs and zero Params.[00m


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

In [None]:
mac