In [1]:
import torch

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
import transformers

In [4]:
transformers.__file__

'/home/wk247/workspace/stitching/src/transformers/__init__.py'

### Tokenizer

In [5]:
from transformers import BertTokenizer

# vocabs are identical for small and large
tokenizer = BertTokenizer('../vocab.txt')

### Models

In [6]:
# from transformers
from transformers import AutoModel

small_model = AutoModel.from_pretrained("prajjwal1/bert-mini").to(device)
large_model = AutoModel.from_pretrained("prajjwal1/bert-small").to(device)

Some weights of the model checkpoint at prajjwal1/bert-mini were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.predictions

In [7]:
# define small model 1, 2
import copy

small_model1 = small_model
small_model2 = copy.deepcopy(small_model)

In [8]:
# change configs to return intermediate outputs
# https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForPreTraining.forward
small_model.config.output_hidden_states=True
small_model.config.output_attentions=True
large_model.config.output_hidden_states=True
large_model.config.output_attentions=True

In [9]:
# model configs
print("small model")
print(f'- num_parameters: {small_model.num_parameters()}')
print(f"- hidden_size: {small_model.config.hidden_size}")
print(f"- num_attention_heads: {small_model.config.num_attention_heads}")
print(f"- num_hidden_layers: {small_model.config.num_hidden_layers}")
print()

print("large model")
print(f'- num_parameters: {large_model.num_parameters()}')
print(f"- hidden_size: {large_model.config.hidden_size}")
print(f"- num_attention_heads: {large_model.config.num_attention_heads}")
print(f"- num_hidden_layers: {large_model.config.num_hidden_layers}")

small model
- num_parameters: 11170560
- hidden_size: 256
- num_attention_heads: 4
- num_hidden_layers: 4

large model
- num_parameters: 28763648
- hidden_size: 512
- num_attention_heads: 8
- num_hidden_layers: 4


### Stitched config/model

In [10]:
from transformers import StitchedBertConfig

stitched_config = StitchedBertConfig(**small_model.config.to_dict())
stitched_config

StitchedBertConfig {
  "_name_or_path": "prajjwal1/bert-mini",
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 1024,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "stitched_bert",
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "output_attentions": true,
  "output_hidden_states": true,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "stitch_hidden_size": 512,
  "stitch_intermediate_size": 2048,
  "transformers_version": "4.17.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [11]:
from transformers import StitchedBertModel

stitched_model = StitchedBertModel(stitched_config).to(device)

#### 1. forward

In [12]:
# forward
text = "Example input"
encoded_input = tokenizer(text, return_tensors='pt').to(device)
# small_output = small_model(**encoded_input, return_dict=True)

In [13]:
output = stitched_model(**encoded_input, return_dict=True)

In [14]:
# output shapes
print("last_hidden_state:", output["last_hidden_state"].shape)
print("pooler_output:", output["pooler_output"].shape)

last_hidden_state: torch.Size([1, 4, 512])
pooler_output: torch.Size([1, 512])


In [15]:
# includes the last hidden states
assert torch.isclose(output.hidden_states[-1], output.last_hidden_state).all()

len(output.hidden_states)

5

In [16]:
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
len(output.attentions)

4

In [17]:
# attentions
for attn in output.attentions:
    print(attn.shape)

torch.Size([2, 4, 4, 4])
torch.Size([2, 4, 4, 4])
torch.Size([2, 4, 4, 4])
torch.Size([2, 4, 4, 4])


#### 2. copy modules

In [18]:
import copy

In [19]:
# copy embeddings
# NOTE: check if deepcopy is okay
stitched_model.embeddings1 = copy.deepcopy(small_model1.embeddings)
stitched_model.embeddings2 = copy.deepcopy(small_model2.embeddings)

In [20]:
assert torch.isclose(small_model1.embeddings.word_embeddings.weight, stitched_model.embeddings1.word_embeddings.weight).all()
assert torch.isclose(small_model2.embeddings.word_embeddings.weight, stitched_model.embeddings2.word_embeddings.weight).all()

In [21]:
l1 = torch.nn.Linear(10, 100, bias=True)
l2 = torch.nn.Linear(10, 100, bias=True)
if all((l1.bias !=  None, l2.bias != None)) :
    print("bias")

bias


In [22]:
all((1, 1, 1))

True

In [23]:
def copy_linear(tgt, src1, src2):
    """
    all: torch.nn.Linear
    """
    # check if bias exists
    assert None not in (tgt.bias, src1.bias, src2.bias) or not any((tgt.bias, src1.bias, src2.bias))
    
    tgt_out_dim, tgt_in_dim = tgt.weight.size()
    src1_out_dim, src1_in_dim = src1.weight.size()
    src2_out_dim, src2_in_dim = src2.weight.size()

    assert tgt_out_dim == src1_out_dim + src2_out_dim
    assert tgt_in_dim == src1_in_dim + src2_in_dim

    # NOTE: check indexing
    tgt.weight.detach()[:src1_out_dim, :src1_in_dim] = src1.weight.detach()
    tgt.weight.detach()[-src2_out_dim:, -src2_in_dim:] = src2.weight.detach()
    
    if tgt.bias is not None:
        # copy bias
        tgt.bias.detach()[:src1_out_dim] = src1.bias.detach()
        tgt.bias.detach()[-src2_out_dim:] = src2.bias.detach()

In [24]:
def copy_layernorm(tgt, src1, src2):
    """
    all: torch.nn.modules.normalization.LayerNorm
    """
    tgt_dim, src1_dim, src2_dim = tgt.weight.size(0), src1.weight.size(0), src2.weight.size(0)
    assert tgt_dim == src1_dim + src2_dim

    # # NOTE: check indexing
    # copy weights
    tgt.weight.detach()[:src1_dim] = src1.weight.detach()
    tgt.weight.detach()[-src2_dim:] = src2.weight.detach()

In [25]:
# copy within layers
for layer_st, layer_1, layer_2 in zip(stitched_model.encoder.layer, small_model1.encoder.layer, small_model2.encoder.layer):
    assert type(layer_st.attention1) == type(layer_1.attention)
    assert type(layer_st.attention2) == type(layer_2.attention)

    # copy attention modules
    layer_st.attention1 = copy.deepcopy(layer_1.attention)
    layer_st.attention2 = copy.deepcopy(layer_2.attention)

    # copy intermediate ffn
    copy_linear(layer_st.intermediate.dense, layer_1.intermediate.dense, layer_2.intermediate.dense)

    # copy output ffn
    copy_linear(layer_st.output.dense, layer_1.output.dense, layer_2.output.dense)
    copy_layernorm(layer_st.output.LayerNorm, layer_1.output.LayerNorm, layer_2.output.LayerNorm)

In [26]:
# copy pooler
copy_linear(stitched_model.pooler.dense, small_model1.pooler.dense, small_model2.pooler.dense)


# TODO: copy sanity check, incorporate to bertmodel

In [26]:
# nonlinear activation is inside BertIntermediate
large_model.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=512, out_features=512, bias=True)
      (key): Linear(in_features=512, out_features=512, bias=True)
      (value): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=512, out_features=512, bias=True)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=512, out_features=2048, bias=True)
  )
  (output): BertOutput(
    (dense): Linear(in_features=2048, out_features=512, bias=True)
    (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
large_model.generate()

In [44]:
tokenizer.decode()

TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'