In [1]:
%load_ext autoreload
%autoreload 2 

import torch
import transformers
import numpy as np

from collections import OrderedDict

import sys
sys.path.insert(0, '..')

from bert_model import BertForSequenceClassification, BertConfig
from decompose_bert import BertForSequenceClassificationDecomposed


In [2]:
model_hf = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")
model_hf.eval()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [3]:
model = BertForSequenceClassification(BertConfig.from_dict(model_hf.config.to_dict()))


In [4]:
new_state_dict = OrderedDict()
for key, value in model_hf.state_dict().items():
    # rename weight values in state_dict from roberta to bert
    # new_key = key.replace("roberta", "bert")
    new_key = key.replace(
        "classifier.dense", "bert.pooler.dense").replace(
            "classifier.out_proj", "classifier")
    # ref : https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
    new_key = new_key.replace("LayerNorm.weight",
                           "LayerNorm.gamma").replace("LayerNorm.bias",
                                                       "LayerNorm.beta")
    new_key = new_key.replace("lm_head.layer_norm.weight",
                           "cls.predictions.transform.LayerNorm.gamma").replace("lm_head.layer_norm.bias",
                                                       "cls.predictions.transform.LayerNorm.beta")
    new_key = new_key.replace("lm_head.decoder", "cls.predictions.decoder")
    new_key = new_key.replace("lm_head.dense", "cls.predictions.transform.dense")
    new_key = new_key.replace("cls.predictions.decoder.bias", "cls.predictions.transform.dense")
    new_key = new_key.replace("lm_head.bias", "cls.predictions.bias")
    
    new_state_dict[new_key] = value
    
model.load_state_dict(new_state_dict)
model.eval()


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12)
              (dropout): Dropout(p=0.1,

In [5]:
torch.allclose(
    model_hf.bert.encoder.layer[0].attention.self.query.weight,
    model.bert.encoder.layer[0].attention.self.query.weight,
    atol=1e-5)


True

In [6]:
with torch.no_grad():
    for i in range(12):
        # test equal encoder layers
        x = torch.rand((100, 100, 768))
        f_1 = model_hf.bert.encoder.layer[i]
        f_2 = model.bert.encoder.layer[i]
        assert torch.allclose(f_1(x)[0], f_2(x, 0), atol=1e-5)


In [7]:
# tokenizer.pad_token = tokenizer.eos_token
# model_hf.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt",
                   truncation=True, padding="max_length", max_length=128
                   )

inputs["input_ids"] = inputs["input_ids"]
model_hf.eval()

with torch.no_grad():
    logits_hf = model_hf(**inputs)

logits_hf.logits


tensor([[ 0.0067, -0.0538],
        [ 0.0067, -0.0538]])

In [8]:
with torch.no_grad():
    logits = model(**inputs)

logits


tensor([[ 0.0067, -0.0538],
        [ 0.0067, -0.0538]])

In [9]:
model_decomposed = BertForSequenceClassificationDecomposed.from_pretrained(
    "bert-base-uncased", num_labels=2, debug=True, num_contributions=2,
    shapley_include_bias=False, generalized=True)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
beta_mask = (torch.rand(128) > 0.5).to(int)
beta_mask = torch.stack([beta_mask, 1 - beta_mask])
beta_mask = torch.stack([beta_mask] * 2)
print(beta_mask.shape)

with torch.no_grad():
    model_decomposed.model.eval()
    decomposed_logits = model_decomposed(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        beta_mask=beta_mask)


torch.Size([2, 2, 128])
Bert Encoder Layer 0 error: 6.989978328607237e-16
Bert Encoder Layer 1 error: 1.3262312161643556e-15
Bert Encoder Layer 2 error: 1.469668336832308e-15
Bert Encoder Layer 3 error: 1.987404278810551e-15
Bert Encoder Layer 4 error: 2.269078172395863e-15
Bert Encoder Layer 5 error: 2.488831869733019e-15
Bert Encoder Layer 6 error: 2.6950499651323467e-15
Bert Encoder Layer 7 error: 2.7028122996895306e-15
Bert Encoder Layer 8 error: 2.7256980562439067e-15
Bert Encoder Layer 9 error: 2.8546187576946783e-15
Bert Encoder Layer 10 error: 2.777546592894938e-15
Bert Encoder Layer 11 error: 1.409679690032525e-15

Pooled output error:  6.996171683522143e-16

Bert Classifier logits error:  4.180683577104105e-16


In [11]:
model_decomposed = BertForSequenceClassificationDecomposed.from_pretrained(
    "bert-base-uncased", num_labels=2, debug=True, num_contributions=2,
    shapley_include_bias=True, generalized=True)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
beta_mask = (torch.rand(128) > 0.5).to(int)
beta_mask = torch.stack([beta_mask, 1 - beta_mask])
beta_mask = torch.stack([beta_mask] * 2)
print(beta_mask.shape)

with torch.no_grad():
    model_decomposed.model.eval()
    decomposed_logits = model_decomposed(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        beta_mask=beta_mask)


torch.Size([2, 2, 128])
Bert Encoder Layer 0 error: 8.539088210246059e-16
Bert Encoder Layer 1 error: 1.982586159089037e-15
Bert Encoder Layer 2 error: 3.1153116672346503e-15
Bert Encoder Layer 3 error: 7.47926386656403e-15
Bert Encoder Layer 4 error: 2.695628248642653e-14
Bert Encoder Layer 5 error: 1.8440875813350766e-13
Bert Encoder Layer 6 error: 1.364401917188876e-12
Bert Encoder Layer 7 error: 7.79094108507314e-12
Bert Encoder Layer 8 error: 2.3997366853811223e-11
Bert Encoder Layer 9 error: 6.985405154082852e-11
Bert Encoder Layer 10 error: 1.611596716656634e-10
Bert Encoder Layer 11 error: 1.4537406219001946e-10

Pooled output error:  5.70506828567058e-11

Bert Classifier logits error:  2.187477976534069e-11


In [36]:
model_decomposed = BertForSequenceClassificationDecomposed.from_pretrained(
    "bert-base-uncased", num_labels=2, debug=False, num_contributions=2,
    shapley_include_bias=False, generalized=False)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [42]:
beta_mask = (torch.rand(128) > 0.5).to(int)
beta_mask = torch.stack([beta_mask, 1 - beta_mask])
beta_mask = torch.stack([beta_mask] * 2)
print(beta_mask.shape)

with torch.no_grad():
    model_decomposed.model.eval()
    decomposed_logits = model_decomposed(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        beta_mask=beta_mask)


torch.Size([2, 2, 128])


In [44]:
decomposed_logits


tensor([[[ 0.3257,  0.1336],
         [ 0.3257,  0.1336]],

        [[-0.4999, -0.4101],
         [-0.4999, -0.4101]],

        [[-0.1630, -0.0822],
         [-0.1630, -0.0822]]], dtype=torch.float64)