In [1]:
import torch.nn as nn
from transformers import BertModel, BertConfig

# 1. Inverted Bottleneck 구조 정의
class InvertedBottleneck(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(InvertedBottleneck, self).__init__()
        self.expand = nn.Linear(input_dim, hidden_dim)
        self.squeeze = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.squeeze(self.expand(x))

# 2. 주어진 모델의 Feed-forward 레이어를 Inverted Bottleneck 구조로 변환
def convert_to_inverted_bottleneck_bert(model, bottleneck_ratio=4):
    for layer in model.encoder.layer:
        input_dim = layer.intermediate.dense.in_features
        hidden_dim = input_dim * bottleneck_ratio
        output_dim = layer.intermediate.dense.out_features
        
        layer.intermediate.dense = InvertedBottleneck(input_dim, hidden_dim, output_dim)
        
    return model

# 3. 모델 레이어 구조 출력
def print_layers(model, indentation=0):
    for name, child in model.named_children():
        print('  ' * indentation + name, type(child).__name__)
        if len(list(child.children())) > 0:
            print_layers(child, indentation + 1)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# KoBERT 로딩
from transformers import BertModel
kobert_model = BertModel.from_pretrained('monologg/kobert')

# KoBERT를 inverted-bottleneck 구조로 변환
inverted_bottleneck_kobert = convert_to_inverted_bottleneck_bert(kobert_model)

# 변환된 모델의 레이어 구조 출력
print_layers(inverted_bottleneck_kobert)


Downloading model.safetensors: 100%|██████████| 369M/369M [00:31<00:00, 11.7MB/s] 
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


embeddings BertEmbeddings
  word_embeddings Embedding
  position_embeddings Embedding
  token_type_embeddings Embedding
  LayerNorm LayerNorm
  dropout Dropout
encoder BertEncoder
  layer ModuleList
    0 BertLayer
      attention BertAttention
        self BertSelfAttention
          query Linear
          key Linear
          value Linear
          dropout Dropout
        output BertSelfOutput
          dense Linear
          LayerNorm LayerNorm
          dropout Dropout
      intermediate BertIntermediate
        dense InvertedBottleneck
          expand Linear
          squeeze Linear
        intermediate_act_fn GELUActivation
      output BertOutput
        dense Linear
        LayerNorm LayerNorm
        dropout Dropout
    1 BertLayer
      attention BertAttention
        self BertSelfAttention
          query Linear
          key Linear
          value Linear
          dropout Dropout
        output BertSelfOutput
          dense Linear
          LayerNorm LayerNorm
          d