In [1]:
import torch.nn as nn
from transformers import BertModel, BertConfig

class InvertedBottleneck(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(InvertedBottleneck, self).__init__()
        self.expand = nn.Linear(input_dim, hidden_dim)
        self.squeeze = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.squeeze(self.expand(x))

def convert_to_inverted_bottleneck_bert(bert_model, bottleneck_ratio=4):
    for layer in bert_model.encoder.layer:
        input_dim = layer.intermediate.dense.in_features
        hidden_dim = input_dim * bottleneck_ratio
        output_dim = layer.intermediate.dense.out_features
        
        layer.intermediate.dense = InvertedBottleneck(input_dim, hidden_dim, output_dim)
        
    return bert_model

bert_config = BertConfig()
bert_model = BertModel(bert_config)

# BERT를 inverted-bottleneck BERT로 변환
inverted_bottleneck_bert = convert_to_inverted_bottleneck_bert(bert_model)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def print_layers(model, indentation=0):
    for name, child in model.named_children():
        print('  ' * indentation + name, type(child).__name__)
        if len(list(child.children())) > 0:
            print_layers(child, indentation + 1)

print_layers(inverted_bottleneck_bert)

embeddings BertEmbeddings
  word_embeddings Embedding
  position_embeddings Embedding
  token_type_embeddings Embedding
  LayerNorm LayerNorm
  dropout Dropout
encoder BertEncoder
  layer ModuleList
    0 BertLayer
      attention BertAttention
        self BertSelfAttention
          query Linear
          key Linear
          value Linear
          dropout Dropout
        output BertSelfOutput
          dense Linear
          LayerNorm LayerNorm
          dropout Dropout
      intermediate BertIntermediate
        dense InvertedBottleneck
          expand Linear
          squeeze Linear
        intermediate_act_fn GELUActivation
      output BertOutput
        dense Linear
        LayerNorm LayerNorm
        dropout Dropout
    1 BertLayer
      attention BertAttention
        self BertSelfAttention
          query Linear
          key Linear
          value Linear
          dropout Dropout
        output BertSelfOutput
          dense Linear
          LayerNorm LayerNorm
          d